From d4fb9b791a6cc6d3d2ad473fbb732fd33d095b69 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Sep 2025 18:02:55 +0800 Subject: [PATCH 1/3] [Refactor] Enhance CopyNode Lower method to support disable_tma flag and improve flash attention implementation * Updated the CopyNode Lower method to correctly include the disable_tma flag in the GetCopyInst call. * Refactored the flash attention implementation to selectively disable TMA for specific copy operations while allowing it for others. * Addressed linting issues for improved code quality --- examples/deepseek_v32/fp8_mqa_logits.py | 410 +++++++++++++++++ examples/deepseek_v32/sparse_mla_fwd.py | 292 ++++++++++++ .../deepseek_v32/sparse_mla_fwd_pipelined.py | 418 ++++++++++++++++++ examples/deepseek_v32/utils.py | 174 ++++++++ 4 files changed, 1294 insertions(+) create mode 100644 examples/deepseek_v32/fp8_mqa_logits.py create mode 100644 examples/deepseek_v32/sparse_mla_fwd.py create mode 100644 examples/deepseek_v32/sparse_mla_fwd_pipelined.py create mode 100644 examples/deepseek_v32/utils.py diff --git a/examples/deepseek_v32/fp8_mqa_logits.py b/examples/deepseek_v32/fp8_mqa_logits.py new file mode 100644 index 000000000..6273f71a9 --- /dev/null +++ b/examples/deepseek_v32/fp8_mqa_logits.py @@ -0,0 +1,410 @@ +import itertools +import math +from einops import rearrange +import tilelang +from tilelang import language as T +import torch +from tilelang.autotuner import autotune +from tilelang import tvm +from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q + +from typing import Tuple + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + +def print_red_warning(message): + print(f"\033[31mWARNING: {message}\033[0m") + + +def calc_sim(x, y, name="tensor"): + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print_red_warning(f"{name} all zero") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + x_mask = torch.isfinite(x) + y_mask = torch.isfinite(y) + if not torch.all(x_mask == y_mask): + print_red_warning(f"{name} Error: isfinite mask mismatch") + if raise_assert: + assert False + if not torch.isclose( + x.masked_fill(x_mask, 0), + y.masked_fill(y_mask, 0), + rtol=0, + atol=0, + equal_nan=True, + ).all(): + print_red_warning(f"{name} Error: nonfinite value mismatch") + if raise_assert: + assert False + x = x.masked_fill(~x_mask, 0) + y = y.masked_fill(~y_mask, 0) + sim = calc_sim(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print_red_warning(f"{name} Error: {diff}") + if raise_assert: + assert False + return diff + + +def get_configs(): + iter_params = dict( + block_N=[32, 64, 128], + num_stages=[0, 1, 2], + threads=[128, 256], + block_Q=[1, 2, 4], + ) + return [ + {k: v for k, v in zip(iter_params, values)} + for values in itertools.product(*iter_params.values()) + ] + + +class SupplyProg: + def __init__(self): + self.tensors_dict = {} + + def get_key(self, shape, dtype) -> str: + return f"{shape}-{dtype}" + + def supply_prog(self, params): + shapes = [p.shape for p in params] + dtypes = [p.dtype for p in params] + tensor_list = [] + for shape, dtype in zip(shapes, dtypes): + key = self.get_key(shape, dtype) + if key not in self.tensors_dict: + self.tensors_dict[key] = torch.randn(shape, dtype=dtype, device="cuda") + tensor_list.append(self.tensors_dict[key]) + else: + tensor_list.append(self.tensors_dict[key]) + return tensor_list + + +supply_prog = SupplyProg() + + +@tilelang.jit +def mqa_attn_return_logits( + heads, + index_dim, + block_N=256, + num_stages=3, + threads=512, + block_Q=None, +): + if block_Q is None: + block_Q = 128 // heads + dtype = "float8_e4m3" + accum_dtype = "float" + index_dtype = "int32" + + seq_len = tvm.te.var("seq_len") + seq_len_kv = tvm.te.var("seq_len_kv") + + index_q_shape = [seq_len * heads, index_dim] + index_k_shape = [seq_len_kv, index_dim] + index_k_scale_shape = [seq_len_kv] + logits_shape = [seq_len, seq_len_kv] + + @T.prim_func + def mqa_attn_return_logits_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: + + index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) + index_k_shared = T.alloc_shared([block_N, index_dim], dtype) + index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) + s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) + s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype) + logits = T.alloc_fragment([block_N, block_Q], accum_dtype) + weights = T.alloc_fragment([block_Q, heads], accum_dtype) + + seq_len_i = bx * block_Q + + cu_k_s_min = T.alloc_local([1], index_dtype) + cu_k_e_max = T.alloc_local([1], index_dtype) + + T.no_set_max_nreg() + + cu_k_s_min[0] = 2147483647 + cu_k_e_max[0] = -2147483648 + + for bq_i in T.serial(block_Q): + cu_k_s_min[0] = T.min( + cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv) + ) + for bq_i in T.serial(block_Q): + cu_k_e_max[0] = T.max( + cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv) + ) + + T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) + T.copy(Weights[seq_len_i, 0], weights) + + for nbn_i in T.Pipelined( + T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages + ): + T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) + T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) + + T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): + s_reshaped[bn_i, bq_i, h_i] = ( + T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i] + ) * index_k_scale_fragment[bn_i] + + T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + + for bq_i, bn_i in T.Parallel(block_Q, block_N): + Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = ( + logits[bn_i, bq_i] + ) + return mqa_attn_return_logits_kernel + + +@tilelang.jit +def clean_logits_( + threads: int = 512, + block_K: int = 4096, +): + seq_len = tvm.te.var("seq_len") + seq_len_kv = tvm.te.var("seq_len_kv") + + dtype = "float" + indices_dtype = "int32" + + @T.prim_func + def clean_logits_kernel( + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + ): + with T.Kernel(seq_len, threads=threads) as bx: + tx = T.thread_binding(0, threads, thread="threadIdx.x") + cu_k_s = T.alloc_local([1], indices_dtype) + cu_k_e = T.alloc_local([1], indices_dtype) + cu_k_s[0] = CuSeqLenKS[bx] + cu_k_e[0] = CuSeqLenKE[bx] + + for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): + for k_i in T.serial(block_K // threads): + idx = n_i * block_K + k_i * threads + tx + if idx < cu_k_s[0] or idx >= cu_k_e[0]: + Logits[bx, idx] = -T.infinity(dtype) + + return clean_logits_kernel + + +def mqa_attn_return_logits_interface( + q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True +): + seq_len, heads, index_dim = q.shape + seq_len_kv = kv.shape[0] + + clean_logits_kernel = clean_logits_() + + mqa_attn_return_logits_kernel = mqa_attn_return_logits( + heads=heads, index_dim=index_dim + ) + logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32) + mqa_attn_return_logits_kernel( + q.view(seq_len * heads, index_dim), + kv, + kv_scales, + logits, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + ) + if clean_logits: + clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke) + return logits + + +def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): + k = kv + q = q.float() + k = k.float() + + seq_len_kv = kv.shape[0] + mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] + mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + + score = torch.einsum('mhd,nd->hmn', q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float('-inf')) + + cost = mask.sum() + return logits, cost + +if __name__ == "__main__": + torch.manual_seed(0) + S, SKV, H, HKV, D, kv_stride = 4096, 8192, 32, 1, 64, 1 + q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to( + torch.bfloat16 + ) + kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to( + torch.bfloat16 + ) + weights = torch.randn(S, H, device="cuda", dtype=torch.float32) + p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) + + def generate_random_cu_seqlens( + per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, average_q_len=512 + ): + total_seqlen = per_cp_seqlen * cp_size + + cu_seqlens = torch.randint( + 0, average_q_len * 2, (total_seqlen // average_q_len * 2,) + ).cuda() + last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0] + cu_seqlens = cu_seqlens[:last_seq_id] + + if cu_seqlens.sum() < total_seqlen: + cu_seqlens = torch.cat( + [cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()] + ) + + total_seqlen_k = (cu_seqlens // kv_stride).sum() + + cu_seqlens_cumsum = torch.cumsum(cu_seqlens, dim=0) + cu_seqlens_k_cumsum = torch.cumsum(cu_seqlens // kv_stride, dim=0) + cu_seqlens_qs = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_cumsum[:-1]]) + cu_seqlens_ks = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_k_cumsum[:-1]]) + cu_seqlens_qe = cu_seqlens_cumsum.clone() + cu_seqlens_ke = cu_seqlens_k_cumsum.clone() + + cu_seqlens_ks_for_each_q = cal_cu_seqlen_ks_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + seq_len=total_seqlen, + ) + cu_seqlens_ke_for_each_q = cal_cu_seqlen_ke_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + cu_seqlens_ke=cu_seqlens_ke, + q_start_idxs=torch.zeros_like(cu_seqlens_qs), + seq_len=total_seqlen, + kv_stride=kv_stride, + ) + + assert per_cp_seqlen % 2 == 0 + per_chunk_seqlen = per_cp_seqlen // 2 + slice_short = slice( + cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen + ) + slice_long = slice( + total_seqlen - (cp_rank + 1) * per_chunk_seqlen, + total_seqlen - cp_rank * per_chunk_seqlen, + ) + ks = torch.cat( + [ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ] + ) + ke = torch.cat( + [ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ] + ) + assert len(ks) == len(ke) == per_cp_seqlen + return ks, ke + + ks, ke = generate_random_cu_seqlens( + per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048 + ) + + logits_ref, cost_ref = ref_fp8_mqa_logits( + q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke + ) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0, ), False) + + logits_tl = mqa_attn_return_logits_interface( + q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke + ) + diff = assert_similar( + logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False + ) + + original_diff = None + for i in range(10): + logits_tl = mqa_attn_return_logits_interface( + q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke + ) + diff = assert_similar( + logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False + ) + if original_diff is None: + original_diff = diff + else: + assert original_diff == diff + + from tilelang.profiler import do_bench + + + def logits_fn(): + return mqa_attn_return_logits_interface( + q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke + ) + + with torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CUDA] + ) as prof: + logits_fn() + + print( + prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50) + ) + + logits_ms = do_bench(logits_fn, warmup=100, rep=100) + logits_flops = 2 * cost_ref * H * D + logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12 + print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}") + + print(f"cost_ref: {cost_ref}") + + torch.cuda.profiler.start() + logits_fn() + torch.cuda.profiler.stop() diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py new file mode 100644 index 000000000..b68547bd2 --- /dev/null +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -0,0 +1,292 @@ +import torch +import tilelang +from tilelang import language as T +from tilelang import tvm +from utils import print_red_warning, calc_sim, assert_similar + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_attention_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=2, + threads=256, +): + assert dim == tilelang.math.next_power_of_2( + dim + ), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2( + tail_dim + ), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert ( + topk % block_I == 0 + ), "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + batch = T.symbolic("batch") + seq_len = T.symbolic("seq_len") + seq_len_kv = T.symbolic("seq_len_kv") + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = "int32" + dtype = "bfloat16" + accum_dtype = "float" + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert ( + kv_group == 1 + ), "here we solve the H padding automically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automically)" + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + O_shared = T.alloc_shared([H_per_block, D], dtype) + Lse_shared = T.alloc_shared([H_per_block], accum_dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_i, g_i = by, bz + s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + + for bi_i in T.Parallel(BI): + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[ + b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i + ] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[ + b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i + ] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else( + mask[bi_i], 0, -T.infinity(acc_s.dtype) + ) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2( + acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale + ) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, O_shared) + T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) + T.copy(sumexp, Lse_shared) + T.copy(sumexp, Lse[b_i, s_i, H0:H1]) + + return main + + +def sparse_attention_fwd_interface( + q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512 +): + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + kernel = sparse_attention_fwd( + heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual + ) + out, lse = kernel(q, kv, indices) + return out, lse + + +def ref_sparse_attention_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): + q = q.float() + kv = kv.float() + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + num_kv_per_index = 1 + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view( + -1, 1 + ) >= torch.arange(1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) + + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter( + 3, indices.long(), 1 + ) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : 1 - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + return o.to(torch.float16) + +def test_sparse_attn_mla_fwd(): + B, S, SKV, H, HKV, DQK, DV, topk, dtype = ( + 1, + 4096, + 32768, + 128, + 1, + 576, + 512, + 2048, + torch.bfloat16, + ) + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_( + True + ) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + tl_out, tl_lse = sparse_attention_fwd_interface(q, kv, indices) + + def fn(): + return sparse_attention_fwd_interface(q, kv, indices) + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=100, + warmup=250, + ) + print(f"Average time: {ms:.3f} ms") + print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + + +if __name__ == "__main__": + test_sparse_attn_mla_fwd() diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py new file mode 100644 index 000000000..de823d4a2 --- /dev/null +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -0,0 +1,418 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +from tilelang import tvm + +from tilelang.engine.callback import register_cuda_postproc_callback +import argparse + +@tilelang.jit( + out_idx=[-2, -1], + compile_flags=[ + "-O3", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG"], +) +def sparse_attention_fwd( + batch, + seq_len, + seq_len_kv, + heads, + dim, + tail_dim, + topk, + kv_stride, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=0, + threads=384, +): + ''' + This code implements sparse attn + Note that the first kv_stride - 1 token's out would be nan. since this isn't used, we assume it doesn't matter. (**still, one might have to handle carefully in backward to avoid dout * nan propagated!**) + It might be OK to set these nan to zero, but we assume it might serve as a reminder of taking care of these out in 'delta = out * dout'. + The above feature might be replaced with out being undefined if we fix CP0 logic (this logic is currently wrong due to some bug in compiler) + ''' + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, 'non-casual is not supported' + assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = "int32" + dtype = "bfloat16" + accum_dtype = "float" + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, 'here we solve the H padding automically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automically)' + BI = block_I + NI = tilelang.cdiv(topk, block_I) + assert NI % 2 == 0, 'NI should be a multiple of 2' + D = dim + D_tail = tail_dim + KV_stride = kv_stride + if head_kv > 64: + assert head_kv % 64 == 0, 'head_kv should be a multiple of 64' + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz): + Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) + Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype) + K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype) + K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype) + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + is_kv_valid = T.alloc_shared([BI], "bool", scope="shared") + + acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared") + alpha_local = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + indices_local = T.alloc_local([1], indices_dtype) + + # TODO: Multi buffer + bar_q = T.alloc_barrier(arrive_count=384) + bar_k_0_ready = T.alloc_barrier(arrive_count=128) + bar_k_1_ready = T.alloc_barrier(arrive_count=128) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256) + bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) + + b_i, g_i = by, bz + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + q_i = q_start_index_s[0] + s_i + max_kv_i = (q_i + 1 - KV_stride) // KV_stride + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + tx = T.get_thread_binding() + + T.copy(Q[b_i, s_i, H0:H1, 0:D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2:D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + T.barrier_arrive(bar_q) + + if tx < 128: + T.set_max_nreg(240, 1) + T.fill(sumexp, 0) + T.fill(m_i, -2**30) # avoid -inf - inf to cause nan + T.fill(acc_o_l, 0) + T.barrier_wait(bar_q, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + + # Buffer 0 + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + if i_i != 0: + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_0_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_0_free[0]) + + + # Buffer 1 + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) + T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) + T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + + T.wait_wgmma(0) + + T.barrier_arrive(bar_sScale_and_sS_free) + T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1) + + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= alpha_local[h_i] + T.copy(alpha_local, alpha_shared) + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared_1_l, acc_o_l) + + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_arrive(bar_k_1_free[0]) + + # Rescale + for h_i in T.Parallel(H_per_block): + sum_exp_shared[h_i] = sumexp[h_i] + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0:D // 2]) + + elif tx >= 128 and tx < 256: + T.set_max_nreg(168, 1) + T.fill(acc_o_r, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1)) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_0_r, acc_o_r) + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_sScale_and_sS_free) + + # Buffer 1 + T.barrier_arrive(bar_sScale_and_sS_ready) + T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1)) + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= alpha_shared[h_i] + T.gemm(S_shared, KV_shared_1_r, acc_o_r) + T.barrier_arrive(bar_k_1_free[0]) + if i_i != T.ceildiv(NI, 2) - 1: + T.barrier_arrive(bar_sScale_and_sS_free) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2:D]) + elif tx >= 256: + # producer + T.set_max_nreg(80, 0) + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + if is_kv_valid[r * 16 + (tx - 256) // 8]: + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v] + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) + for r in T.serial(4): + indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i + if is_kv_valid[r * 16 + (tx - 256) // 8]: + with T.attr("default", "async_scope", 1): + for u in T.serial(4): + for v in T.vectorized(8): + KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v] + with T.attr("default", "async_scope", 1): + for v in T.vectorized(8): + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v] + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + return main + + +def sparse_attention_fwd_interface(q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False): + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, 'you should assign dim otherwise' + dim = 512 + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + if q_start_index_s != 0: + assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + CP0 = q_start_index_s == 0 + + kernel = sparse_attention_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) + if print_kernel: + print(kernel.get_kernel_source()) + out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + if return_kernel: + return kernel + if q_start_index_s == 0 and kv_stride > 1: + out[:, :kv_stride-1, :, :] = 0 + return out, lse + + +def ref_sparse_attention_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True): + q = q.float() + kv = kv.float() + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + if q_start_index_s is None: + q_start_index_s = sk * kv_stride - sq + + assert kv.shape[-1] == 576, 'you should assign dim otherwise' + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + num_kv_per_index = 1 + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(kv_stride-1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, :kv_stride - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q ** -0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + return o.to(torch.float16) + +def test_sparse_attn_mla_fwd(test_correctness=False): + KV_stride = 1 + if test_correctness: + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16 + q_start_s_index = 1024 + else: + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + q_start_s_index = 4096*64 + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True)/10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True)/10 + q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] + indices[b, t, h, :len(i_i)] = i_i + + kernel = sparse_attention_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + def fn(): + out, lse = kernel(q, kv, indices, q_start_s_index_t) + if q_start_s_index == 0 and kv_stride > 1: + out[:, :kv_stride-1, :, :] = 0 + return out, lse + + tl_out, tl_lse = fn() + if test_correctness: + ref_out = ref_sparse_attention_fwd_interface(q, kv, indices, q_start_s_index, KV_stride) + print(f"tl_out: {tl_out}") + print(f"ref_out: {ref_out}") + assert_similar(tl_out, ref_out) + + from tilelang.profiler import do_bench + ms = do_bench( + fn, + rep=10, + warmup=10, + ) + print(f"Average time: {ms:.3f} ms") + print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--test_correctness", action="store_true") + args = parser.parse_args() + test_sparse_attn_mla_fwd(args.test_correctness) diff --git a/examples/deepseek_v32/utils.py b/examples/deepseek_v32/utils.py new file mode 100644 index 000000000..2129fcb4a --- /dev/null +++ b/examples/deepseek_v32/utils.py @@ -0,0 +1,174 @@ +# -*- coding: utf-8 -*- +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +# -*- coding: utf-8 -*- + +import contextlib +import functools +import logging +import os +import sys +from enum import Enum +from functools import lru_cache +from typing import Any, Callable, Dict, Literal, Optional, Tuple + +from packaging import version + +def _is_equal(a, b): + if isinstance(a, torch.Tensor): + return a is b + # Whitelist of types that are safe to compare by value for caching. + if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))): + return a == b + # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. + return False + +def tensor_cache( + fn: Callable[..., torch.Tensor] +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: Optional[Tuple] = None + last_kwargs: Optional[Dict] = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if last_args is not None and last_kwargs is not None: + if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): + # For Tensors, check for object identity. For other types, check for equality. + # Python caches small integers, so `is` works for them but not for large integers like 4096. + if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \ + set(kwargs.keys()) == set(last_kwargs.keys()) and \ + all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + +@tensor_cache +def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): + seq_idx = cu_seqlens.new_zeros(seq_len+1) + seq_idx.scatter_add_(0, cu_seqlens[1:].long(), torch.ones_like(seq_idx)) + seq_idx.cumsum_(0) + return seq_idx[:-1] + +@tensor_cache +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i + return seq_idx_for_q + +@tensor_cache +def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor: + cu_seqlen_ks_for_each_q = torch.gather(input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) + return cu_seqlen_ks_for_each_q.int() + +@tensor_cache +def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor, q_start_idxs: torch.LongTensor, seq_len: int, kv_stride: int) -> torch.IntTensor: + cu_seqlen_ke_for_each_q = torch.gather(input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange(q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i] + cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) + return cu_seqlen_ke_for_each_q.int() + +@tensor_cache +def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, cu_seqlens_k: torch.LongTensor = None, offs_q: torch.LongTensor = None, *, seq_len: int, kv_stride:int=1, cp_rank:int=0, cp_size:int=1, balanced_cp=False): + ''' + seq_len: seq len per cp rank + balanced cp slice assignment: 0 1 2 3 3 2 1 0 + ''' + n_seq = len(cu_seqlens_q) - 1 + assert n_seq > 0 + assert cu_seqlens_q.shape == (n_seq + 1,) + seq_idx = cal_seq_idx_from_cu_seqlens(cu_seqlens_q.long(), seq_len*cp_size) + qs = cu_seqlens_q.gather(0, seq_idx) + pos = torch.arange(len(qs), dtype=qs.dtype, device=qs.device) - qs + if offs_q is not None: + assert offs_q.shape == (n_seq,), offs_q.shape + qoff = offs_q.gather(0, seq_idx) + pos += qoff + if cu_seqlens_k is None or cu_seqlens_k is cu_seqlens_q: + ks = qs + else: + assert cu_seqlens_k.shape == (n_seq + 1,) + ks = cu_seqlens_k.gather(0, seq_idx) + ke = ks + (pos + 1) // kv_stride + + if cp_size == 1: + pass + elif balanced_cp: + assert cp_size % 2 == 0, cp_size + def f(x: torch.Tensor): + chunks = x.chunk(cp_size*2) + return torch.cat([ + chunks[cp_rank], + chunks[cp_size-cp_rank-1], + ]) + ks = f(ks) + ke = f(ke) + else: + ks = ks.chunk(cp_size)[cp_rank] + ke = ke.chunk(cp_size)[cp_rank] + + return ks, ke + + +def print_red_warning(message): + print(f"\033[31mWARNING: {message}\033[0m") + + +def calc_sim(x, y, name="tensor"): + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print_red_warning(f'{name} all zero') + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_similar(x, y, eps=1e-8, name="tensor"): + sim = calc_sim(x, y, name) + diff = 1. - sim + if not (0 <= diff <= eps): + print_red_warning(f'{name} Error: {diff}') + assert False + + +if __name__ == "__main__": + seq_len = 32768 + cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") + last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] + cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) + cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) + cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) + + from tilelang.profiler import do_bench + + fn = lambda: cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len) + ms = do_bench(fn, warmup=25, rep=100) From 8f00a9e87b685b71f92daf16b82b07b742351a37 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Sep 2025 18:07:48 +0800 Subject: [PATCH 2/3] sparse mla kernels --- examples/deepseek_v32/fp8_mqa_logits.py | 215 +++++++----------- examples/deepseek_v32/sparse_mla_fwd.py | 87 ++++--- .../deepseek_v32/sparse_mla_fwd_pipelined.py | 143 +++++++----- examples/deepseek_v32/utils.py | 91 ++++++-- 4 files changed, 277 insertions(+), 259 deletions(-) diff --git a/examples/deepseek_v32/fp8_mqa_logits.py b/examples/deepseek_v32/fp8_mqa_logits.py index 6273f71a9..60376d6cc 100644 --- a/examples/deepseek_v32/fp8_mqa_logits.py +++ b/examples/deepseek_v32/fp8_mqa_logits.py @@ -1,13 +1,9 @@ import itertools -import math -from einops import rearrange import tilelang from tilelang import language as T import torch -from tilelang.autotuner import autotune from tilelang import tvm -from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q - +from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar from typing import Tuple @@ -15,7 +11,9 @@ def ceil_to_ue8m0(x: torch.Tensor): assert x.view(-1).amax().item() > 0 return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) -def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], + use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) sf = x_amax / 448.0 @@ -23,47 +21,6 @@ def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bo x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) return x_scaled, sf.squeeze() -def print_red_warning(message): - print(f"\033[31mWARNING: {message}\033[0m") - - -def calc_sim(x, y, name="tensor"): - x, y = x.data.double(), y.data.double() - denominator = (x * x + y * y).sum() - if denominator == 0: - print_red_warning(f"{name} all zero") - return 1 - sim = 2 * (x * y).sum() / denominator - return sim - - -def assert_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): - x_mask = torch.isfinite(x) - y_mask = torch.isfinite(y) - if not torch.all(x_mask == y_mask): - print_red_warning(f"{name} Error: isfinite mask mismatch") - if raise_assert: - assert False - if not torch.isclose( - x.masked_fill(x_mask, 0), - y.masked_fill(y_mask, 0), - rtol=0, - atol=0, - equal_nan=True, - ).all(): - print_red_warning(f"{name} Error: nonfinite value mismatch") - if raise_assert: - assert False - x = x.masked_fill(~x_mask, 0) - y = y.masked_fill(~y_mask, 0) - sim = calc_sim(x, y, name) - diff = 1.0 - sim - if not (0 <= diff <= eps): - print_red_warning(f"{name} Error: {diff}") - if raise_assert: - assert False - return diff - def get_configs(): iter_params = dict( @@ -72,13 +29,13 @@ def get_configs(): threads=[128, 256], block_Q=[1, 2, 4], ) - return [ - {k: v for k, v in zip(iter_params, values)} - for values in itertools.product(*iter_params.values()) - ] + return [{ + k: v for k, v in zip(iter_params, values) + } for values in itertools.product(*iter_params.values())] class SupplyProg: + def __init__(self): self.tensors_dict = {} @@ -127,13 +84,13 @@ def mqa_attn_return_logits( @T.prim_func def mqa_attn_return_logits_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore - IndexK: T.Tensor(index_k_shape, dtype), # type: ignore - IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore - Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore - Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore ): with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: @@ -156,20 +113,17 @@ def mqa_attn_return_logits_kernel( cu_k_e_max[0] = -2147483648 for bq_i in T.serial(block_Q): - cu_k_s_min[0] = T.min( - cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv) - ) + cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], + seq_len_kv)) for bq_i in T.serial(block_Q): - cu_k_e_max[0] = T.max( - cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv) - ) + cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], + seq_len_kv)) T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) T.copy(Weights[seq_len_i, 0], weights) for nbn_i in T.Pipelined( - T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages - ): + T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) @@ -183,16 +137,16 @@ def mqa_attn_return_logits_kernel( ) for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): - s_reshaped[bn_i, bq_i, h_i] = ( - T.max(s[bn_i, bq_i * heads + h_i], 0) * weights[bq_i, h_i] - ) * index_k_scale_fragment[bn_i] + s_reshaped[bn_i, bq_i, + h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * + weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) for bq_i, bn_i in T.Parallel(block_Q, block_N): Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = ( - logits[bn_i, bq_i] - ) + logits[bn_i, bq_i]) + return mqa_attn_return_logits_kernel @@ -209,9 +163,9 @@ def clean_logits_( @T.prim_func def clean_logits_kernel( - Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore ): with T.Kernel(seq_len, threads=threads) as bx: tx = T.thread_binding(0, threads, thread="threadIdx.x") @@ -229,17 +183,19 @@ def clean_logits_kernel( return clean_logits_kernel -def mqa_attn_return_logits_interface( - q, kv, kv_scales, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True -): +def mqa_attn_return_logits_interface(q, + kv, + kv_scales, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + clean_logits=True): seq_len, heads, index_dim = q.shape seq_len_kv = kv.shape[0] clean_logits_kernel = clean_logits_() - mqa_attn_return_logits_kernel = mqa_attn_return_logits( - heads=heads, index_dim=index_dim - ) + mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim) logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32) mqa_attn_return_logits_kernel( q.view(seq_len * heads, index_dim), @@ -273,33 +229,30 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, cost = mask.sum() return logits, cost + if __name__ == "__main__": torch.manual_seed(0) S, SKV, H, HKV, D, kv_stride = 4096, 8192, 32, 1, 64, 1 - q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to( - torch.bfloat16 - ) - kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to( - torch.bfloat16 - ) + q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) weights = torch.randn(S, H, device="cuda", dtype=torch.float32) p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) - def generate_random_cu_seqlens( - per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, average_q_len=512 - ): + def generate_random_cu_seqlens(per_cp_seqlen, + cp_size=4, + cp_rank=3, + kv_stride=1, + average_q_len=512): total_seqlen = per_cp_seqlen * cp_size - cu_seqlens = torch.randint( - 0, average_q_len * 2, (total_seqlen // average_q_len * 2,) - ).cuda() + cu_seqlens = torch.randint(0, average_q_len * 2, + (total_seqlen // average_q_len * 2,)).cuda() last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0] cu_seqlens = cu_seqlens[:last_seq_id] if cu_seqlens.sum() < total_seqlen: cu_seqlens = torch.cat( - [cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()] - ) + [cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()]) total_seqlen_k = (cu_seqlens // kv_stride).sum() @@ -328,75 +281,65 @@ def generate_random_cu_seqlens( assert per_cp_seqlen % 2 == 0 per_chunk_seqlen = per_cp_seqlen // 2 - slice_short = slice( - cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen - ) + slice_short = slice(cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen) slice_long = slice( total_seqlen - (cp_rank + 1) * per_chunk_seqlen, total_seqlen - cp_rank * per_chunk_seqlen, ) - ks = torch.cat( - [ - cu_seqlens_ks_for_each_q[slice_short], - cu_seqlens_ks_for_each_q[slice_long], - ] - ) - ke = torch.cat( - [ - cu_seqlens_ke_for_each_q[slice_short], - cu_seqlens_ke_for_each_q[slice_long], - ] - ) + ks = torch.cat([ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ]) + ke = torch.cat([ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ]) assert len(ks) == len(ke) == per_cp_seqlen return ks, ke ks, ke = generate_random_cu_seqlens( - per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048 - ) + per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) logits_ref, cost_ref = ref_fp8_mqa_logits( - q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke - ) - + q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + q_fp8 = q.to(torch.float8_e4m3fn) - kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0, ), False) + kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) logits_tl = mqa_attn_return_logits_interface( - q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke - ) - diff = assert_similar( - logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False - ) + q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = assert_similar(logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False) original_diff = None for i in range(10): logits_tl = mqa_attn_return_logits_interface( - q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke - ) - diff = assert_similar( - logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False - ) + q=q_fp8, + kv=kv_fp8, + kv_scales=kv_scales, + weights=weights, + cu_seqlen_ks=ks, + cu_seqlen_ke=ke) + diff = assert_similar(logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False) if original_diff is None: original_diff = diff else: assert original_diff == diff - from tilelang.profiler import do_bench - + from tilelang.profiler import do_bench def logits_fn(): return mqa_attn_return_logits_interface( - q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke - ) - - with torch.profiler.profile( - activities=[torch.profiler.ProfilerActivity.CUDA] - ) as prof: + q=q_fp8, + kv=kv_fp8, + kv_scales=kv_scales, + weights=weights, + cu_seqlen_ks=ks, + cu_seqlen_ke=ke) + + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: logits_fn() - print( - prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50) - ) + print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50)) logits_ms = do_bench(logits_fn, warmup=100, rep=100) logits_flops = 2 * cost_ref * H * D diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index b68547bd2..cd7659551 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -4,6 +4,7 @@ from tilelang import tvm from utils import print_red_warning, calc_sim, assert_similar + @tilelang.jit( out_idx=[-2, -1], pass_configs={ @@ -25,17 +26,14 @@ def sparse_attention_fwd( threads=256, ): assert dim == tilelang.math.next_power_of_2( - dim - ), f"haven't check padding correctness yet, dim={dim}" + dim), f"haven't check padding correctness yet, dim={dim}" assert tail_dim == tilelang.math.next_power_of_2( - tail_dim - ), f"haven't check padding correctness yet, dim={tail_dim}" + tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert is_causal == True, "non-casual is not supported" - assert ( - topk % block_I == 0 - ), "otherwise will load some index=0 thus causing wrong kv to be loaded" + assert (topk % + block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) else: sm_scale = sm_scale * 1.44269504 # log2(e) @@ -59,7 +57,7 @@ def sparse_attention_fwd( if padded_H != H: assert ( kv_group == 1 - ), "here we solve the H padding automically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automically)" + ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" BI = block_I NI = tilelang.cdiv(topk, block_I) D = dim @@ -75,17 +73,18 @@ def sparse_attention_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( - bx, - by, - bz, - ): + with T.Kernel( + seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): Q_shared = T.alloc_shared([H_per_block, D], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) KV_shared = T.alloc_shared([BI, D], dtype) @@ -124,18 +123,14 @@ def main( mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[ - b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i - ] + KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, + d_i] for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[ - b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i - ] + K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, + D + d_i] for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else( - mask[bi_i], 0, -T.infinity(acc_s.dtype) - ) + acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) T.gemm( Q_shared, KV_shared, @@ -155,9 +150,7 @@ def main( for h_i in T.Parallel(H_per_block): alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.exp2( - acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale - ) + acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? for h_i in T.Parallel(H_per_block): sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] @@ -181,9 +174,12 @@ def main( return main -def sparse_attention_fwd_interface( - q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512 -): +def sparse_attention_fwd_interface(q, + kv, + indices, + sm_scale=None, + return_p_sum: bool = False, + d_v=512): is_casual = True assert return_p_sum == False, "This kernel file is for fwd only" assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() @@ -199,9 +195,7 @@ def sparse_attention_fwd_interface( _, _, _, topk = indices.shape assert indices.shape == (batch, seq_len, kv_group, topk) - kernel = sparse_attention_fwd( - heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual - ) + kernel = sparse_attention_fwd(heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual) out, lse = kernel(q, kv, indices) return out, lse @@ -222,16 +216,14 @@ def ref_sparse_attention_fwd_interface(q, kv, indices, sm_scale=None, is_casual= num_kv_per_index = 1 g_index = g h_index = h // g - compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32, device="cuda").view( - -1, 1 - ) >= torch.arange(1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange( + 0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( + 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) - mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter( - 3, indices.long(), 1 - ) + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, : 1 - 1, 0] = True + mask[:, :, :1 - 1, 0] = True mask = mask.view(b, g_index, 1, sq, sk) q = q.view(b, sq, g, -1, dim_q) @@ -245,6 +237,7 @@ def ref_sparse_attention_fwd_interface(q, kv, indices, sm_scale=None, is_casual= o = o.reshape(b, sq, h, dim_v) return o.to(torch.float16) + def test_sparse_attn_mla_fwd(): B, S, SKV, H, HKV, DQK, DV, topk, dtype = ( 1, @@ -260,16 +253,14 @@ def test_sparse_attn_mla_fwd(): torch.random.manual_seed(0) q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) - kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_( - True - ) + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") for b in range(B): for t in range(S): for h in range(HKV): i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, : len(i_i)] = i_i + indices[b, t, h, :len(i_i)] = i_i tl_out, tl_lse = sparse_attention_fwd_interface(q, kv, indices) diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index de823d4a2..cacad4afd 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -7,19 +7,15 @@ from tilelang.engine.callback import register_cuda_postproc_callback import argparse + @tilelang.jit( out_idx=[-2, -1], compile_flags=[ - "-O3", - "-Wno-deprecated-declarations", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", - "-DNDEBUG"], + "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", + "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" + ], ) def sparse_attention_fwd( batch, @@ -38,14 +34,10 @@ def sparse_attention_fwd( num_stages=0, threads=384, ): - ''' - This code implements sparse attn - Note that the first kv_stride - 1 token's out would be nan. since this isn't used, we assume it doesn't matter. (**still, one might have to handle carefully in backward to avoid dout * nan propagated!**) - It might be OK to set these nan to zero, but we assume it might serve as a reminder of taking care of these out in 'delta = out * dout'. - The above feature might be replaced with out being undefined if we fix CP0 logic (this logic is currently wrong due to some bug in compiler) - ''' - assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert dim == tilelang.math.next_power_of_2( + dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2( + tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" assert is_causal == True, 'non-casual is not supported' assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' if sm_scale is None: @@ -67,7 +59,7 @@ def sparse_attention_fwd( H = head_kv padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) if padded_H != H: - assert kv_group == 1, 'here we solve the H padding automically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automically)' + assert kv_group == 1, 'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)' BI = block_I NI = tilelang.cdiv(topk, block_I) assert NI % 2 == 0, 'NI should be a multiple of 2' @@ -84,14 +76,18 @@ def sparse_attention_fwd( @T.prim_func def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - q_start_index_s: T.Tensor(1, indices_dtype), - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore ): - with T.Kernel((seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, batch, kv_group, threads=threads) as (bx, by, bz): + with T.Kernel( + (seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, + batch, + kv_group, + threads=threads) as (bx, by, bz): Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) @@ -128,7 +124,8 @@ def main( bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) b_i, g_i = by, bz - s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else ( + bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) q_i = q_start_index_s[0] + s_i max_kv_i = (q_i + 1 - KV_stride) // KV_stride @@ -155,13 +152,14 @@ def main( T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, + -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) T.wait_wgmma(0) - + if i_i != 0: T.barrier_arrive(bar_sScale_and_sS_free) T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) @@ -185,12 +183,12 @@ def main( T.barrier_arrive(bar_sScale_and_sS_ready) T.barrier_arrive(bar_k_0_free[0]) - # Buffer 1 T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) + acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, + -T.infinity(acc_s.dtype)) T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) @@ -265,39 +263,65 @@ def main( # Buffer 0 T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] + indices_local[0] = Indices[b_i, s_i, g_i, + (i_i * 2) * BI + r * 16 + (tx - 256) // 8] is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_l[r * 16 + (tx - 256) // 8, + 64 * u + (tx - 256) % 8 * 8 + + v] = KV[b_i, indices_local[0], g_i, + 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_0_r[r * 16 + (tx - 256) // 8, + 64 * u + (tx - 256) % 8 * 8 + + v] = KV[b_i, indices_local[0], g_i, D // 2 + + 64 * u + (tx - 256) % 8 * 8 + v] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v] + K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + + v] = KV[b_i, indices_local[0], g_i, + D + (tx - 256) % 8 * 8 + v] T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + indices_local[0] = Indices[b_i, s_i, g_i, + (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: with T.attr("default", "async_scope", 1): for u in T.serial(4): for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_l[r * 16 + (tx - 256) // 8, + 64 * u + (tx - 256) % 8 * 8 + + v] = KV[b_i, indices_local[0], g_i, + 64 * u + (tx - 256) % 8 * 8 + v] + KV_shared_1_r[r * 16 + (tx - 256) // 8, + 64 * u + (tx - 256) % 8 * 8 + + v] = KV[b_i, indices_local[0], g_i, D // 2 + + 64 * u + (tx - 256) % 8 * 8 + v] with T.attr("default", "async_scope", 1): for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[b_i, indices_local[0], g_i, D + (tx - 256) % 8 * 8 + v] + K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + + v] = KV[b_i, indices_local[0], g_i, + D + (tx - 256) % 8 * 8 + v] T.cp_async_barrier_noinc(bar_k_1_ready[0]) return main -def sparse_attention_fwd_interface(q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False): +def sparse_attention_fwd_interface(q, + kv, + indices, + q_start_index_s, + kv_stride, + sm_scale=None, + is_casual=True, + return_kernel=False, + print_kernel=False): assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() batch, seq_len, heads, dim_plus_tail_dim = q.shape _, seq_len_kv, kv_group, _ = kv.shape @@ -315,18 +339,26 @@ def sparse_attention_fwd_interface(q, kv, indices, q_start_index_s, kv_stride, s assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" CP0 = q_start_index_s == 0 - kernel = sparse_attention_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) + kernel = sparse_attention_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, + kv_group, sm_scale, is_casual, CP0) if print_kernel: print(kernel.get_kernel_source()) - out, lse = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + out, lse = kernel(q, kv, indices, + torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) if return_kernel: return kernel if q_start_index_s == 0 and kv_stride > 1: - out[:, :kv_stride-1, :, :] = 0 + out[:, :kv_stride - 1, :, :] = 0 return out, lse -def ref_sparse_attention_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=4, sm_scale=None, is_casual=True): +def ref_sparse_attention_fwd_interface(q, + kv, + indices, + q_start_index_s, + kv_stride=4, + sm_scale=None, + is_casual=True): q = q.float() kv = kv.float() indices = indices.transpose(1, 2) @@ -344,7 +376,10 @@ def ref_sparse_attention_fwd_interface(q, kv, indices, q_start_index_s, kv_strid num_kv_per_index = 1 g_index = g h_index = h // g - compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange(kv_stride-1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + compressed_casual_mask = torch.arange( + q_start_index_s, sq + q_start_index_s, dtype=torch.int32, + device="cuda").view(-1, 1) >= torch.arange( + kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) mask = mask[..., :-1] @@ -354,7 +389,7 @@ def ref_sparse_attention_fwd_interface(q, kv, indices, q_start_index_s, kv_strid q = q.view(b, sq, g, -1, dim_q) score = torch.einsum("bmghd,bngd->bghmn", q, k) - sm_scale = dim_q ** -0.5 if sm_scale is None else sm_scale + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) p = score.softmax(dim=-1) p = p.view(b, g_index, h_index, -1, sq, sk) @@ -363,6 +398,7 @@ def ref_sparse_attention_fwd_interface(q, kv, indices, q_start_index_s, kv_strid o = o.reshape(b, sq, h, dim_v) return o.to(torch.float16) + def test_sparse_attn_mla_fwd(test_correctness=False): KV_stride = 1 if test_correctness: @@ -370,11 +406,11 @@ def test_sparse_attn_mla_fwd(test_correctness=False): q_start_s_index = 1024 else: B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 - q_start_s_index = 4096*64 + q_start_s_index = 4096 * 64 torch.random.manual_seed(0) - q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True)/10 - kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True)/10 + q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") q.clamp_(-10, 10) @@ -387,11 +423,13 @@ def test_sparse_attn_mla_fwd(test_correctness=False): i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] indices[b, t, h, :len(i_i)] = i_i - kernel = sparse_attention_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) - def fn(): + kernel = sparse_attention_fwd_interface( + q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + + def fn(): out, lse = kernel(q, kv, indices, q_start_s_index_t) if q_start_s_index == 0 and kv_stride > 1: - out[:, :kv_stride-1, :, :] = 0 + out[:, :kv_stride - 1, :, :] = 0 return out, lse tl_out, tl_lse = fn() @@ -411,6 +449,7 @@ def fn(): print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--test_correctness", action="store_true") diff --git a/examples/deepseek_v32/utils.py b/examples/deepseek_v32/utils.py index 2129fcb4a..201e13405 100644 --- a/examples/deepseek_v32/utils.py +++ b/examples/deepseek_v32/utils.py @@ -1,11 +1,7 @@ -# -*- coding: utf-8 -*- -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang - import torch import torch.nn.functional as F import triton import triton.language as tl -# -*- coding: utf-8 -*- import contextlib import functools @@ -18,18 +14,19 @@ from packaging import version + def _is_equal(a, b): if isinstance(a, torch.Tensor): return a is b # Whitelist of types that are safe to compare by value for caching. - if isinstance(a, (int, float, str, bool, type(None))) and isinstance(b, (int, float, str, bool, type(None))): + if isinstance(a, (int, float, str, bool, type(None))) and isinstance( + b, (int, float, str, bool, type(None))): return a == b # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. return False -def tensor_cache( - fn: Callable[..., torch.Tensor] -) -> Callable[..., torch.Tensor]: + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: """ A decorator that caches the most recent result of a function with tensor inputs. @@ -68,36 +65,79 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper + @tensor_cache def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): - seq_idx = cu_seqlens.new_zeros(seq_len+1) + seq_idx = cu_seqlens.new_zeros(seq_len + 1) seq_idx.scatter_add_(0, cu_seqlens[1:].long(), torch.ones_like(seq_idx)) seq_idx.cumsum_(0) return seq_idx[:-1] + @tensor_cache -def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: - seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, + seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), + len(cu_seqlens_qs), + dtype=torch.int32, + device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i return seq_idx_for_q + @tensor_cache -def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor: - cu_seqlen_ks_for_each_q = torch.gather(input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) +def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor: + cu_seqlen_ks_for_each_q = torch.gather( + input=torch.cat([ + cu_seqlens_ks, + torch.full((1,), + torch.iinfo(torch.int32).max, + dtype=torch.int32, + device=cu_seqlens_qs.device) + ]), + dim=0, + index=cal_seq_idx_for_q( + cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) return cu_seqlen_ks_for_each_q.int() + @tensor_cache -def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor, q_start_idxs: torch.LongTensor, seq_len: int, kv_stride: int) -> torch.IntTensor: - cu_seqlen_ke_for_each_q = torch.gather(input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), dim=0, index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) - casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) +def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor, + q_start_idxs: torch.LongTensor, seq_len: int, + kv_stride: int) -> torch.IntTensor: + cu_seqlen_ke_for_each_q = torch.gather( + input=torch.cat( + [cu_seqlens_ke, + torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), + dim=0, + index=cal_seq_idx_for_q( + cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), + dtype=torch.int32, + device=cu_seqlens_qs.device) for i in range(len(cu_seqlens_qs)): - casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange(q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i] + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange( + q_start_idxs[i], + q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], + dtype=torch.int32, + device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i] cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) return cu_seqlen_ke_for_each_q.int() + @tensor_cache -def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, cu_seqlens_k: torch.LongTensor = None, offs_q: torch.LongTensor = None, *, seq_len: int, kv_stride:int=1, cp_rank:int=0, cp_size:int=1, balanced_cp=False): +def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, + cu_seqlens_k: torch.LongTensor = None, + offs_q: torch.LongTensor = None, + *, + seq_len: int, + kv_stride: int = 1, + cp_rank: int = 0, + cp_size: int = 1, + balanced_cp=False): ''' seq_len: seq len per cp rank balanced cp slice assignment: 0 1 2 3 3 2 1 0 @@ -105,7 +145,7 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, cu_seqlens_k: to n_seq = len(cu_seqlens_q) - 1 assert n_seq > 0 assert cu_seqlens_q.shape == (n_seq + 1,) - seq_idx = cal_seq_idx_from_cu_seqlens(cu_seqlens_q.long(), seq_len*cp_size) + seq_idx = cal_seq_idx_from_cu_seqlens(cu_seqlens_q.long(), seq_len * cp_size) qs = cu_seqlens_q.gather(0, seq_idx) pos = torch.arange(len(qs), dtype=qs.dtype, device=qs.device) - qs if offs_q is not None: @@ -123,12 +163,14 @@ def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, cu_seqlens_k: to pass elif balanced_cp: assert cp_size % 2 == 0, cp_size + def f(x: torch.Tensor): - chunks = x.chunk(cp_size*2) + chunks = x.chunk(cp_size * 2) return torch.cat([ chunks[cp_rank], - chunks[cp_size-cp_rank-1], + chunks[cp_size - cp_rank - 1], ]) + ks = f(ks) ke = f(ke) else: @@ -165,8 +207,11 @@ def assert_similar(x, y, eps=1e-8, name="tensor"): cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) - cu_seqlens_qs = torch.cat([torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) - cu_seqlens_qe = torch.cat([cu_seqlens_cumsum, torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) + cu_seqlens_qs = torch.cat( + [torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) + cu_seqlens_qe = torch.cat( + [cu_seqlens_cumsum, + torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) from tilelang.profiler import do_bench From ff9e5986becfdfa42f07a942af4d91f83163a72a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Sep 2025 18:14:22 +0800 Subject: [PATCH 3/3] Remove deprecated sparse MLA and utility files to streamline the codebase. --- examples/deepseek_v32/README.md | 1 + examples/deepseek_v32/fp8_mqa_logits.py | 353 -------------- examples/deepseek_v32/sparse_mla_fwd.py | 283 ----------- .../deepseek_v32/sparse_mla_fwd_pipelined.py | 457 ------------------ examples/deepseek_v32/utils.py | 219 --------- 5 files changed, 1 insertion(+), 1312 deletions(-) create mode 100644 examples/deepseek_v32/README.md delete mode 100644 examples/deepseek_v32/fp8_mqa_logits.py delete mode 100644 examples/deepseek_v32/sparse_mla_fwd.py delete mode 100644 examples/deepseek_v32/sparse_mla_fwd_pipelined.py delete mode 100644 examples/deepseek_v32/utils.py diff --git a/examples/deepseek_v32/README.md b/examples/deepseek_v32/README.md new file mode 100644 index 000000000..cbbbc981f --- /dev/null +++ b/examples/deepseek_v32/README.md @@ -0,0 +1 @@ +Comming Soon. diff --git a/examples/deepseek_v32/fp8_mqa_logits.py b/examples/deepseek_v32/fp8_mqa_logits.py deleted file mode 100644 index 60376d6cc..000000000 --- a/examples/deepseek_v32/fp8_mqa_logits.py +++ /dev/null @@ -1,353 +0,0 @@ -import itertools -import tilelang -from tilelang import language as T -import torch -from tilelang import tvm -from utils import cal_cu_seqlen_ke_for_q, cal_cu_seqlen_ks_for_q, assert_similar -from typing import Tuple - - -def ceil_to_ue8m0(x: torch.Tensor): - assert x.view(-1).amax().item() > 0 - return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) - - -def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], - use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: - excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) - x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) - sf = x_amax / 448.0 - sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf - x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) - return x_scaled, sf.squeeze() - - -def get_configs(): - iter_params = dict( - block_N=[32, 64, 128], - num_stages=[0, 1, 2], - threads=[128, 256], - block_Q=[1, 2, 4], - ) - return [{ - k: v for k, v in zip(iter_params, values) - } for values in itertools.product(*iter_params.values())] - - -class SupplyProg: - - def __init__(self): - self.tensors_dict = {} - - def get_key(self, shape, dtype) -> str: - return f"{shape}-{dtype}" - - def supply_prog(self, params): - shapes = [p.shape for p in params] - dtypes = [p.dtype for p in params] - tensor_list = [] - for shape, dtype in zip(shapes, dtypes): - key = self.get_key(shape, dtype) - if key not in self.tensors_dict: - self.tensors_dict[key] = torch.randn(shape, dtype=dtype, device="cuda") - tensor_list.append(self.tensors_dict[key]) - else: - tensor_list.append(self.tensors_dict[key]) - return tensor_list - - -supply_prog = SupplyProg() - - -@tilelang.jit -def mqa_attn_return_logits( - heads, - index_dim, - block_N=256, - num_stages=3, - threads=512, - block_Q=None, -): - if block_Q is None: - block_Q = 128 // heads - dtype = "float8_e4m3" - accum_dtype = "float" - index_dtype = "int32" - - seq_len = tvm.te.var("seq_len") - seq_len_kv = tvm.te.var("seq_len_kv") - - index_q_shape = [seq_len * heads, index_dim] - index_k_shape = [seq_len_kv, index_dim] - index_k_scale_shape = [seq_len_kv] - logits_shape = [seq_len, seq_len_kv] - - @T.prim_func - def mqa_attn_return_logits_kernel( - IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore - IndexK: T.Tensor(index_k_shape, dtype), # type: ignore - IndexKScale: T.Tensor(index_k_scale_shape, accum_dtype), # type: ignore - Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore - Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore - ): - with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: - - index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) - index_k_shared = T.alloc_shared([block_N, index_dim], dtype) - index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) - s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) - s_reshaped = T.alloc_fragment([block_N, block_Q, heads], accum_dtype) - logits = T.alloc_fragment([block_N, block_Q], accum_dtype) - weights = T.alloc_fragment([block_Q, heads], accum_dtype) - - seq_len_i = bx * block_Q - - cu_k_s_min = T.alloc_local([1], index_dtype) - cu_k_e_max = T.alloc_local([1], index_dtype) - - T.no_set_max_nreg() - - cu_k_s_min[0] = 2147483647 - cu_k_e_max[0] = -2147483648 - - for bq_i in T.serial(block_Q): - cu_k_s_min[0] = T.min(cu_k_s_min[0], T.min(CuSeqLenKS[seq_len_i + bq_i], - seq_len_kv)) - for bq_i in T.serial(block_Q): - cu_k_e_max[0] = T.max(cu_k_e_max[0], T.min(CuSeqLenKE[seq_len_i + bq_i], - seq_len_kv)) - - T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) - T.copy(Weights[seq_len_i, 0], weights) - - for nbn_i in T.Pipelined( - T.ceildiv(cu_k_e_max[0] - cu_k_s_min[0], block_N), num_stages=num_stages): - T.copy(IndexK[cu_k_s_min[0] + nbn_i * block_N, 0], index_k_shared) - T.copy(IndexKScale[cu_k_s_min[0] + nbn_i * block_N], index_k_scale_fragment) - - T.gemm( - index_k_shared, - index_q_shared, - s, - transpose_B=True, - clear_accum=True, - policy=T.GemmWarpPolicy.FullCol, - ) - - for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): - s_reshaped[bn_i, bq_i, - h_i] = (T.max(s[bn_i, bq_i * heads + h_i], 0) * - weights[bq_i, h_i]) * index_k_scale_fragment[bn_i] - - T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) - - for bq_i, bn_i in T.Parallel(block_Q, block_N): - Logits[seq_len_i + bq_i, cu_k_s_min[0] + nbn_i * block_N + bn_i] = ( - logits[bn_i, bq_i]) - - return mqa_attn_return_logits_kernel - - -@tilelang.jit -def clean_logits_( - threads: int = 512, - block_K: int = 4096, -): - seq_len = tvm.te.var("seq_len") - seq_len_kv = tvm.te.var("seq_len_kv") - - dtype = "float" - indices_dtype = "int32" - - @T.prim_func - def clean_logits_kernel( - Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore - CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore - CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore - ): - with T.Kernel(seq_len, threads=threads) as bx: - tx = T.thread_binding(0, threads, thread="threadIdx.x") - cu_k_s = T.alloc_local([1], indices_dtype) - cu_k_e = T.alloc_local([1], indices_dtype) - cu_k_s[0] = CuSeqLenKS[bx] - cu_k_e[0] = CuSeqLenKE[bx] - - for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): - for k_i in T.serial(block_K // threads): - idx = n_i * block_K + k_i * threads + tx - if idx < cu_k_s[0] or idx >= cu_k_e[0]: - Logits[bx, idx] = -T.infinity(dtype) - - return clean_logits_kernel - - -def mqa_attn_return_logits_interface(q, - kv, - kv_scales, - weights, - cu_seqlen_ks, - cu_seqlen_ke, - clean_logits=True): - seq_len, heads, index_dim = q.shape - seq_len_kv = kv.shape[0] - - clean_logits_kernel = clean_logits_() - - mqa_attn_return_logits_kernel = mqa_attn_return_logits(heads=heads, index_dim=index_dim) - logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32) - mqa_attn_return_logits_kernel( - q.view(seq_len * heads, index_dim), - kv, - kv_scales, - logits, - weights, - cu_seqlen_ks, - cu_seqlen_ke, - ) - if clean_logits: - clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke) - return logits - - -def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor, - cu_seqlen_ks: torch.Tensor, cu_seqlen_ke: torch.Tensor): - k = kv - q = q.float() - k = k.float() - - seq_len_kv = kv.shape[0] - mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= cu_seqlen_ks[:, None] - mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < cu_seqlen_ke[:, None] - mask = mask_lo & mask_hi - - score = torch.einsum('mhd,nd->hmn', q, k) - logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) - logits = logits.masked_fill(~mask, float('-inf')) - - cost = mask.sum() - return logits, cost - - -if __name__ == "__main__": - torch.manual_seed(0) - S, SKV, H, HKV, D, kv_stride = 4096, 8192, 32, 1, 64, 1 - q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) - kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) - weights = torch.randn(S, H, device="cuda", dtype=torch.float32) - p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) - - def generate_random_cu_seqlens(per_cp_seqlen, - cp_size=4, - cp_rank=3, - kv_stride=1, - average_q_len=512): - total_seqlen = per_cp_seqlen * cp_size - - cu_seqlens = torch.randint(0, average_q_len * 2, - (total_seqlen // average_q_len * 2,)).cuda() - last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0] - cu_seqlens = cu_seqlens[:last_seq_id] - - if cu_seqlens.sum() < total_seqlen: - cu_seqlens = torch.cat( - [cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()]) - - total_seqlen_k = (cu_seqlens // kv_stride).sum() - - cu_seqlens_cumsum = torch.cumsum(cu_seqlens, dim=0) - cu_seqlens_k_cumsum = torch.cumsum(cu_seqlens // kv_stride, dim=0) - cu_seqlens_qs = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_cumsum[:-1]]) - cu_seqlens_ks = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_k_cumsum[:-1]]) - cu_seqlens_qe = cu_seqlens_cumsum.clone() - cu_seqlens_ke = cu_seqlens_k_cumsum.clone() - - cu_seqlens_ks_for_each_q = cal_cu_seqlen_ks_for_q( - cu_seqlens_qs=cu_seqlens_qs, - cu_seqlens_qe=cu_seqlens_qe, - cu_seqlens_ks=cu_seqlens_ks, - seq_len=total_seqlen, - ) - cu_seqlens_ke_for_each_q = cal_cu_seqlen_ke_for_q( - cu_seqlens_qs=cu_seqlens_qs, - cu_seqlens_qe=cu_seqlens_qe, - cu_seqlens_ks=cu_seqlens_ks, - cu_seqlens_ke=cu_seqlens_ke, - q_start_idxs=torch.zeros_like(cu_seqlens_qs), - seq_len=total_seqlen, - kv_stride=kv_stride, - ) - - assert per_cp_seqlen % 2 == 0 - per_chunk_seqlen = per_cp_seqlen // 2 - slice_short = slice(cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen) - slice_long = slice( - total_seqlen - (cp_rank + 1) * per_chunk_seqlen, - total_seqlen - cp_rank * per_chunk_seqlen, - ) - ks = torch.cat([ - cu_seqlens_ks_for_each_q[slice_short], - cu_seqlens_ks_for_each_q[slice_long], - ]) - ke = torch.cat([ - cu_seqlens_ke_for_each_q[slice_short], - cu_seqlens_ke_for_each_q[slice_long], - ]) - assert len(ks) == len(ke) == per_cp_seqlen - return ks, ke - - ks, ke = generate_random_cu_seqlens( - per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) - - logits_ref, cost_ref = ref_fp8_mqa_logits( - q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - - q_fp8 = q.to(torch.float8_e4m3fn) - kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) - - logits_tl = mqa_attn_return_logits_interface( - q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - diff = assert_similar(logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False) - - original_diff = None - for i in range(10): - logits_tl = mqa_attn_return_logits_interface( - q=q_fp8, - kv=kv_fp8, - kv_scales=kv_scales, - weights=weights, - cu_seqlen_ks=ks, - cu_seqlen_ke=ke) - diff = assert_similar(logits_ref, logits_tl, eps=1e-14, name="logits", raise_assert=False) - if original_diff is None: - original_diff = diff - else: - assert original_diff == diff - - from tilelang.profiler import do_bench - - def logits_fn(): - return mqa_attn_return_logits_interface( - q=q_fp8, - kv=kv_fp8, - kv_scales=kv_scales, - weights=weights, - cu_seqlen_ks=ks, - cu_seqlen_ke=ke) - - with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: - logits_fn() - - print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50)) - - logits_ms = do_bench(logits_fn, warmup=100, rep=100) - logits_flops = 2 * cost_ref * H * D - logits_tflops = logits_flops / (logits_ms * 1e-3) / 1e12 - print(f"logits_tflops: {logits_tflops}, logits_ms: {logits_ms}") - - print(f"cost_ref: {cost_ref}") - - torch.cuda.profiler.start() - logits_fn() - torch.cuda.profiler.stop() diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py deleted file mode 100644 index cd7659551..000000000 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ /dev/null @@ -1,283 +0,0 @@ -import torch -import tilelang -from tilelang import language as T -from tilelang import tvm -from utils import print_red_warning, calc_sim, assert_similar - - -@tilelang.jit( - out_idx=[-2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, -) -def sparse_attention_fwd( - heads, - dim, - tail_dim, - topk, - kv_group=1, - sm_scale=None, - is_causal=True, - CP0=True, - block_I=64, - num_stages=2, - threads=256, -): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" - assert is_causal == True, "non-casual is not supported" - assert (topk % - block_I == 0), "otherwise will load some index=0 thus causing wrong kv to be loaded" - if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) - else: - sm_scale = sm_scale * 1.44269504 # log2(e) - - batch = T.symbolic("batch") - seq_len = T.symbolic("seq_len") - seq_len_kv = T.symbolic("seq_len_kv") - - head_kv = heads // kv_group - q_shape = [batch, seq_len, heads, dim + tail_dim] - kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] - o_shape = [batch, seq_len, heads, dim] - indices_shape = [batch, seq_len, kv_group, topk] - lse_shape = [batch, seq_len, heads] - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" - - G = kv_group - H = head_kv - padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) - if padded_H != H: - assert ( - kv_group == 1 - ), "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" - BI = block_I - NI = tilelang.cdiv(topk, block_I) - D = dim - D_tail = tail_dim - - if head_kv > 64: - assert head_kv % 64 == 0, "head_kv should be a multiple of 64" - REPLICATE_H = head_kv // 64 - else: - REPLICATE_H = 1 - - H_per_block = padded_H if REPLICATE_H == 1 else 64 - - @T.prim_func - def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore - ): - with T.Kernel( - seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( - bx, - by, - bz, - ): - Q_shared = T.alloc_shared([H_per_block, D], dtype) - Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) - KV_shared = T.alloc_shared([BI, D], dtype) - K_tail_shared = T.alloc_shared([BI, D_tail], dtype) - O_shared = T.alloc_shared([H_per_block, D], dtype) - Lse_shared = T.alloc_shared([H_per_block], accum_dtype) - mask = T.alloc_fragment([BI], "bool") - - acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) - acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) - S_shared = T.alloc_shared([H_per_block, BI], dtype) - sumexp = T.alloc_fragment([H_per_block], accum_dtype) - sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) - alpha = T.alloc_fragment([H_per_block], accum_dtype) - m_i = T.alloc_fragment([H_per_block], accum_dtype) - m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) - - T.fill(acc_o, 0) - T.fill(sumexp, 0) - T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan - - b_i, g_i = by, bz - s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) - q_i = s_i - max_kv_i = q_i - - H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) - H1 = H0 + H_per_block - - T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) - T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) - - for i_i in T.Pipelined(NI, num_stages=num_stages): - - for bi_i in T.Parallel(BI): - mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] <= max_kv_i - - for bi_i, d_i in T.Parallel(BI, D): - KV_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - d_i] - for bi_i, d_i in T.Parallel(BI, D_tail): - K_tail_shared[bi_i, d_i] = KV[b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, - D + d_i] - - for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(mask[bi_i], 0, -T.infinity(acc_s.dtype)) - T.gemm( - Q_shared, - KV_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - ) - T.gemm( - Q_tail_shared, - K_tail_shared, - acc_s, - transpose_B=True, - policy=T.GemmWarpPolicy.FullCol, - ) - T.copy(m_i, m_i_prev) - T.reduce_max(acc_s, m_i, dim=1, clear=False) - for h_i in T.Parallel(H_per_block): - alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) - for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) - T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? - for h_i in T.Parallel(H_per_block): - sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] - for h_i, d_i in T.Parallel(H_per_block, D): - acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] - - T.copy(acc_s, S_shared) - T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) - - # Rescale - for h_i, d_i in T.Parallel(H_per_block, D): - acc_o[h_i, d_i] /= sumexp[h_i] - for h_i in T.Parallel(H_per_block): - sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale - - T.copy(acc_o, O_shared) - T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) - T.copy(sumexp, Lse_shared) - T.copy(sumexp, Lse[b_i, s_i, H0:H1]) - - return main - - -def sparse_attention_fwd_interface(q, - kv, - indices, - sm_scale=None, - return_p_sum: bool = False, - d_v=512): - is_casual = True - assert return_p_sum == False, "This kernel file is for fwd only" - assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() - batch, seq_len, heads, dim_plus_tail_dim = q.shape - _, seq_len_kv, kv_group, _ = kv.shape - - assert dim_plus_tail_dim == 576, "you should assign dim otherwise" - dim = d_v - - assert kv.shape[-1] == dim_plus_tail_dim - tail_dim = dim_plus_tail_dim - dim - assert kv.shape[0] == batch - _, _, _, topk = indices.shape - assert indices.shape == (batch, seq_len, kv_group, topk) - - kernel = sparse_attention_fwd(heads, dim, tail_dim, topk, kv_group, sm_scale, is_casual) - out, lse = kernel(q, kv, indices) - return out, lse - - -def ref_sparse_attention_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True): - q = q.float() - kv = kv.float() - indices = indices.transpose(1, 2) - b, sq, h, dim_q = q.shape - b, sk, g, _ = kv.shape - - assert kv.shape[-1] == 576, "you should assign dim otherwise" - dim = 512 - k = kv - v = kv[..., :dim] - - b, _, _, dim_v = v.shape - num_kv_per_index = 1 - g_index = g - h_index = h // g - compressed_casual_mask = torch.arange( - 0, sq, dtype=torch.int32, device="cuda").view(-1, 1) >= torch.arange( - 1 - 1, sk * 1, 1, dtype=torch.int32, device="cuda").view(1, -1) - - mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) - mask = mask[..., :-1] - mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :1 - 1, 0] = True - mask = mask.view(b, g_index, 1, sq, sk) - - q = q.view(b, sq, g, -1, dim_q) - score = torch.einsum("bmghd,bngd->bghmn", q, k) - sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale - score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) - p = score.softmax(dim=-1) - p = p.view(b, g_index, h_index, -1, sq, sk) - p = p.view(b, g, -1, sq, sk) - o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) - o = o.reshape(b, sq, h, dim_v) - return o.to(torch.float16) - - -def test_sparse_attn_mla_fwd(): - B, S, SKV, H, HKV, DQK, DV, topk, dtype = ( - 1, - 4096, - 32768, - 128, - 1, - 576, - 512, - 2048, - torch.bfloat16, - ) - - torch.random.manual_seed(0) - q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) - kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) - - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") - for b in range(B): - for t in range(S): - for h in range(HKV): - i_i = torch.randperm(max(1, t))[:topk] - indices[b, t, h, :len(i_i)] = i_i - - tl_out, tl_lse = sparse_attention_fwd_interface(q, kv, indices) - - def fn(): - return sparse_attention_fwd_interface(q, kv, indices) - - from tilelang.profiler import do_bench - - ms = do_bench( - fn, - rep=100, - warmup=250, - ) - print(f"Average time: {ms:.3f} ms") - print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) - print(f"fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) - - -if __name__ == "__main__": - test_sparse_attn_mla_fwd() diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py deleted file mode 100644 index cacad4afd..000000000 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ /dev/null @@ -1,457 +0,0 @@ -# ruff: noqa -import torch -import tilelang -from tilelang import language as T -from tilelang import tvm - -from tilelang.engine.callback import register_cuda_postproc_callback -import argparse - - -@tilelang.jit( - out_idx=[-2, -1], - compile_flags=[ - "-O3", "-Wno-deprecated-declarations", "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", "--expt-relaxed-constexpr", "--expt-extended-lambda", - "--ptxas-options=-v,--register-usage-level=10", "-DNDEBUG" - ], -) -def sparse_attention_fwd( - batch, - seq_len, - seq_len_kv, - heads, - dim, - tail_dim, - topk, - kv_stride, - kv_group=1, - sm_scale=None, - is_causal=True, - CP0=True, - block_I=64, - num_stages=0, - threads=384, -): - assert dim == tilelang.math.next_power_of_2( - dim), f"haven't check padding correctness yet, dim={dim}" - assert tail_dim == tilelang.math.next_power_of_2( - tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" - assert is_causal == True, 'non-casual is not supported' - assert topk % block_I == 0, 'otherwise will load some index=0 thus causing wrong kv to be loaded' - if sm_scale is None: - sm_scale = (1.0 / (dim + tail_dim))**0.5 * 1.44269504 # log2(e) - else: - sm_scale = sm_scale * 1.44269504 # log2(e) - - head_kv = heads // kv_group - q_shape = [batch, seq_len, heads, dim + tail_dim] - kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] - o_shape = [batch, seq_len, heads, dim] - indices_shape = [batch, seq_len, kv_group, topk] - lse_shape = [batch, seq_len, heads] - indices_dtype = "int32" - dtype = "bfloat16" - accum_dtype = "float" - - G = kv_group - H = head_kv - padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) - if padded_H != H: - assert kv_group == 1, 'here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)' - BI = block_I - NI = tilelang.cdiv(topk, block_I) - assert NI % 2 == 0, 'NI should be a multiple of 2' - D = dim - D_tail = tail_dim - KV_stride = kv_stride - if head_kv > 64: - assert head_kv % 64 == 0, 'head_kv should be a multiple of 64' - REPLICATE_H = head_kv // 64 - else: - REPLICATE_H = 1 - - H_per_block = padded_H if REPLICATE_H == 1 else 64 - - @T.prim_func - def main( - Q: T.Tensor(q_shape, dtype), # type: ignore - KV: T.Tensor(kv_shape, dtype), # type: ignore - Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore - q_start_index_s: T.Tensor(1, indices_dtype), - Output: T.Tensor(o_shape, dtype), # type: ignore - Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore - ): - with T.Kernel( - (seq_len - kv_stride + 1 if CP0 else seq_len) * REPLICATE_H, - batch, - kv_group, - threads=threads) as (bx, by, bz): - Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) - Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) - Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) - KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype) - KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype) - KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype) - KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype) - K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype) - K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype) - O_shared_l = Q_shared_l - O_shared_r = Q_shared_r - is_kv_valid = T.alloc_shared([BI], "bool", scope="shared") - - acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype) - acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype) - acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) - S_shared = T.alloc_shared([H_per_block, BI], dtype) - sumexp = T.alloc_fragment([H_per_block], accum_dtype) - sum_exp_shared = T.alloc_shared([H_per_block], accum_dtype) - sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) - alpha_shared = T.alloc_shared([H_per_block], accum_dtype, scope="shared") - alpha_local = T.alloc_fragment([H_per_block], accum_dtype) - m_i = T.alloc_fragment([H_per_block], accum_dtype) - m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) - indices_local = T.alloc_local([1], indices_dtype) - - # TODO: Multi buffer - bar_q = T.alloc_barrier(arrive_count=384) - bar_k_0_ready = T.alloc_barrier(arrive_count=128) - bar_k_1_ready = T.alloc_barrier(arrive_count=128) - bar_k_0_free = T.alloc_barrier(arrive_count=256) - bar_k_1_free = T.alloc_barrier(arrive_count=256) - bar_sScale_and_sS_ready = T.alloc_barrier(arrive_count=256) - bar_sScale_and_sS_free = T.alloc_barrier(arrive_count=256) - - b_i, g_i = by, bz - s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else ( - bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) - q_i = q_start_index_s[0] + s_i - max_kv_i = (q_i + 1 - KV_stride) // KV_stride - - H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) - H1 = H0 + H_per_block - - tx = T.get_thread_binding() - - T.copy(Q[b_i, s_i, H0:H1, 0:D // 2], Q_shared_l) - T.copy(Q[b_i, s_i, H0:H1, D // 2:D], Q_shared_r) - T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) - T.barrier_arrive(bar_q) - - if tx < 128: - T.set_max_nreg(240, 1) - T.fill(sumexp, 0) - T.fill(m_i, -2**30) # avoid -inf - inf to cause nan - T.fill(acc_o_l, 0) - T.barrier_wait(bar_q, 0) - - for i_i in T.serial(T.ceildiv(NI, 2)): - - # Buffer 0 - T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) - - for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) - T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) - - T.wait_wgmma(0) - - if i_i != 0: - T.barrier_arrive(bar_sScale_and_sS_free) - T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2) & 1) ^ 1) - - T.copy(m_i, m_i_prev) - T.reduce_max(acc_s, m_i, dim=1, clear=False) - for h_i in T.Parallel(H_per_block): - alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) - for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) - T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? - for h_i in T.Parallel(H_per_block): - sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] - for h_i, d_i in T.Parallel(H_per_block, D // 2): - acc_o_l[h_i, d_i] *= alpha_local[h_i] - T.copy(alpha_local, alpha_shared) - - T.copy(acc_s, S_shared) - T.gemm(S_shared, KV_shared_0_l, acc_o_l) - - T.barrier_arrive(bar_sScale_and_sS_ready) - T.barrier_arrive(bar_k_0_free[0]) - - # Buffer 1 - T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) - - for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, - -T.infinity(acc_s.dtype)) - T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) - - T.wait_wgmma(0) - - T.barrier_arrive(bar_sScale_and_sS_free) - T.barrier_wait(bar_sScale_and_sS_free, ((i_i * 2 + 1) & 1) ^ 1) - - T.copy(m_i, m_i_prev) - T.reduce_max(acc_s, m_i, dim=1, clear=False) - for h_i in T.Parallel(H_per_block): - alpha_local[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) - for h_i, bi_i in T.Parallel(H_per_block, BI): - acc_s[h_i, bi_i] = T.exp2(acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale) - T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? - for h_i in T.Parallel(H_per_block): - sumexp[h_i] = sumexp[h_i] * alpha_local[h_i] + sumexp_i[h_i] - for h_i, d_i in T.Parallel(H_per_block, D // 2): - acc_o_l[h_i, d_i] *= alpha_local[h_i] - T.copy(alpha_local, alpha_shared) - - T.copy(acc_s, S_shared) - T.gemm(S_shared, KV_shared_1_l, acc_o_l) - - T.barrier_arrive(bar_sScale_and_sS_ready) - T.barrier_arrive(bar_k_1_free[0]) - - # Rescale - for h_i in T.Parallel(H_per_block): - sum_exp_shared[h_i] = sumexp[h_i] - for h_i, d_i in T.Parallel(H_per_block, D // 2): - acc_o_l[h_i, d_i] /= sumexp[h_i] - for h_i in T.Parallel(H_per_block): - sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale - T.copy(acc_o_l, O_shared_l) - T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0:D // 2]) - - elif tx >= 128 and tx < 256: - T.set_max_nreg(168, 1) - T.fill(acc_o_r, 0) - for i_i in T.serial(T.ceildiv(NI, 2)): - # Buffer 0 - T.barrier_arrive(bar_sScale_and_sS_ready) - T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2) & 1)) - for h_i, d_i in T.Parallel(H_per_block, D // 2): - acc_o_r[h_i, d_i] *= alpha_shared[h_i] - T.gemm(S_shared, KV_shared_0_r, acc_o_r) - T.barrier_arrive(bar_k_0_free[0]) - T.barrier_arrive(bar_sScale_and_sS_free) - - # Buffer 1 - T.barrier_arrive(bar_sScale_and_sS_ready) - T.barrier_wait(bar_sScale_and_sS_ready, ((i_i * 2 + 1) & 1)) - for h_i, d_i in T.Parallel(H_per_block, D // 2): - acc_o_r[h_i, d_i] *= alpha_shared[h_i] - T.gemm(S_shared, KV_shared_1_r, acc_o_r) - T.barrier_arrive(bar_k_1_free[0]) - if i_i != T.ceildiv(NI, 2) - 1: - T.barrier_arrive(bar_sScale_and_sS_free) - - # Rescale - for h_i, d_i in T.Parallel(H_per_block, D // 2): - acc_o_r[h_i, d_i] /= sum_exp_shared[h_i] - - T.copy(acc_o_r, O_shared_r) - T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2:D]) - elif tx >= 256: - # producer - T.set_max_nreg(80, 0) - for i_i in T.serial(T.ceildiv(NI, 2)): - # Buffer 0 - T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) - for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2) * BI + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i - if is_kv_valid[r * 16 + (tx - 256) // 8]: - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_0_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] - T.cp_async_barrier_noinc(bar_k_0_ready[0]) - - # Buffer 1 - T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) - for r in T.serial(4): - indices_local[0] = Indices[b_i, s_i, g_i, - (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] - is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local[0] <= max_kv_i - if is_kv_valid[r * 16 + (tx - 256) // 8]: - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - 64 * u + (tx - 256) % 8 * 8 + v] - KV_shared_1_r[r * 16 + (tx - 256) // 8, - 64 * u + (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, D // 2 + - 64 * u + (tx - 256) % 8 * 8 + v] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + - v] = KV[b_i, indices_local[0], g_i, - D + (tx - 256) % 8 * 8 + v] - T.cp_async_barrier_noinc(bar_k_1_ready[0]) - - return main - - -def sparse_attention_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride, - sm_scale=None, - is_casual=True, - return_kernel=False, - print_kernel=False): - assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() - batch, seq_len, heads, dim_plus_tail_dim = q.shape - _, seq_len_kv, kv_group, _ = kv.shape - - assert dim_plus_tail_dim == 576, 'you should assign dim otherwise' - dim = 512 - - assert kv.shape[-1] == dim_plus_tail_dim - tail_dim = dim_plus_tail_dim - dim - assert kv.shape[0] == batch - _, _, _, topk = indices.shape - assert indices.shape == (batch, seq_len, kv_group, topk) - - if q_start_index_s != 0: - assert q_start_index_s > kv_stride, "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" - CP0 = q_start_index_s == 0 - - kernel = sparse_attention_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, - kv_group, sm_scale, is_casual, CP0) - if print_kernel: - print(kernel.get_kernel_source()) - out, lse = kernel(q, kv, indices, - torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) - if return_kernel: - return kernel - if q_start_index_s == 0 and kv_stride > 1: - out[:, :kv_stride - 1, :, :] = 0 - return out, lse - - -def ref_sparse_attention_fwd_interface(q, - kv, - indices, - q_start_index_s, - kv_stride=4, - sm_scale=None, - is_casual=True): - q = q.float() - kv = kv.float() - indices = indices.transpose(1, 2) - b, sq, h, dim_q = q.shape - b, sk, g, _ = kv.shape - if q_start_index_s is None: - q_start_index_s = sk * kv_stride - sq - - assert kv.shape[-1] == 576, 'you should assign dim otherwise' - dim = 512 - k = kv - v = kv[..., :dim] - - b, _, _, dim_v = v.shape - num_kv_per_index = 1 - g_index = g - h_index = h // g - compressed_casual_mask = torch.arange( - q_start_index_s, sq + q_start_index_s, dtype=torch.int32, - device="cuda").view(-1, 1) >= torch.arange( - kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) - - mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) - mask = mask[..., :-1] - mask = mask & compressed_casual_mask.view(1, 1, sq, sk) - mask[:, :, :kv_stride - 1, 0] = True - mask = mask.view(b, g_index, 1, sq, sk) - - q = q.view(b, sq, g, -1, dim_q) - score = torch.einsum("bmghd,bngd->bghmn", q, k) - sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale - score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) - p = score.softmax(dim=-1) - p = p.view(b, g_index, h_index, -1, sq, sk) - p = p.view(b, g, -1, sq, sk) - o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) - o = o.reshape(b, sq, h, dim_v) - return o.to(torch.float16) - - -def test_sparse_attn_mla_fwd(test_correctness=False): - KV_stride = 1 - if test_correctness: - B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 2048, 128, 1, 576, 512, 2048, torch.bfloat16 - q_start_s_index = 1024 - else: - B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 - q_start_s_index = 4096 * 64 - - torch.random.manual_seed(0) - q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 - kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 - q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") - - q.clamp_(-10, 10) - kv.clamp_(-10, 10) - - indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') - for b in range(B): - for t in range(S): - for h in range(HKV): - i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] - indices[b, t, h, :len(i_i)] = i_i - - kernel = sparse_attention_fwd_interface( - q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) - - def fn(): - out, lse = kernel(q, kv, indices, q_start_s_index_t) - if q_start_s_index == 0 and kv_stride > 1: - out[:, :kv_stride - 1, :, :] = 0 - return out, lse - - tl_out, tl_lse = fn() - if test_correctness: - ref_out = ref_sparse_attention_fwd_interface(q, kv, indices, q_start_s_index, KV_stride) - print(f"tl_out: {tl_out}") - print(f"ref_out: {ref_out}") - assert_similar(tl_out, ref_out) - - from tilelang.profiler import do_bench - ms = do_bench( - fn, - rep=10, - warmup=10, - ) - print(f"Average time: {ms:.3f} ms") - print(f'fwd io bandwidth = ', (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) - print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--test_correctness", action="store_true") - args = parser.parse_args() - test_sparse_attn_mla_fwd(args.test_correctness) diff --git a/examples/deepseek_v32/utils.py b/examples/deepseek_v32/utils.py deleted file mode 100644 index 201e13405..000000000 --- a/examples/deepseek_v32/utils.py +++ /dev/null @@ -1,219 +0,0 @@ -import torch -import torch.nn.functional as F -import triton -import triton.language as tl - -import contextlib -import functools -import logging -import os -import sys -from enum import Enum -from functools import lru_cache -from typing import Any, Callable, Dict, Literal, Optional, Tuple - -from packaging import version - - -def _is_equal(a, b): - if isinstance(a, torch.Tensor): - return a is b - # Whitelist of types that are safe to compare by value for caching. - if isinstance(a, (int, float, str, bool, type(None))) and isinstance( - b, (int, float, str, bool, type(None))): - return a == b - # For other types, we cannot guarantee a cheap and safe comparison, so we fail the cache check. - return False - - -def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: - """ - A decorator that caches the most recent result of a function with tensor inputs. - - This decorator will store the output of the decorated function for the most recent set of input tensors. - If the function is called again with the same input tensors, it will return the cached result. - - - Args: - fn (Callable[..., torch.Tensor]): - The function to be decorated. It should take tensor inputs and return tensor outputs. - - Returns: - Callable[..., torch.Tensor]: - A wrapped version of the input function with single-entry caching. - """ - last_args: Optional[Tuple] = None - last_kwargs: Optional[Dict] = None - last_result: Any = None - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> Any: - nonlocal last_args, last_kwargs, last_result - - if last_args is not None and last_kwargs is not None: - if len(args) == len(last_args) and len(kwargs) == len(last_kwargs): - # For Tensors, check for object identity. For other types, check for equality. - # Python caches small integers, so `is` works for them but not for large integers like 4096. - if all(_is_equal(a, b) for a, b in zip(args, last_args)) and \ - set(kwargs.keys()) == set(last_kwargs.keys()) and \ - all(_is_equal(v, last_kwargs[k]) for k, v in kwargs.items()): - return last_result - - result = fn(*args, **kwargs) - last_args, last_kwargs, last_result = args, kwargs, result - return result - - return wrapper - - -@tensor_cache -def cal_seq_idx_from_cu_seqlens(cu_seqlens: torch.LongTensor, seq_len: int): - seq_idx = cu_seqlens.new_zeros(seq_len + 1) - seq_idx.scatter_add_(0, cu_seqlens[1:].long(), torch.ones_like(seq_idx)) - seq_idx.cumsum_(0) - return seq_idx[:-1] - - -@tensor_cache -def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - seq_len: int) -> torch.IntTensor: - seq_idx_for_q = torch.full((seq_len,), - len(cu_seqlens_qs), - dtype=torch.int32, - device=cu_seqlens_qs.device) - for i in range(len(cu_seqlens_qs)): - seq_idx_for_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = i - return seq_idx_for_q - - -@tensor_cache -def cal_cu_seqlen_ks_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, seq_len: int) -> torch.IntTensor: - cu_seqlen_ks_for_each_q = torch.gather( - input=torch.cat([ - cu_seqlens_ks, - torch.full((1,), - torch.iinfo(torch.int32).max, - dtype=torch.int32, - device=cu_seqlens_qs.device) - ]), - dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) - return cu_seqlen_ks_for_each_q.int() - - -@tensor_cache -def cal_cu_seqlen_ke_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, - cu_seqlens_ks: torch.LongTensor, cu_seqlens_ke: torch.LongTensor, - q_start_idxs: torch.LongTensor, seq_len: int, - kv_stride: int) -> torch.IntTensor: - cu_seqlen_ke_for_each_q = torch.gather( - input=torch.cat( - [cu_seqlens_ke, - torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), - dim=0, - index=cal_seq_idx_for_q( - cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long()) - casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), - dtype=torch.int32, - device=cu_seqlens_qs.device) - for i in range(len(cu_seqlens_qs)): - casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i]:cu_seqlens_qe[i]] = (torch.arange( - q_start_idxs[i], - q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], - dtype=torch.int32, - device=cu_seqlens_qs.device) + 1) // kv_stride + cu_seqlens_ks[i] - cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) - return cu_seqlen_ke_for_each_q.int() - - -@tensor_cache -def cal_ks_ke_from_cu_seqlen_qk(cu_seqlens_q: torch.LongTensor, - cu_seqlens_k: torch.LongTensor = None, - offs_q: torch.LongTensor = None, - *, - seq_len: int, - kv_stride: int = 1, - cp_rank: int = 0, - cp_size: int = 1, - balanced_cp=False): - ''' - seq_len: seq len per cp rank - balanced cp slice assignment: 0 1 2 3 3 2 1 0 - ''' - n_seq = len(cu_seqlens_q) - 1 - assert n_seq > 0 - assert cu_seqlens_q.shape == (n_seq + 1,) - seq_idx = cal_seq_idx_from_cu_seqlens(cu_seqlens_q.long(), seq_len * cp_size) - qs = cu_seqlens_q.gather(0, seq_idx) - pos = torch.arange(len(qs), dtype=qs.dtype, device=qs.device) - qs - if offs_q is not None: - assert offs_q.shape == (n_seq,), offs_q.shape - qoff = offs_q.gather(0, seq_idx) - pos += qoff - if cu_seqlens_k is None or cu_seqlens_k is cu_seqlens_q: - ks = qs - else: - assert cu_seqlens_k.shape == (n_seq + 1,) - ks = cu_seqlens_k.gather(0, seq_idx) - ke = ks + (pos + 1) // kv_stride - - if cp_size == 1: - pass - elif balanced_cp: - assert cp_size % 2 == 0, cp_size - - def f(x: torch.Tensor): - chunks = x.chunk(cp_size * 2) - return torch.cat([ - chunks[cp_rank], - chunks[cp_size - cp_rank - 1], - ]) - - ks = f(ks) - ke = f(ke) - else: - ks = ks.chunk(cp_size)[cp_rank] - ke = ke.chunk(cp_size)[cp_rank] - - return ks, ke - - -def print_red_warning(message): - print(f"\033[31mWARNING: {message}\033[0m") - - -def calc_sim(x, y, name="tensor"): - x, y = x.data.double(), y.data.double() - denominator = (x * x + y * y).sum() - if denominator == 0: - print_red_warning(f'{name} all zero') - return 1 - sim = 2 * (x * y).sum() / denominator - return sim - - -def assert_similar(x, y, eps=1e-8, name="tensor"): - sim = calc_sim(x, y, name) - diff = 1. - sim - if not (0 <= diff <= eps): - print_red_warning(f'{name} Error: {diff}') - assert False - - -if __name__ == "__main__": - seq_len = 32768 - cu_seqlens = torch.randint(128, 4096, (1000,), dtype=torch.int32, device="cuda") - last_idx = torch.where(cu_seqlens.cumsum(dim=0) >= seq_len)[0][0] - cu_seqlens_cumsum = cu_seqlens[:last_idx].cumsum(dim=0) - cu_seqlens_qs = torch.cat( - [torch.zeros(1, dtype=torch.int32, device=cu_seqlens.device), cu_seqlens_cumsum]) - cu_seqlens_qe = torch.cat( - [cu_seqlens_cumsum, - torch.ones(1, dtype=torch.int32, device=cu_seqlens.device) * seq_len]) - - from tilelang.profiler import do_bench - - fn = lambda: cal_seq_idx_for_q(cu_seqlens_qs, cu_seqlens_qe, seq_len) - ms = do_bench(fn, warmup=25, rep=100)