diff --git a/maint/gemm_v2/correctness_evaluation.py b/maint/gemm_v2/correctness_evaluation.py index 9029fcd67..b7b56a00e 100644 --- a/maint/gemm_v2/correctness_evaluation.py +++ b/maint/gemm_v2/correctness_evaluation.py @@ -1,4 +1,4 @@ -# pytest gemm_ss_wgmma.py -n 32 +# pytest correctness_evaluation.py -n 32 import pytest from tilelang import tvm as tvm import tilelang.testing @@ -384,7 +384,7 @@ def run_gemm_rr( M_VALUES = [64, 128, 256] -N_VALUES = [16, 32, 64, 128] +N_VALUES = [16, 32, 64, 128, 256, 512] K_VALUES = [16, 32, 64, 128] K_VALUES_8Bit = [32, 64, 128] FALSE_TRUE_CASES = ([ diff --git a/maint/gemm_v2/latency_gemm.py b/maint/gemm_v2/latency_gemm.py new file mode 100644 index 000000000..13392dec7 --- /dev/null +++ b/maint/gemm_v2/latency_gemm.py @@ -0,0 +1,99 @@ +import tilelang +import tilelang.language as T +import argparse + +parser = argparse.ArgumentParser() +parser.add_argument("--use_v2", action="store_true") +args = parser.parse_args() + +use_v2 = args.use_v2 + + +# @tilelang.jit(target="cuda") +# target currently can be "cuda" or "hip" or "cpu". +# if not specified, it will be inferred from the input tensors during compile time +@tilelang.jit +def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): + + @T.prim_func + def matmul_relu_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + # Initialize Kernel Context + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Enable rasterization for better L2 cache locality (Optional) + # T.use_swizzle(panel_size=10, enable=True) + + # Clear local accumulation + T.clear(C_local) + + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy tile of A + # This is a sugar syntax for parallelized copy + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy tile of B + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # Perform a tile-level GEMM on the shared buffers + # Currently we dispatch to the cute/hip on Nvidia/AMD GPUs + if use_v2: + T.gemm_v2(A_shared, B_shared, C_local) + else: + T.gemm_v1(A_shared, B_shared, C_local) + + # relu + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Copy result back to global memory + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_relu_kernel + + +M = 16384 # M = T.dynamic("m") if you want to use dynamic shape +N = 16384 +K = 16384 +block_M = 128 +block_N = 128 +block_K = 64 + +# 1. Define the kernel (matmul) and compile/lower it into an executable module +matmul_relu_kernel = matmul(M, N, K, block_M, block_N, block_K) + +# 3. Test the kernel in Python with PyTorch data +import torch + +# Create random input tensors on the GPU +a = torch.randn(M, K, device="cuda", dtype=torch.float16) +b = torch.randn(K, N, device="cuda", dtype=torch.float16) +c = torch.empty(M, N, device="cuda", dtype=torch.float16) + +# Run the kernel through the Profiler +matmul_relu_kernel(a, b, c) + +print(c) +# Reference multiplication using PyTorch +ref_c = torch.relu(a @ b) + +# Validate correctness +torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) +print("Kernel output matches PyTorch reference.") + +# 4. Retrieve and inspect the generated CUDA source (optional) +# cuda_source = jit_kernel.get_kernel_source() +# print("Generated CUDA kernel:\n", cuda_source) + +# 5.Profile latency with kernel +profiler = matmul_relu_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + +latency = profiler.do_bench() + +print(f"Latency: {latency} ms") diff --git a/maint/gemm_v2/latency_mha_fwd_bhsd.py b/maint/gemm_v2/latency_mha_fwd_bhsd.py new file mode 100644 index 000000000..cbe93bf69 --- /dev/null +++ b/maint/gemm_v2/latency_mha_fwd_bhsd.py @@ -0,0 +1,246 @@ +import torch +import torch.nn.functional as F +import tilelang +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + +parser = argparse.ArgumentParser() +parser.add_argument('--batch', type=int, default=128, help='batch size') +parser.add_argument('--heads', type=int, default=16, help='heads') +parser.add_argument('--seq_q', type=int, default=1024, help='query sequence length') +parser.add_argument('--seq_kv', type=int, default=1024, help='key/value sequence length') +parser.add_argument('--dim', type=int, default=512, help='dim') +parser.add_argument('--is_causal', action='store_true', help='causal') +parser.add_argument('--tune', action='store_true', help='tune configs') +parser.add_argument("--use_v2", action="store_true") + +args = parser.parse_args() + +use_v2 = args.use_v2 + + +def get_configs(): + iter_params = dict(block_M=[128], block_N=[128], num_stages=[2], threads=[256]) + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit( + out_idx=[3], pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }) +def flashattn(batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=64, + block_N=64, + num_stages=0, + threads=128): + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + q_shape = [batch, heads, seq_q, dim] + kv_shape = [batch, heads, seq_kv, dim] + dtype = "float16" + accum_dtype = "float" + + past_len = seq_kv - seq_q + assert past_len >= 0, "seq_kv must be greater than or equal to seq_q" + + @T.macro + def MMA0( + K: T.Tensor(kv_shape, dtype), + Q_shared: T.SharedBuffer([block_M, dim], dtype), + K_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, by, k * block_N:(k + 1) * block_N, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + past_len + k_idx = k * block_N + j + acc_s[i, j] = T.if_then_else(q_idx >= k_idx, 0, -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + if use_v2: + T.gemm_v2(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + else: + T.gemm_v1(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Tensor(kv_shape, dtype), + V_shared: T.SharedBuffer([block_N, dim], dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, by, k * block_N:(k + 1) * block_N, :], V_shared) + # T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if use_v2: + T.gemm_v2(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + else: + T.gemm_v1(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.FragmentBuffer([block_M, block_N], accum_dtype), + acc_s_cast: T.FragmentBuffer([block_M, block_N], dtype), + scores_max: T.FragmentBuffer([block_M], accum_dtype), + scores_max_prev: T.FragmentBuffer([block_M], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + scores_sum: T.FragmentBuffer([block_M], accum_dtype), + logsum: T.FragmentBuffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.FragmentBuffer([block_M, dim], accum_dtype), + scores_scale: T.FragmentBuffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_q, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, by, bx * block_M:(bx + 1) * block_M, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min( + T.ceildiv(seq_kv, block_N), T.ceildiv( + (bx + 1) * block_M + + past_len, block_N)) if is_causal else T.ceildiv(seq_kv, block_N)) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, scores_sum, + logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, by, bx * block_M:(bx + 1) * block_M, :]) + + return main + + +def ref_program(Q, K, V, is_causal): + dim = Q.size(-1) + scores = torch.einsum('bhqd,bhkd->bhqk', Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_q = Q.size(2) + seq_kv = K.size(2) + mask = torch.tril(torch.ones(seq_q, seq_kv, device=scores.device), seq_kv - seq_q) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float('-inf')) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum('bhqk,bhkd->bhqd', attention_weights, V) + return output + + +def main( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 64, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if (not tune): + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=64, + block_N=64, + num_stages=0, + threads=128) + print(kernel.get_kernel_source()) + ref_program_processed = partial(ref_program, is_causal=is_causal) + + profiler = kernel.get_profiler() + profiler.assert_allclose(ref_program_processed, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = profiler.do_bench(ref_program_processed, warmup=500) + print(f"Ref: {latency:.2f} ms") + print(f"Ref: {total_flops / latency * 1e-9:.2f} TFlops") + latency = profiler.do_bench(warmup=500) + print(f"Tile-lang: {latency:.2f} ms") + print(f"Tile-lang: {total_flops / latency * 1e-9:.2f} TFlops") + else: + kernel = flashattn(batch, heads, seq_q, seq_kv, dim, is_causal) + best_latency = kernel.latency + best_config = kernel.config + ref_latency = kernel.ref_latency + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") + print(f"Ref latency: {ref_latency}") + + +if __name__ == "__main__": + tilelang.disable_cache() + main(args.batch, args.heads, args.seq_q, args.seq_kv, args.dim, args.is_causal, args.tune) diff --git a/src/transform/lower_tile_op.cc b/src/transform/lower_tile_op.cc index 96ae34e3f..9759c9bbc 100644 --- a/src/transform/lower_tile_op.cc +++ b/src/transform/lower_tile_op.cc @@ -10,6 +10,7 @@ #include #include #include +#include #include "../layout/layout.h" #include "../layout/utils.h" @@ -301,6 +302,9 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { layout_map_.Set(buffer, layout); } } + // Begin a new workspace collection frame for this block scope + workspace_stack_.emplace_back(); + auto block = Downcast(arith::IRMutatorWithAnalyzer::VisitStmt_(op)); auto block_ptr = block.CopyOnWrite(); for (size_t i = 0; i < block->alloc_buffers.size(); i++) { @@ -309,9 +313,13 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { block_ptr->alloc_buffers.Set(i, buffer_remap_[buffer]); } } - for (const auto &buffer : workspaces_) - block_ptr->alloc_buffers.push_back(buffer); - workspaces_.clear(); + // Attach any workspaces requested within this block to its alloc_buffers + if (!workspace_stack_.empty()) { + for (const auto &buffer : workspace_stack_.back()) { + block_ptr->alloc_buffers.push_back(buffer); + } + workspace_stack_.pop_back(); + } return block; } @@ -659,7 +667,15 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { AddWorkspaceCallback callback = [this](int num_elem, DataType dtype) { auto workspace = decl_buffer({PrimExpr(num_elem)}, dtype, "workspace", "shared.dyn"); - workspaces_.push_back(workspace); + // Record workspace under the innermost block scope so its lifetime + // covers the statements that requested it and does not sink into + // subsequently created inner blocks (e.g., GEMM macro blocks). + if (!workspace_stack_.empty()) { + workspace_stack_.back().push_back(workspace); + } else { + // Fallback: create a temporary frame (should be rare) + workspace_stack_.emplace_back(Array{workspace}); + } return workspace.access_ptr(2); // write }; @@ -707,7 +723,8 @@ class LowerTileOpPass : arith::IRMutatorWithAnalyzer { IterVar thread_var_ = IterVar(Range::FromMinExtent(0, 1), Var("v_thread"), IterVarType::kDataPar); size_t thread_block_size_ = 0; - Array workspaces_; + // Stack of per-Block workspace buffers gathered while visiting children + std::vector> workspace_stack_; // For ptx Node, we need to remap the buffer and indices // By access CallNode instead of BufferLoad Node. bool is_ptx_{false}; diff --git a/tilelang/intrinsics/wgmma_macro_generator.py b/tilelang/intrinsics/wgmma_macro_generator.py index b6d45cc1e..69ef750b5 100644 --- a/tilelang/intrinsics/wgmma_macro_generator.py +++ b/tilelang/intrinsics/wgmma_macro_generator.py @@ -6,6 +6,7 @@ from tvm import DataType from tvm.tir import PrimExpr, Buffer, Var, IndexMap from tilelang.utils import is_fragment +from math import gcd from tilelang.layout import ( Layout, make_full_bank_swizzled_layout, @@ -70,6 +71,11 @@ class TensorCoreIntrinEmitter(MMAIntrinEmitter): # should be rewritten to support dynamic k_dim wgmma_prefix: str + # wgmma instruction M dimension + wgmma_inst_m: int + # wgmma instruction N dimension + wgmma_inst_n: int + a_shared_layout: Layout = None b_shared_layout: Layout = None @@ -104,9 +110,18 @@ def _assign_b_shared_layout(self, layout: Layout): return self def _initialize_wgmma_prefix(self, n_dim: int = 16): - inst_m, inst_n = 64, self.block_col_warps * self.warp_col_tiles + inst_m, inst_n = 64, gcd(self.warp_col_tiles, 256) + assert inst_n % 8 == 0, ( + f"inst_n must be a multiple of 8, got {inst_n} " + f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})") + # Validate inst_n: Hopper WGMMA supports n in [8, 256] and multiple of 8 + assert 8 <= inst_n <= 256, ( + f"inst_n must be within [8, 256], got {inst_n} " + f"(block_col_warps={self.block_col_warps}, warp_col_tiles={self.warp_col_tiles})") # 256 bits per instruction inst_k = 256 // DataType(self.a_dtype).bits + self.wgmma_inst_m = inst_m + self.wgmma_inst_n = inst_n self.wgmma_prefix = f"m{inst_m}n{inst_n}k{inst_k}" def _initialize_micro_size(self, m_dim: int = 16, k_dim: int = 16): @@ -149,10 +164,11 @@ def wgmma(self, A_buf: Buffer, B_buf: Buffer, C_local_buf: Buffer, - clear_accum: PrimExpr = False): + clear_accum: PrimExpr = False, + wg_wait: int = 0): if is_fragment(A_buf): - return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum) + return self.wgmma_rs(A_buf, B_buf, C_local_buf, clear_accum, wg_wait) local_size_out = self.local_size_out a_dtype_abbrv = self.a_dtype_abbrv @@ -241,9 +257,16 @@ def wgmma(self, # where max specially handles the case when n_dim is 8. ak_atom_size = max(a_swizzle_atom_elems // micro_size_k, 1) bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n + num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m + num_inst_n = self.warp_col_tiles // wgmma_inst_n + + thread_binding = self.get_thread_binding() @T.macro def _warp_mma(A_buf, B_buf, C_local_buf): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + desc_a = T.alloc_wgmma_desc() desc_b = T.alloc_wgmma_desc() T.initialize_wgmma_descriptor(desc_a, A_buf.access_ptr("r"), a_swizzle_mode, @@ -254,23 +277,29 @@ def _warp_mma(A_buf, B_buf, C_local_buf): int(b_stride_byte_offset >> 4)) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_arrive() - for ki in T.serial(0, (k_dim // micro_size_k)): - scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1)) - for i in T.serial(m_dim // 64): - A_offset = (ki % ak_atom_size) * micro_size_k + i * 64 * a_swizzle_atom_elems + ( - ki // ak_atom_size - ) * m_dim * a_swizzle_atom_elems if a_is_k_major else i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k - B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( - ki % bk_atom_size - ) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k - C_offset = i * warp_cols * local_size_out # 4 warps as an unit - T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, - a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, - (A_offset * elems_in_bytes) >> 4, desc_b.data, - (B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset, - scale_out, scale_in_a, scale_in_b) + for j in T.serial(num_inst_n): + for i in T.serial(num_inst_m): + for ki in T.serial(k_dim // micro_size_k): + warp_i = (warp_m // 4) * num_inst_m + i + warp_j = warp_n * num_inst_n + j + scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1)) + A_offset = ( + ki % ak_atom_size + ) * micro_size_k + warp_i * 64 * a_swizzle_atom_elems + ( + ki // ak_atom_size + ) * m_dim * a_swizzle_atom_elems if a_is_k_major else warp_i * 64 * k_dim + ki * a_swizzle_atom_elems * micro_size_k + B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( + ki % bk_atom_size + ) * micro_size_k + warp_j * wgmma_inst_n * b_swizzle_atom_elems if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k + warp_j * k_dim * wgmma_inst_n + C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit + T.ptx_wgmma_ss(accum_dtype, wgmma_prefix, a_is_k_major, b_is_k_major, + a_dtype_abbrv, b_dtype_abbrv, accum_dtype_abbrv, desc_a.data, + (A_offset * elems_in_bytes) >> 4, desc_b.data, + (B_offset * elems_in_bytes) >> 4, C_local_buf.data, C_offset, + scale_out, scale_in_a, scale_in_b) T.warpgroup_commit_batch() - T.warpgroup_wait(0) + if wg_wait >= 0: + T.warpgroup_wait(wg_wait) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) return _warp_mma(A_buf, B_buf, C_local_buf) @@ -279,7 +308,8 @@ def wgmma_rs(self, A_buf: Buffer, B_buf: Buffer, C_local_buf: Buffer, - clear_accum: PrimExpr = False): + clear_accum: PrimExpr = False, + wg_wait: int = 0): local_size_a = self.local_size_a local_size_out = self.local_size_out a_dtype_abbrv = self.a_dtype_abbrv @@ -333,9 +363,16 @@ def wgmma_rs(self, b_stride_byte_offset = 8 * elems_in_bytes * b_swizzle_atom_elems bk_atom_size = max(b_swizzle_atom_elems // micro_size_k, 1) + wgmma_inst_m, wgmma_inst_n = self.wgmma_inst_m, self.wgmma_inst_n + num_inst_m = 4 * self.warp_row_tiles // wgmma_inst_m + num_inst_n = self.warp_col_tiles // wgmma_inst_n + + thread_binding = self.get_thread_binding() @T.macro def _warp_mma(A_buf, B_buf, C_local_buf): + tx, warp_n, warp_m = self.extract_thread_binding(thread_binding) + desc_b = T.alloc_wgmma_desc() T.initialize_wgmma_descriptor(desc_b, B_buf.access_ptr("r"), b_swizzle_mode, int(b_leading_byte_offset >> 4), @@ -343,33 +380,39 @@ def _warp_mma(A_buf, B_buf, C_local_buf): T.warpgroup_fence_operand(A_buf, num_regs=a_regs) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_arrive() - for ki in T.serial(0, (k_dim // micro_size_k)): - scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1)) - for i in T.serial(m_dim // 64): - A_offset = ki * warp_rows * local_size_a + i * local_size_a - B_offset = (ki // bk_atom_size) * n_dim * b_swizzle_atom_elems + ( - ki % bk_atom_size - ) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k - C_offset = i * warp_cols * local_size_out # 4 warps as an unit - T.ptx_wgmma_rs( - accum_dtype, - wgmma_prefix, - self.b_transposed, - a_dtype_abbrv, - b_dtype_abbrv, - accum_dtype_abbrv, - A_buf.data, - A_offset, - desc_b.data, - (B_offset * elems_in_bytes) >> 4, - C_local_buf.data, - C_offset, - scale_out, - scale_in_a, - scale_in_b, - ) + + for j in T.serial(0, num_inst_n): + for i in T.serial(num_inst_m): + for ki in T.serial(0, (k_dim // micro_size_k)): + warp_j = warp_n * num_inst_n + j + scale_out = T.if_then_else(ki != 0, 1, T.if_then_else(clear_accum, 0, 1)) + A_offset = ki * warp_rows * local_size_a + i * local_size_a + B_offset = ( + ki // bk_atom_size + ) * n_dim * b_swizzle_atom_elems + warp_j * wgmma_inst_n * b_swizzle_atom_elems + ( + ki % bk_atom_size + ) * micro_size_k if b_is_k_major else ki * b_swizzle_atom_elems * micro_size_k + warp_j * k_dim * wgmma_inst_n + C_offset = i * warp_cols * local_size_out + j * warp_cols * local_size_out // num_inst_n # 4 warps as an unit + T.ptx_wgmma_rs( + accum_dtype, + wgmma_prefix, + self.b_transposed, + a_dtype_abbrv, + b_dtype_abbrv, + accum_dtype_abbrv, + A_buf.data, + A_offset, + desc_b.data, + (B_offset * elems_in_bytes) >> 4, + C_local_buf.data, + C_offset, + scale_out, + scale_in_a, + scale_in_b, + ) T.warpgroup_commit_batch() - T.warpgroup_wait(0) + if wg_wait >= 0: + T.warpgroup_wait(wg_wait) T.warpgroup_fence_operand(C_local_buf, num_regs=accum_regs) T.warpgroup_fence_operand(A_buf, num_regs=a_regs) diff --git a/tilelang/tileop/gemm/gemm_wgmma.py b/tilelang/tileop/gemm/gemm_wgmma.py index 39be65921..1e9607cdf 100644 --- a/tilelang/tileop/gemm/gemm_wgmma.py +++ b/tilelang/tileop/gemm/gemm_wgmma.py @@ -91,6 +91,7 @@ def lower(self, layout_map: dict, target: Target, thread_nums: int, thread_var: B_shared = self.B C_local = self.C clear_accum = self.clear_accum + wg_wait = self.wg_wait if self.is_gemm_ss(): @@ -102,7 +103,7 @@ def _gemm_ssr() -> None: accumulating into C_local. """ # Perform Matrix Multiplication - mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum) + mma_emitter.wgmma(A_shared, B_shared, C_local, clear_accum, wg_wait) # Simplify to optimize the index computing # Must inline let statements to simplify the analysis @@ -117,7 +118,7 @@ def _gemm_rsr() -> None: B_shared into local fragments, then issues Tensor Core mma ops, accumulating into C_local. """ - mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum) + mma_emitter.wgmma(A_local, B_shared, C_local, clear_accum, wg_wait) # Simplify to optimize the index computing # Must inline let statements to simplify the analysis