|
| 1 | +# Copyright (c) Tile-AI Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import torch |
| 5 | +import torch.distributed as dist |
| 6 | +import pynvshmem |
| 7 | +import tilelang |
| 8 | +import tilelang.language as T |
| 9 | +from tilelang.distributed.utils import init_distributed, dtype_map, dsize_map |
| 10 | +import math |
| 11 | +import argparse |
| 12 | + |
| 13 | +tilelang.disable_cache() |
| 14 | + |
| 15 | + |
| 16 | +def summa(MESH, M, N, K, block_M, block_N, block_K, dtype="float16"): |
| 17 | + |
| 18 | + M_local = T.ceildiv(M, MESH) |
| 19 | + N_local = T.ceildiv(N, MESH) |
| 20 | + K_local = T.ceildiv(K, MESH) |
| 21 | + accum_dtype = "float32" |
| 22 | + |
| 23 | + sm_num = 132 # 132 SMs for H100 |
| 24 | + total_tiles = T.ceildiv(M_local, block_M) * T.ceildiv(N_local, block_N) |
| 25 | + |
| 26 | + @T.prim_func |
| 27 | + def main( |
| 28 | + A: T.Tensor((2, M_local, K_local), dtype), |
| 29 | + B: T.Tensor((2, N_local, K_local), dtype), |
| 30 | + A_signal_to: T.Tensor((T.ceildiv(M, block_M),), "uint64"), |
| 31 | + A_signal_from: T.Tensor((T.ceildiv(M, block_M),), "uint64"), |
| 32 | + B_signal_to: T.Tensor((T.ceildiv(N, block_N),), "uint64"), |
| 33 | + B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"), |
| 34 | + C: T.Tensor((M_local, N_local), dtype), |
| 35 | + ): |
| 36 | + grid_size = T.min(sm_num, total_tiles) |
| 37 | + A_rows_per_block = T.ceildiv(M_local, grid_size) |
| 38 | + B_cols_per_block = T.ceildiv(N_local, grid_size) |
| 39 | + waves = T.ceildiv(total_tiles, sm_num) |
| 40 | + with T.Kernel(grid_size, threads=256) as (block_id): |
| 41 | + mype = T.alloc_local([1], "int32") |
| 42 | + mype[0] = T.get_pe() |
| 43 | + |
| 44 | + A_shared = T.alloc_shared((block_M, block_K), dtype) |
| 45 | + B_shared = T.alloc_shared((block_N, block_K), dtype) |
| 46 | + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 47 | + |
| 48 | + tx = T.get_thread_binding(0) |
| 49 | + |
| 50 | + pe_mn = mype[0] // MESH |
| 51 | + pe_k = mype[0] % MESH |
| 52 | + |
| 53 | + T.clear(C_local) |
| 54 | + for ko in T.serial(MESH): |
| 55 | + # broadcast A |
| 56 | + if pe_k == ko: |
| 57 | + if tx == 0: |
| 58 | + T.signal_wait_until( |
| 59 | + T.address_of(A_signal_from[0]), |
| 60 | + T.NVSHMEM_CMP_GE, |
| 61 | + total_tiles * MESH * ko, |
| 62 | + ) |
| 63 | + if block_id < T.ceildiv(M_local, A_rows_per_block): |
| 64 | + for peer_k in T.serial(MESH): |
| 65 | + T.putmem_signal_nbi_block( |
| 66 | + T.address_of(A[(ko + 1) % 2, A_rows_per_block * block_id, 0]), |
| 67 | + T.address_of(A[ko % 2, A_rows_per_block * block_id, |
| 68 | + 0]), A_rows_per_block * K_local * dsize_map[dtype], |
| 69 | + T.address_of(A_signal_to[0]), 1, T.NVSHMEM_SIGNAL_ADD, |
| 70 | + pe_mn * MESH + peer_k) |
| 71 | + |
| 72 | + # broadcast B |
| 73 | + if pe_k == ko: |
| 74 | + if tx == 0: |
| 75 | + T.signal_wait_until( |
| 76 | + T.address_of(B_signal_from[0]), |
| 77 | + T.NVSHMEM_CMP_GE, |
| 78 | + total_tiles * MESH * ko, |
| 79 | + ) |
| 80 | + if block_id < T.ceildiv(N_local, B_cols_per_block): |
| 81 | + for peer_k in T.serial(MESH): |
| 82 | + T.putmem_signal_nbi_block( |
| 83 | + T.address_of(B[(ko + 1) % 2, B_cols_per_block * block_id, 0]), |
| 84 | + T.address_of(B[ko % 2, B_cols_per_block * block_id, |
| 85 | + 0]), B_cols_per_block * K_local * dsize_map[dtype], |
| 86 | + T.address_of(B_signal_to[0]), 1, T.NVSHMEM_SIGNAL_ADD, |
| 87 | + pe_mn * MESH + peer_k) |
| 88 | + |
| 89 | + # TODO: check if __syncthreads() is needed |
| 90 | + T.signal_wait_until( |
| 91 | + T.address_of(A_signal_to[0]), |
| 92 | + T.NVSHMEM_CMP_GE, |
| 93 | + (ko + 1) * T.ceildiv(M_local, A_rows_per_block), |
| 94 | + ) |
| 95 | + T.signal_wait_until( |
| 96 | + T.address_of(B_signal_to[0]), |
| 97 | + T.NVSHMEM_CMP_GE, |
| 98 | + (ko + 1) * T.ceildiv(N_local, B_cols_per_block), |
| 99 | + ) |
| 100 | + |
| 101 | + for w in T.serial(waves): |
| 102 | + |
| 103 | + bx = (grid_size * w + block_id) // T.ceildiv(N_local, block_N) |
| 104 | + by = (grid_size * w + block_id) % T.ceildiv(N_local, block_N) |
| 105 | + |
| 106 | + if bx < T.ceildiv(M_local, block_M) and by < T.ceildiv(N_local, block_N): |
| 107 | + T.copy(C[bx * block_M, by * block_N], C_local) |
| 108 | + for ki in T.Pipelined(T.ceildiv(K_local, block_K), num_stages=4): |
| 109 | + T.copy(A[ko % 2, bx * block_M, ki * block_K], A_shared) |
| 110 | + T.copy(B[ko % 2, by * block_N, ki * block_K], B_shared) |
| 111 | + T.gemm(A_shared, B_shared, C_local, transpose_B=True) |
| 112 | + |
| 113 | + T.copy(C_local, C[bx * block_M, by * block_N]) |
| 114 | + if tx == 0: |
| 115 | + # Tell next A sender |
| 116 | + a_sender = pe_mn * MESH + (ko + 1) % MESH |
| 117 | + T.signal_op( |
| 118 | + T.address_of(A_signal_from[0]), |
| 119 | + 1, |
| 120 | + T.NVSHMEM_SIGNAL_ADD, |
| 121 | + a_sender, |
| 122 | + ) |
| 123 | + # Tell next B sender |
| 124 | + b_sender = pe_mn * MESH + (ko + 1) % MESH |
| 125 | + T.signal_op( |
| 126 | + T.address_of(B_signal_from[0]), |
| 127 | + 1, |
| 128 | + T.NVSHMEM_SIGNAL_ADD, |
| 129 | + b_sender, |
| 130 | + ) |
| 131 | + |
| 132 | + return main |
| 133 | + |
| 134 | + |
| 135 | +def parse_args(): |
| 136 | + parser = argparse.ArgumentParser() |
| 137 | + parser.add_argument("--M", default=16384, type=int) |
| 138 | + parser.add_argument("--N", default=16384, type=int) |
| 139 | + parser.add_argument("--K", default=16384, type=int) |
| 140 | + parser.add_argument("--warmup", default=20, type=int, help="warmup iterations") |
| 141 | + parser.add_argument("--iters", default=100, type=int, help="perf iterations") |
| 142 | + parser.add_argument("--dtype", default="float16", type=str, help="data type") |
| 143 | + return parser.parse_args() |
| 144 | + |
| 145 | + |
| 146 | +if __name__ == "__main__": |
| 147 | + # init |
| 148 | + args = parse_args() |
| 149 | + |
| 150 | + WORLD_SIZE, RANK, LOCAL_RANK = init_distributed() |
| 151 | + |
| 152 | + MESH = math.ceil(math.sqrt(WORLD_SIZE)) |
| 153 | + assert MESH * MESH == WORLD_SIZE, "Mesh size must match world size" |
| 154 | + |
| 155 | + M, N, K = args.M, args.N, args.K |
| 156 | + block_M, block_N, block_K = 128, 256, 64 |
| 157 | + dtype = dtype_map[args.dtype] |
| 158 | + |
| 159 | + M_local = math.ceil(M / MESH) |
| 160 | + N_local = math.ceil(N / MESH) |
| 161 | + K_local = math.ceil(K / MESH) |
| 162 | + |
| 163 | + func = summa(MESH, M, N, K, block_M, block_N, block_K, args.dtype) |
| 164 | + kernel = tilelang.compile( |
| 165 | + func, pass_configs={ |
| 166 | + "tl.disable_tma_lower": True, |
| 167 | + "tl.disable_warp_specialized": True |
| 168 | + }) |
| 169 | + |
| 170 | + # Get CUDA Source |
| 171 | + if RANK == 0: |
| 172 | + print(kernel.get_kernel_source()) |
| 173 | + |
| 174 | + device = torch.device(f"cuda:{RANK}") |
| 175 | + ref = torch.empty((M_local, N_local), dtype=dtype, device=device) |
| 176 | + A_ref = torch.empty((M_local, K_local), dtype=dtype, device=device) |
| 177 | + B_ref = torch.empty((N_local, K_local), dtype=dtype, device=device) |
| 178 | + |
| 179 | + if RANK == 0: |
| 180 | + A = torch.randn(M, K, dtype=dtype, device=device) |
| 181 | + B = torch.randn(N, K, dtype=dtype, device=device) |
| 182 | + C = A @ B.T |
| 183 | + |
| 184 | + c_scatter_list = [] |
| 185 | + a_scatter_list = [] |
| 186 | + b_scatter_list = [] |
| 187 | + for r in range(WORLD_SIZE): |
| 188 | + rr, cc = divmod(r, MESH) |
| 189 | + c_tile = C[M_local * rr:M_local * (rr + 1), N_local * cc:N_local * (cc + 1)] |
| 190 | + a_tile = A[M_local * rr:M_local * (rr + 1), K_local * cc:K_local * (cc + 1)] |
| 191 | + b_tile = B[N_local * cc:N_local * (cc + 1), K_local * rr:K_local * (rr + 1)] |
| 192 | + |
| 193 | + c_scatter_list.append(c_tile.contiguous()) |
| 194 | + a_scatter_list.append(a_tile.contiguous()) |
| 195 | + b_scatter_list.append(b_tile.contiguous()) |
| 196 | + else: |
| 197 | + c_scatter_list = None |
| 198 | + a_scatter_list = None |
| 199 | + b_scatter_list = None |
| 200 | + |
| 201 | + dist.scatter(tensor=ref, scatter_list=c_scatter_list, src=0) |
| 202 | + dist.scatter(tensor=A_ref, scatter_list=a_scatter_list, src=0) |
| 203 | + dist.scatter(tensor=B_ref, scatter_list=b_scatter_list, src=0) |
| 204 | + dist.barrier() |
| 205 | + |
| 206 | + A = pynvshmem.nvshmem_create_tensor([2, M_local, K_local], dtype) |
| 207 | + B = pynvshmem.nvshmem_create_tensor([2, N_local, K_local], dtype) |
| 208 | + A[0, :, :].copy_(A_ref) |
| 209 | + B[0, :, :].copy_(B_ref) |
| 210 | + A_signal_to = pynvshmem.nvshmem_create_tensor([math.ceil(M / block_M)], torch.uint64) |
| 211 | + A_signal_from = pynvshmem.nvshmem_create_tensor([math.ceil(M / block_M)], torch.uint64) |
| 212 | + B_signal_to = pynvshmem.nvshmem_create_tensor([math.ceil(N / block_N)], torch.uint64) |
| 213 | + B_signal_from = pynvshmem.nvshmem_create_tensor([math.ceil(N / block_N)], torch.uint64) |
| 214 | + A_signal_to.fill_(0) |
| 215 | + A_signal_from.fill_(0) |
| 216 | + B_signal_to.fill_(0) |
| 217 | + B_signal_from.fill_(0) |
| 218 | + C_tilelang = pynvshmem.nvshmem_create_tensor([M_local, N_local], dtype) |
| 219 | + |
| 220 | + kernel(A, B, A_signal_to, A_signal_from, B_signal_to, B_signal_from, C_tilelang) |
| 221 | + |
| 222 | + for r in range(WORLD_SIZE): |
| 223 | + dist.barrier() |
| 224 | + if r == RANK: |
| 225 | + if torch.allclose(C_tilelang, ref, rtol=1e-2, atol=1e-2): |
| 226 | + print('-' * 100) |
| 227 | + print(f"[Rank {RANK}] ✅ Tilelang and Torch match") |
| 228 | + else: |
| 229 | + abs_error = torch.abs(C_tilelang - ref) |
| 230 | + rel_error = abs_error / (torch.abs(ref) + 1e-8) |
| 231 | + |
| 232 | + max_abs_error = abs_error.max().item() |
| 233 | + max_rel_error = rel_error.max().item() |
| 234 | + mismatch_ratio = (abs_error > (1e-2 + 1e-2 * torch.abs(ref))).float().mean().item() |
| 235 | + |
| 236 | + print('-' * 100) |
| 237 | + print(f"[Rank {RANK}] ❌ Tilelang and Torch mismatch") |
| 238 | + print(f"[Rank {RANK}] ref:\n{ref}") |
| 239 | + print(f"[Rank {RANK}] tilelang:\n{C_tilelang}") |
| 240 | + print(f"[Rank {RANK}] Mismatch ratio: {mismatch_ratio:.4f}") |
| 241 | + print(f"[Rank {RANK}] Max absolute error: {max_abs_error:.6f}") |
| 242 | + print(f"[Rank {RANK}] Max relative error: {max_rel_error:.6f}") |
| 243 | + dist.barrier() |
| 244 | + |
| 245 | + |
| 246 | +def bench(func, *args): |
| 247 | + bench_iters = 10 |
| 248 | + torch.cuda._sleep(1000000000) |
| 249 | + |
| 250 | + def preprocess(): |
| 251 | + # clear signals |
| 252 | + args[2].fill_(0) |
| 253 | + args[3].fill_(0) |
| 254 | + args[4].fill_(0) |
| 255 | + args[5].fill_(0) |
| 256 | + |
| 257 | + # warmup |
| 258 | + for _ in range(20): |
| 259 | + preprocess() |
| 260 | + _ = func(*args) |
| 261 | + |
| 262 | + st = torch.cuda.Event(enable_timing=True) |
| 263 | + ed = torch.cuda.Event(enable_timing=True) |
| 264 | + # bench |
| 265 | + st.record() |
| 266 | + for _ in range(bench_iters): |
| 267 | + preprocess() |
| 268 | + _ = func(*args) |
| 269 | + ed.record() |
| 270 | + torch.cuda.synchronize() |
| 271 | + avg_time = st.elapsed_time(ed) / bench_iters |
| 272 | + |
| 273 | + return avg_time |
| 274 | + |
| 275 | + |
| 276 | +def reduce_local_time(local_time): |
| 277 | + tensor = torch.tensor([local_time], dtype=torch.float32).to("cuda") |
| 278 | + dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM) |
| 279 | + if dist.get_rank() == 0: |
| 280 | + world_size = dist.get_world_size() |
| 281 | + mean_time = (tensor / world_size).item() |
| 282 | + return mean_time |
| 283 | + return None |
| 284 | + |
| 285 | + |
| 286 | +total_flops = 2 * M * N * K |
| 287 | +avg_time = reduce_local_time( |
| 288 | + bench(kernel, A, B, A_signal_to, A_signal_from, B_signal_to, B_signal_from, C_tilelang)) |
| 289 | + |
| 290 | +if RANK == 0: |
| 291 | + print(f"avg time of RANK {RANK}: {avg_time} ms") |
| 292 | + print(f"TFlops: {total_flops / avg_time * 1e-9} TFlops") |
0 commit comments