diff --git a/examples/deepseek_mla/benchmark_mla.py b/examples/deepseek_mla/benchmark_mla.py new file mode 100644 index 000000000..4b743dc02 --- /dev/null +++ b/examples/deepseek_mla/benchmark_mla.py @@ -0,0 +1,553 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py + +import argparse +import math +import random + +import flashinfer +import torch +import triton +import triton.language as tl + +# pip install flashinfer-python +from flash_mla import flash_mla_with_kvcache, get_mla_metadata + +import tilelang +from tilelang.profiler import do_bench +from example_mla_decode_paged import mla_decode_tilelang + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + +@torch.inference_mode() +def run_flash_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv) + + def flash_mla(): + return flash_mla_with_kvcache( + q, blocked_k, block_table, cache_seqlens, dv, + tile_scheduler_metadata, num_splits, causal=causal, + ) + + out_flash, lse_flash = flash_mla() + t = triton.testing.do_bench(flash_mla) + return out_flash, lse_flash, t + + +@torch.inference_mode() +def run_flash_infer(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + + kv_indptr = [0] + kv_indices = [] + for i in range(b): + seq_len = cache_seqlens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_table[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + for seq_len in cache_seqlens[1:]: + kv_indptr.append((seq_len + block_size - 1) // block_size + kv_indptr[-1]) + + q_indptr = torch.arange(0, b + 1).int() * s_q + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + + mla_wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper( + torch.empty(128 * 1024 * 1024, dtype=torch.int8), + backend="fa3" + ) + mla_wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + cache_seqlens, + h_q, + dv, + d-dv, + block_size, + causal, + 1 / math.sqrt(d), + q.dtype, + blocked_k.dtype, + ) + + def flash_infer(): + output, lse = mla_wrapper.run(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope, blocked_k_pe, return_lse=True) + return output.view(b, -1, h_q, dv), lse.view(b, h_q, 1) + + out_flash, lse_flash = flash_infer() + t = triton.testing.do_bench(flash_infer) + return out_flash, lse_flash, t + + +@triton.jit +def _mla_attn_kernel( + Q_nope, + Q_pe, + Kv_c_cache, + K_pe_cache, + Req_to_tokens, + B_seq_len, + O, + sm_scale, + stride_q_nope_bs, + stride_q_nope_h, + stride_q_pe_bs, + stride_q_pe_h, + stride_kv_c_bs, + stride_k_pe_bs, + stride_req_to_tokens_bs, + stride_o_b, + stride_o_h, + stride_o_s, + BLOCK_H: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, + HEAD_DIM_KPE: tl.constexpr, +): + cur_batch = tl.program_id(1) + cur_head_id = tl.program_id(0) + split_kv_id = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + cur_head = cur_head_id * BLOCK_H + tl.arange(0, BLOCK_H) + offs_q_nope = cur_batch * stride_q_nope_bs + cur_head[:, None] * stride_q_nope_h + offs_d_ckv[None, :] + q_nope = tl.load(Q_nope + offs_q_nope) + + offs_d_kpe = tl.arange(0, HEAD_DIM_KPE) + offs_q_pe = cur_batch * stride_q_pe_bs + cur_head[:, None] * stride_q_pe_h + offs_d_kpe[None, :] + q_pe = tl.load(Q_pe + offs_q_pe) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, HEAD_DIM_CKV], dtype=tl.float32) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_page_number = tl.load( + Req_to_tokens + stride_req_to_tokens_bs * cur_batch + offs_n // PAGE_SIZE, + mask=offs_n < split_kv_end, + other=0, + ) + kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE + offs_k_c = kv_loc[None, :] * stride_kv_c_bs + offs_d_ckv[:, None] + k_c = tl.load(Kv_c_cache + offs_k_c, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk = tl.dot(q_nope, k_c.to(q_nope.dtype)) + + offs_k_pe = kv_loc[None, :] * stride_k_pe_bs + offs_d_kpe[:, None] + k_pe = tl.load(K_pe_cache + offs_k_pe, mask=offs_n[None, :] < split_kv_end, other=0.0) + + qk += tl.dot(q_pe, k_pe.to(q_pe.dtype)) + qk *= sm_scale + + qk = tl.where(offs_n[None, :] < split_kv_end, qk, float("-inf")) + + v_c = tl.trans(k_c) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v_c.dtype), v_c) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + offs_o = cur_batch * stride_o_b + cur_head[:, None] * stride_o_h + split_kv_id * stride_o_s + offs_d_ckv[None, :] + tl.store(O + offs_o, acc / e_sum[:, None]) + offs_o_1 = cur_batch * stride_o_b + cur_head * stride_o_h + split_kv_id * stride_o_s + HEAD_DIM_CKV + tl.store(O + offs_o_1, e_max + tl.log(e_sum)) + + +def _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, +): + batch_size, head_num = q_nope.shape[0], q_nope.shape[1] + head_dim_ckv = q_nope.shape[-1] + head_dim_kpe = q_pe.shape[-1] + + BLOCK_H = 16 + BLOCK_N = 64 + grid = ( + triton.cdiv(head_num, BLOCK_H), + batch_size, + num_kv_splits, + ) + _mla_attn_kernel[grid]( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + req_to_tokens, + b_seq_len, + attn_logits, + sm_scale, + # stride + q_nope.stride(0), + q_nope.stride(1), + q_pe.stride(0), + q_pe.stride(1), + kv_c_cache.stride(-2), + k_pe_cache.stride(-2), + req_to_tokens.stride(0), + attn_logits.stride(0), + attn_logits.stride(1), + attn_logits.stride(2), + BLOCK_H=BLOCK_H, + BLOCK_N=BLOCK_N, + NUM_KV_SPLITS=num_kv_splits, + PAGE_SIZE=page_size, + HEAD_DIM_CKV=head_dim_ckv, + HEAD_DIM_KPE=head_dim_kpe, + ) + +@triton.jit +def _mla_softmax_reducev_kernel( + Logits, + B_seq_len, + O, + stride_l_b, + stride_l_h, + stride_l_s, + stride_o_b, + stride_o_h, + NUM_KV_SPLITS: tl.constexpr, + HEAD_DIM_CKV: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + cur_batch_seq_len = tl.load(B_seq_len + cur_batch) + + offs_d_ckv = tl.arange(0, HEAD_DIM_CKV) + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([HEAD_DIM_CKV], dtype=tl.float32) + + offs_l = cur_batch * stride_l_b + cur_head * stride_l_h + offs_d_ckv + offs_l_1 = cur_batch * stride_l_b + cur_head * stride_l_h + HEAD_DIM_CKV + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + logits = tl.load(Logits + offs_l + split_kv_id * stride_l_s) + logits_1 = tl.load(Logits + offs_l_1 + split_kv_id * stride_l_s) + + n_e_max = tl.maximum(logits_1, e_max) + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(logits_1 - n_e_max) + acc += exp_logic * logits + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_o_b + cur_head * stride_o_h + offs_d_ckv, + acc / e_sum, + ) + + +def _mla_softmax_reducev( + logits, + o, + b_seq_len, + num_kv_splits, +): + batch_size, head_num, head_dim_ckv = o.shape[0], o.shape[1], o.shape[2] + grid = (batch_size, head_num) + _mla_softmax_reducev_kernel[grid]( + logits, + b_seq_len, + o, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=num_kv_splits, + HEAD_DIM_CKV=head_dim_ckv, + num_warps=4, + num_stages=2, + ) + +def mla_decode_triton( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + o, + req_to_tokens, + b_seq_len, + attn_logits, + num_kv_splits, + sm_scale, + page_size, +): + assert num_kv_splits == attn_logits.shape[2] + _mla_attn( + q_nope, + q_pe, + kv_c_cache, + k_pe_cache, + attn_logits, + req_to_tokens, + b_seq_len, + num_kv_splits, + sm_scale, + page_size, + ) + _mla_softmax_reducev( + attn_logits, + o, + b_seq_len, + num_kv_splits, + ) + + +@torch.inference_mode() +def run_flash_mla_triton(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + blocked_v = blocked_k[..., :dv] + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + def flash_mla_triton(): + num_kv_splits = 32 + o = torch.empty([b * s_q, h_q, dv]) + attn_logits = torch.empty([b * s_q, h_q, num_kv_splits, dv + 1]) + mla_decode_triton(q_nope.view(-1, h_q, dv), q_pe.view(-1, h_q, d-dv), blocked_k_nope.view(-1, dv), blocked_k_pe.view(-1, d-dv), o, block_table, cache_seqlens, attn_logits, num_kv_splits, 1 / math.sqrt(d), block_size) + return o.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_triton() + t = triton.testing.do_bench(flash_mla_triton) + return out_flash, None, t + + +@torch.inference_mode() +def run_flash_mla_tilelang(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + BLOCK_N = 64 + BLOCK_H = 64 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + out = torch.empty(b, h_q, dv, dtype=dtype, device=q.device) + program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [8], tilelang.TensorSupplyType.Randn) + + def flash_mla_tilelang(): + out = mod.func( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, dpe), + blocked_k_nope.view(-1, h_kv, dv), + blocked_k_pe.view(-1, h_kv, dpe), + block_table, + cache_seqlens, + glse, + out_partial, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_tilelang() + t = do_bench(flash_mla_tilelang) + return out_flash, None, t + +FUNC_TABLE = { + "torch": run_torch_mla, + "tilelang": run_flash_mla_tilelang, + "flash_mla": run_flash_mla, + "flash_infer": run_flash_infer, + "flash_mla_triton": run_flash_mla_triton, +} + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["flash_infer", "flash_mla_triton", "flash_mla_tilelang"]: + # flash_infer has a different lse return value + # flash_mla_triton and flash_mla_tilelang doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10 ** 9 / perf_a:.0f} TFLOPS, {bytes / 10 ** 6 / perf_a:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") + return bytes / 10 ** 6 / perf_a, bytes / 10 ** 6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10 ** 9 / perf_b:.0f} TFLOPS, {bytes / 10 ** 6 / perf_b:.0f} GB/s") + return bytes / 10 ** 6 / perf_b + + +available_targets = [ + "torch", + "tilelang", + "flash_mla", + "flash_infer", + "flash_mla_triton", +] + +shape_configs = [ + {"b": batch, "s_q": 1, "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), "h_q": head, "h_kv": 1, "d": 512+64, "dv": 512, "causal": True, "dtype": torch.float16} + for batch in [128] for seqlen in [1024, 2048, 4096, 8192, 16384, 32768] for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="tilelang") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a(target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') + elif args.compare: + perfa, prefb = compare_ab(args.baseline, args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{args.baseline},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perfa:.0f}\n') + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{prefb:.0f}\n') + elif args.one: + perf = compare_a(args.target, shape["b"], shape["s_q"], shape["cache_seqlens"], shape["h_q"], shape["h_kv"], shape["d"], shape["dv"], shape["causal"], shape["dtype"]) + fout.write(f'{args.target},{shape["b"]},{shape["cache_seqlens"].float().mean().cpu().item():.0f},{shape["h_q"]},{perf:.0f}\n') \ No newline at end of file diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index d3168480f..a5c49757f 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -182,6 +182,7 @@ def combine( T.clear(lse_logsum_local) T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) for k in T.serial(num_split): lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) for k in T.Pipelined(num_split, num_stages=1): diff --git a/examples/deepseek_mla/example_mla_decode_paged.py b/examples/deepseek_mla/example_mla_decode_paged.py new file mode 100644 index 000000000..24c2b68ba --- /dev/null +++ b/examples/deepseek_mla/example_mla_decode_paged.py @@ -0,0 +1,372 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +from einops import rearrange, einsum +import argparse +from tilelang.profiler import do_bench +import math + +def mla_decode_tilelang(batch, h_q, h_kv, max_seqlen_pad, dv, dpe, block_N, block_H, num_split, block_size): + scale = (1.0 / (dv + dpe))**0.5 * 1.44269504 # log2(e) + dtype = "float16" + accum_dtype = "float" + kv_group_num = h_q // h_kv + VALID_BLOCK_H = min(block_H, kv_group_num) + assert h_kv == 1, "h_kv must be 1" + assert block_size >= block_N and block_size % block_N == 0, "block_size must be larger than block_N and a multiple of block_N" + + @T.macro + def flash_mla_kernel( + Q: T.Buffer([batch, h_q, dv], dtype), + Q_pe: T.Buffer([batch, h_q, dpe], dtype), + KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Buffer([batch, max_seqlen_pad // block_size], "int32"), + CACHE_SEQLENS: T.Buffer([batch], "int32"), + Output: T.Buffer([batch, h_q, dv], dtype), + ): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), threads=256) as (bx, by): + Q_shared = T.alloc_shared([block_H, dv], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) + KV_shared = T.alloc_shared([block_N, dv], dtype) + K_pe_shared = T.alloc_shared([block_N, dpe], dtype) + O_shared = T.alloc_shared([block_H, dv], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dv], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout({ + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + }) + + T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = T.ceildiv(CACHE_SEQLENS[bx], block_N) + for kr in T.Pipelined(loop_range, num_stages=2): + k = loop_range - 1 - kr + kv_start = BLOCK_TABLE[bx, (k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm( + Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm( + Q_pe_shared, + K_pe_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + with T.If(kr == 0), T.Then(): + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :]) + + @T.macro + def flash_mla_split_kv_kernel( + Q: T.Buffer([batch, h_q, dv], dtype), + Q_pe: T.Buffer([batch, h_q, dpe], dtype), + KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype), + BLOCK_TABLE: T.Buffer([batch, max_seqlen_pad // block_size], "int32"), + CACHE_SEQLENS: T.Buffer([batch], "int32"), + glse: T.Buffer([batch, h_q, num_split], dtype), + Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype), + ): + with T.Kernel(batch, h_q // min(block_H, kv_group_num), num_split, threads=256) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dv], dtype) + S_shared = T.alloc_shared([block_H, block_N], dtype) + Q_pe_shared = T.alloc_shared([block_H, dpe], dtype) + KV_shared = T.alloc_shared([block_N, dv], dtype) + K_pe_shared = T.alloc_shared([block_N, dpe], dtype) + O_shared = T.alloc_shared([block_H, dv], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dv], accum_dtype) + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + + cur_kv_head = by // (kv_group_num // block_H) + T.use_swizzle(10) + T.annotate_layout({ + O_shared: tilelang.layout.make_swizzled_layout(O_shared), + S_shared: tilelang.layout.make_swizzled_layout(S_shared), + }) + + T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared) + T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + total_blocks = T.ceildiv(CACHE_SEQLENS[bx], block_N) + blocks_per_split = T.floordiv(total_blocks, num_split) + remaining_blocks = T.floormod(total_blocks, num_split) + loop_range = (blocks_per_split + T.if_then_else(bz < remaining_blocks, 1, 0)) + start = (blocks_per_split * bz + T.min(bz, remaining_blocks)) * block_N + + for k in T.Pipelined(loop_range, num_stages=2): + kv_start = BLOCK_TABLE[bx, (start + k * block_N) // block_size] * block_size + (k * block_N) % block_size + T.copy(KV[kv_start:kv_start + block_N, cur_kv_head, :], KV_shared) + T.copy(K_pe[kv_start:kv_start + block_N, cur_kv_head, :], K_pe_shared) + T.clear(acc_s) + T.gemm( + Q_shared, KV_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullCol) + T.gemm( + Q_pe_shared, + K_pe_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else(start + k * block_N + j >= CACHE_SEQLENS[bx], -T.infinity(accum_dtype), acc_s[i, j]) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_H): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + T.copy(acc_s, S_shared) + T.copy(S_shared, acc_s_cast) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] *= scores_scale[i] + T.gemm(acc_s_cast, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol) + for i, j in T.Parallel(block_H, dv): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, glse[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz]) + T.copy(acc_o, O_shared) + T.copy(O_shared, Output_partial[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, bz, :]) + + @T.macro + def combine( + glse: T.Buffer([batch, h_q, num_split], dtype), + Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype), + Output: T.Buffer([batch, h_q, dv], dtype), + ): + with T.Kernel(h_q, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dv], dtype) + o_accum_local = T.alloc_fragment([dv], accum_dtype) + lse_local_split = T.alloc_local([1], accum_dtype) + lse_logsum_local = T.alloc_local([1], accum_dtype) + lse_max_local = T.alloc_local([1], accum_dtype) + scale_local = T.alloc_local([1], accum_dtype) + + T.annotate_layout({ + lse_logsum_local: T.Fragment(lse_logsum_local.shape, forward_thread_fn=lambda i: i), + }) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local[0] = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_max_local[0] = T.max(lse_max_local[0], glse[bz, by, k]) + for k in T.Pipelined(num_split, num_stages=1): + lse_local_split[0] = glse[bz, by, k] + lse_logsum_local[0] += T.exp2(lse_local_split[0] - lse_max_local[0]) + lse_logsum_local[0] = T.log2(lse_logsum_local[0]) + lse_max_local[0] + for k in T.serial(num_split): + for i in T.Parallel(dv): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split[0] = glse[bz, by, k] + scale_local[0] = T.exp2(lse_local_split[0] - lse_logsum_local[0]) + for i in T.Parallel(dv): + o_accum_local[i] += po_local[i] * scale_local[0] + for i in T.Parallel(dv): + Output[bz, by, i] = o_accum_local[i] + + @T.prim_func + def main_split( + Q: T.Buffer([batch, h_q, dv], dtype), + Q_pe: T.Buffer([batch, h_q, dpe], dtype), + KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Buffer([batch, max_seqlen_pad // block_size], "int32"), + cache_seqlens: T.Buffer([batch], "int32"), + glse: T.Buffer([batch, h_q, num_split], dtype), + Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype), + Output: T.Buffer([batch, h_q, dv], dtype), + ): + flash_mla_split_kv_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, glse, Output_partial) + combine(glse, Output_partial, Output) + + @T.prim_func + def main_no_split( + Q: T.Buffer([batch, h_q, dv], dtype), + Q_pe: T.Buffer([batch, h_q, dpe], dtype), + KV: T.Buffer([batch * max_seqlen_pad, h_kv, dv], dtype), + K_pe: T.Buffer([batch * max_seqlen_pad, h_kv, dpe], dtype), + block_table: T.Buffer([batch, max_seqlen_pad // block_size], "int32"), + cache_seqlens: T.Buffer([batch], "int32"), + glse: T.Buffer([batch, h_q, num_split], dtype), + Output_partial: T.Buffer([batch, h_q, num_split, dv], dtype), + Output: T.Buffer([batch, h_q, dv], dtype), + ): + flash_mla_kernel(Q, Q_pe, KV, K_pe, block_table, cache_seqlens, Output) + + if num_split > 1: + return main_split + else: + return main_no_split + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + # q: [b, s_q, h_q, d] + # block_table: [b, max_seqlen_pad // block_size] + # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] + # cache_seqlens: [b] + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out.to(dtype), lse.to(dtype) + + out_torch, _ = ref_mla() + return out_torch + + +def run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + BLOCK_N = 64 + BLOCK_H = 64 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + out = torch.empty(b, h_q, dv, dtype=dtype, device=q.device) + program = mla_decode_tilelang(b, h_q, h_kv, max_seqlen_pad, dv, dpe, BLOCK_N, BLOCK_H, num_kv_splits, block_size) + mod, params = tilelang.lower(program) + mod = tilelang.Profiler(mod, params, [8], tilelang.TensorSupplyType.Randn) + + def flash_mla_tilelang(): + out = mod.func( + q_nope.view(-1, h_q, dv), + q_pe.view(-1, h_q, dpe), + blocked_k_nope.view(-1, h_kv, dv), + blocked_k_pe.view(-1, h_kv, dpe), + block_table, + cache_seqlens, + glse, + out_partial, + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_tilelang() + t = do_bench(flash_mla_tilelang) + out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + torch.testing.assert_close(out_flash, out_ref, rtol=0.01, atol=0.01) + print("All close") + return out_flash, t + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=128, help='batch size') + parser.add_argument('--h_q', type=int, default=128, help='q heads number') + parser.add_argument('--h_kv', type=int, default=1, help='kv heads number') + parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length') + parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe') + parser.add_argument('--dv', type=int, default=512, help='value head dim') + args = parser.parse_args() + b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv + + device = "cuda" + dtype = torch.float16 + + s_q = 1 # for decode, s_q = 1 + block_size = 64 + cache_seqlens = torch.tensor([cache_seqlen + 2 * i for i in range(b)], dtype=torch.int32, device=device) + dpe = d - dv + causal = True + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 + + total_flops = s_q * total_seqlens * h_q * (d + dv) * 2 + + q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32, device=device).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) + out_flash, latency = run_tilelang_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) \ No newline at end of file diff --git a/testing/python/autotune/test_tilelang_autotune.py b/testing/python/autotune/test_tilelang_autotune.py index f73c1b5c4..a6d87d0e5 100644 --- a/testing/python/autotune/test_tilelang_autotune.py +++ b/testing/python/autotune/test_tilelang_autotune.py @@ -72,7 +72,7 @@ def get_configs(M, N, K, with_roller=False): if roller_hints is None: raise ValueError("No Roller Hints Found for TensorCore Scheduling") - + configs = [] for hint in roller_hints: config = {} diff --git a/tilelang/autotuner/__init__.py b/tilelang/autotuner/__init__.py index 94f9cf02d..b73f26f38 100644 --- a/tilelang/autotuner/__init__.py +++ b/tilelang/autotuner/__init__.py @@ -1,8 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""The auto-tune module for tl programs.""" +"""The auto-tune module for tilelang programs.""" -import tilelang as tl +import tilelang from tilelang import tvm as tvm import inspect from functools import wraps @@ -21,9 +21,9 @@ @dataclass(frozen=True) class JITContext: - mod: tl.Profiler + mod: tilelang.Profiler out_idx: List[int] - supply_type: tl.TensorSupplyType + supply_type: tilelang.TensorSupplyType ref_prog: Callable rtol: float atol: float @@ -144,7 +144,7 @@ def autotune(configs: Any, rep: int = 100, timeout: int = 100) -> Callable: """ - Decorator for tl program + Decorator for tilelang program """ def decorator(fn: Callable) -> Autotuner: @@ -154,7 +154,7 @@ def decorator(fn: Callable) -> Autotuner: def jit(out_idx: List[int], - supply_type: tl.TensorSupplyType = tl.TensorSupplyType.Normal, + supply_type: tilelang.TensorSupplyType = tilelang.TensorSupplyType.Normal, ref_prog: Callable = None, rtol: float = 1e-2, atol: float = 1e-2, @@ -169,9 +169,9 @@ def wrapper(fn: Callable): def decorator(*args, **kwargs) -> float: # Enabling Efficient Fusion with tvm.transform.PassContext(config={"tir.merge_static_smem": True}): - mod, params = tl.lower(fn(*args, **kwargs), target=target) + mod, params = tilelang.lower(fn(*args, **kwargs), target=target) - mod = tl.Profiler(mod, params, out_idx, supply_type) + mod = tilelang.Profiler(mod, params, out_idx, supply_type) return JITContext( mod=mod, diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py index 8720e5ff9..9a8461aa7 100644 --- a/tilelang/engine/phase.py +++ b/tilelang/engine/phase.py @@ -3,7 +3,7 @@ from tvm import tir, IRModule from tvm.target import Target -import tilelang as tl +import tilelang def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: @@ -11,17 +11,17 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: mod = tir.transform.BindTarget(target)(mod) # Legalize the frontend IR to make it compatible with TVM - mod = tl.transform.FrontendLegalize()(mod) + mod = tilelang.transform.FrontendLegalize()(mod) # Simplify the IR expressions mod = tir.transform.Simplify()(mod) # Infer memory layouts for fragments and shared memory - mod = tl.transform.LayoutInference()(mod) + mod = tilelang.transform.LayoutInference()(mod) # Lower high-level tile operations to low-level operations - mod = tl.transform.LowerTileOp()(mod) + mod = tilelang.transform.LowerTileOp()(mod) # Legalize vectorized loops to ensure they are valid - mod = tl.transform.LegalizeVectorizedLoop()(mod) + mod = tilelang.transform.LegalizeVectorizedLoop()(mod) # Add safety checks for memory accesses - mod = tl.transform.LegalizeSafeMemoryAccess()(mod) + mod = tilelang.transform.LegalizeSafeMemoryAccess()(mod) # Simplify again to clean up any duplicated conditions # that may have been introduced by safety checks mod = tir.transform.Simplify()(mod) @@ -32,23 +32,23 @@ def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # which may be introduced by the LegalizeSafeMemoryAccess if target.arch == "sm_90": - mod = tl.transform.MultiVersionBuffer()(mod) - mod = tl.transform.WarpSpecialized()(mod) - mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tilelang.transform.MultiVersionBuffer()(mod) + mod = tilelang.transform.WarpSpecialized()(mod) + mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tir.transform.LowerOpaqueBlock()(mod) - mod = tl.transform.RewriteWgmmaSync()(mod) - # mod = tl.transform.WarpSpecializedPipeline()(mod) - mod = tl.transform.InjectFenceProxy()(mod) + mod = tilelang.transform.RewriteWgmmaSync()(mod) + # mod = tilelang.transform.WarpSpecializedPipeline()(mod) + mod = tilelang.transform.InjectFenceProxy()(mod) else: mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tl.transform.PipelinePlanning()(mod) - mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tilelang.transform.PipelinePlanning()(mod) + mod = tilelang.transform.InjectSoftwarePipeline()(mod) mod = tir.transform.LowerOpaqueBlock()(mod) mod = tir.transform.FlattenBuffer()(mod) mod = tir.transform.NarrowDataType(32)(mod) mod = tir.transform.Simplify()(mod) - mod = tl.transform.VectorizeLoop()(mod) + mod = tilelang.transform.VectorizeLoop()(mod) mod = tir.transform.StorageRewrite()(mod) mod = tir.transform.UnrollLoop()(mod) mod = tir.transform.RenormalizeSplitPattern()(mod) @@ -68,19 +68,19 @@ def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: # We can find a way better to create var instead # of putting the LowerThreadAllreduce before # the Legalization. - mod = tl.transform.ThreadPartialSync("shared.dyn")(mod) + mod = tilelang.transform.ThreadPartialSync("shared.dyn")(mod) mod = tir.transform.InferFragment()(mod) mod = tir.transform.LowerThreadAllreduce()(mod) - mod = tl.transform.LowerHopperIntrin()(mod) - mod = tl.transform.ThreadSync("shared")(mod) - mod = tl.transform.ThreadSync("shared.dyn")(mod) + mod = tilelang.transform.LowerHopperIntrin()(mod) + mod = tilelang.transform.ThreadSync("shared")(mod) + mod = tilelang.transform.ThreadSync("shared.dyn")(mod) mod = tir.transform.InjectPTXAsyncCopy()(mod) - mod = tl.transform.AnnotateDeviceRegions()(mod) + mod = tilelang.transform.AnnotateDeviceRegions()(mod) mod = tir.transform.SplitHostDevice()(mod) mod = tir.transform.MergeSharedMemoryAllocations()(mod) - mod = tl.transform.MakePackedAPI()(mod) + mod = tilelang.transform.MakePackedAPI()(mod) mod = tir.transform.LowerDeviceKernelLaunch()(mod) return mod