Skip to content

Commit 60f16a1

Browse files
committed
[Dev] Update SUMMA example
1 parent a499425 commit 60f16a1

File tree

1 file changed

+292
-0
lines changed

1 file changed

+292
-0
lines changed
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
4+
import torch
5+
import torch.distributed as dist
6+
import pynvshmem
7+
import tilelang
8+
import tilelang.language as T
9+
from tilelang.distributed.utils import init_distributed, dtype_map, dsize_map
10+
import math
11+
import argparse
12+
13+
tilelang.disable_cache()
14+
15+
16+
def summa(MESH, M, N, K, block_M, block_N, block_K, dtype="float16"):
17+
18+
M_local = T.ceildiv(M, MESH)
19+
N_local = T.ceildiv(N, MESH)
20+
K_local = T.ceildiv(K, MESH)
21+
accum_dtype = "float32"
22+
23+
sm_num = 132 # 132 SMs for H100
24+
total_tiles = T.ceildiv(M_local, block_M) * T.ceildiv(N_local, block_N)
25+
26+
@T.prim_func
27+
def main(
28+
A: T.Tensor((2, M_local, K_local), dtype),
29+
B: T.Tensor((2, N_local, K_local), dtype),
30+
A_signal_to: T.Tensor((T.ceildiv(M, block_M),), "uint64"),
31+
A_signal_from: T.Tensor((T.ceildiv(M, block_M),), "uint64"),
32+
B_signal_to: T.Tensor((T.ceildiv(N, block_N),), "uint64"),
33+
B_signal_from: T.Tensor((T.ceildiv(N, block_N),), "uint64"),
34+
C: T.Tensor((M_local, N_local), dtype),
35+
):
36+
grid_size = T.min(sm_num, total_tiles)
37+
A_rows_per_block = T.ceildiv(M_local, grid_size)
38+
B_cols_per_block = T.ceildiv(N_local, grid_size)
39+
waves = T.ceildiv(total_tiles, sm_num)
40+
with T.Kernel(grid_size, threads=256) as (block_id):
41+
mype = T.alloc_local([1], "int32")
42+
mype[0] = T.get_pe()
43+
44+
A_shared = T.alloc_shared((block_M, block_K), dtype)
45+
B_shared = T.alloc_shared((block_N, block_K), dtype)
46+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
47+
48+
tx = T.get_thread_binding(0)
49+
50+
pe_mn = mype[0] // MESH
51+
pe_k = mype[0] % MESH
52+
53+
T.clear(C_local)
54+
for ko in T.serial(MESH):
55+
# broadcast A
56+
if pe_k == ko:
57+
if tx == 0:
58+
T.signal_wait_until(
59+
T.address_of(A_signal_from[0]),
60+
T.NVSHMEM_CMP_GE,
61+
total_tiles * MESH * ko,
62+
)
63+
if block_id < T.ceildiv(M_local, A_rows_per_block):
64+
for peer_k in T.serial(MESH):
65+
T.putmem_signal_nbi_block(
66+
T.address_of(A[(ko + 1) % 2, A_rows_per_block * block_id, 0]),
67+
T.address_of(A[ko % 2, A_rows_per_block * block_id,
68+
0]), A_rows_per_block * K_local * dsize_map[dtype],
69+
T.address_of(A_signal_to[0]), 1, T.NVSHMEM_SIGNAL_ADD,
70+
pe_mn * MESH + peer_k)
71+
72+
# broadcast B
73+
if pe_k == ko:
74+
if tx == 0:
75+
T.signal_wait_until(
76+
T.address_of(B_signal_from[0]),
77+
T.NVSHMEM_CMP_GE,
78+
total_tiles * MESH * ko,
79+
)
80+
if block_id < T.ceildiv(N_local, B_cols_per_block):
81+
for peer_k in T.serial(MESH):
82+
T.putmem_signal_nbi_block(
83+
T.address_of(B[(ko + 1) % 2, B_cols_per_block * block_id, 0]),
84+
T.address_of(B[ko % 2, B_cols_per_block * block_id,
85+
0]), B_cols_per_block * K_local * dsize_map[dtype],
86+
T.address_of(B_signal_to[0]), 1, T.NVSHMEM_SIGNAL_ADD,
87+
pe_mn * MESH + peer_k)
88+
89+
# TODO: check if __syncthreads() is needed
90+
T.signal_wait_until(
91+
T.address_of(A_signal_to[0]),
92+
T.NVSHMEM_CMP_GE,
93+
(ko + 1) * T.ceildiv(M_local, A_rows_per_block),
94+
)
95+
T.signal_wait_until(
96+
T.address_of(B_signal_to[0]),
97+
T.NVSHMEM_CMP_GE,
98+
(ko + 1) * T.ceildiv(N_local, B_cols_per_block),
99+
)
100+
101+
for w in T.serial(waves):
102+
103+
bx = (grid_size * w + block_id) // T.ceildiv(N_local, block_N)
104+
by = (grid_size * w + block_id) % T.ceildiv(N_local, block_N)
105+
106+
if bx < T.ceildiv(M_local, block_M) and by < T.ceildiv(N_local, block_N):
107+
T.copy(C[bx * block_M, by * block_N], C_local)
108+
for ki in T.Pipelined(T.ceildiv(K_local, block_K), num_stages=4):
109+
T.copy(A[ko % 2, bx * block_M, ki * block_K], A_shared)
110+
T.copy(B[ko % 2, by * block_N, ki * block_K], B_shared)
111+
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
112+
113+
T.copy(C_local, C[bx * block_M, by * block_N])
114+
if tx == 0:
115+
# Tell next A sender
116+
a_sender = pe_mn * MESH + (ko + 1) % MESH
117+
T.signal_op(
118+
T.address_of(A_signal_from[0]),
119+
1,
120+
T.NVSHMEM_SIGNAL_ADD,
121+
a_sender,
122+
)
123+
# Tell next B sender
124+
b_sender = pe_mn * MESH + (ko + 1) % MESH
125+
T.signal_op(
126+
T.address_of(B_signal_from[0]),
127+
1,
128+
T.NVSHMEM_SIGNAL_ADD,
129+
b_sender,
130+
)
131+
132+
return main
133+
134+
135+
def parse_args():
136+
parser = argparse.ArgumentParser()
137+
parser.add_argument("--M", default=16384, type=int)
138+
parser.add_argument("--N", default=16384, type=int)
139+
parser.add_argument("--K", default=16384, type=int)
140+
parser.add_argument("--warmup", default=20, type=int, help="warmup iterations")
141+
parser.add_argument("--iters", default=100, type=int, help="perf iterations")
142+
parser.add_argument("--dtype", default="float16", type=str, help="data type")
143+
return parser.parse_args()
144+
145+
146+
if __name__ == "__main__":
147+
# init
148+
args = parse_args()
149+
150+
WORLD_SIZE, RANK, LOCAL_RANK = init_distributed()
151+
152+
MESH = math.ceil(math.sqrt(WORLD_SIZE))
153+
assert MESH * MESH == WORLD_SIZE, "Mesh size must match world size"
154+
155+
M, N, K = args.M, args.N, args.K
156+
block_M, block_N, block_K = 128, 256, 64
157+
dtype = dtype_map[args.dtype]
158+
159+
M_local = math.ceil(M / MESH)
160+
N_local = math.ceil(N / MESH)
161+
K_local = math.ceil(K / MESH)
162+
163+
func = summa(MESH, M, N, K, block_M, block_N, block_K, args.dtype)
164+
kernel = tilelang.compile(
165+
func, pass_configs={
166+
"tl.disable_tma_lower": True,
167+
"tl.disable_warp_specialized": True
168+
})
169+
170+
# Get CUDA Source
171+
if RANK == 0:
172+
print(kernel.get_kernel_source())
173+
174+
device = torch.device(f"cuda:{RANK}")
175+
ref = torch.empty((M_local, N_local), dtype=dtype, device=device)
176+
A_ref = torch.empty((M_local, K_local), dtype=dtype, device=device)
177+
B_ref = torch.empty((N_local, K_local), dtype=dtype, device=device)
178+
179+
if RANK == 0:
180+
A = torch.randn(M, K, dtype=dtype, device=device)
181+
B = torch.randn(N, K, dtype=dtype, device=device)
182+
C = A @ B.T
183+
184+
c_scatter_list = []
185+
a_scatter_list = []
186+
b_scatter_list = []
187+
for r in range(WORLD_SIZE):
188+
rr, cc = divmod(r, MESH)
189+
c_tile = C[M_local * rr:M_local * (rr + 1), N_local * cc:N_local * (cc + 1)]
190+
a_tile = A[M_local * rr:M_local * (rr + 1), K_local * cc:K_local * (cc + 1)]
191+
b_tile = B[N_local * cc:N_local * (cc + 1), K_local * rr:K_local * (rr + 1)]
192+
193+
c_scatter_list.append(c_tile.contiguous())
194+
a_scatter_list.append(a_tile.contiguous())
195+
b_scatter_list.append(b_tile.contiguous())
196+
else:
197+
c_scatter_list = None
198+
a_scatter_list = None
199+
b_scatter_list = None
200+
201+
dist.scatter(tensor=ref, scatter_list=c_scatter_list, src=0)
202+
dist.scatter(tensor=A_ref, scatter_list=a_scatter_list, src=0)
203+
dist.scatter(tensor=B_ref, scatter_list=b_scatter_list, src=0)
204+
dist.barrier()
205+
206+
A = pynvshmem.nvshmem_create_tensor([2, M_local, K_local], dtype)
207+
B = pynvshmem.nvshmem_create_tensor([2, N_local, K_local], dtype)
208+
A[0, :, :].copy_(A_ref)
209+
B[0, :, :].copy_(B_ref)
210+
A_signal_to = pynvshmem.nvshmem_create_tensor([math.ceil(M / block_M)], torch.uint64)
211+
A_signal_from = pynvshmem.nvshmem_create_tensor([math.ceil(M / block_M)], torch.uint64)
212+
B_signal_to = pynvshmem.nvshmem_create_tensor([math.ceil(N / block_N)], torch.uint64)
213+
B_signal_from = pynvshmem.nvshmem_create_tensor([math.ceil(N / block_N)], torch.uint64)
214+
A_signal_to.fill_(0)
215+
A_signal_from.fill_(0)
216+
B_signal_to.fill_(0)
217+
B_signal_from.fill_(0)
218+
C_tilelang = pynvshmem.nvshmem_create_tensor([M_local, N_local], dtype)
219+
220+
kernel(A, B, A_signal_to, A_signal_from, B_signal_to, B_signal_from, C_tilelang)
221+
222+
for r in range(WORLD_SIZE):
223+
dist.barrier()
224+
if r == RANK:
225+
if torch.allclose(C_tilelang, ref, rtol=1e-2, atol=1e-2):
226+
print('-' * 100)
227+
print(f"[Rank {RANK}] ✅ Tilelang and Torch match")
228+
else:
229+
abs_error = torch.abs(C_tilelang - ref)
230+
rel_error = abs_error / (torch.abs(ref) + 1e-8)
231+
232+
max_abs_error = abs_error.max().item()
233+
max_rel_error = rel_error.max().item()
234+
mismatch_ratio = (abs_error > (1e-2 + 1e-2 * torch.abs(ref))).float().mean().item()
235+
236+
print('-' * 100)
237+
print(f"[Rank {RANK}] ❌ Tilelang and Torch mismatch")
238+
print(f"[Rank {RANK}] ref:\n{ref}")
239+
print(f"[Rank {RANK}] tilelang:\n{C_tilelang}")
240+
print(f"[Rank {RANK}] Mismatch ratio: {mismatch_ratio:.4f}")
241+
print(f"[Rank {RANK}] Max absolute error: {max_abs_error:.6f}")
242+
print(f"[Rank {RANK}] Max relative error: {max_rel_error:.6f}")
243+
dist.barrier()
244+
245+
246+
def bench(func, *args):
247+
bench_iters = 10
248+
torch.cuda._sleep(1000000000)
249+
250+
def preprocess():
251+
# clear signals
252+
args[2].fill_(0)
253+
args[3].fill_(0)
254+
args[4].fill_(0)
255+
args[5].fill_(0)
256+
257+
# warmup
258+
for _ in range(20):
259+
preprocess()
260+
_ = func(*args)
261+
262+
st = torch.cuda.Event(enable_timing=True)
263+
ed = torch.cuda.Event(enable_timing=True)
264+
# bench
265+
st.record()
266+
for _ in range(bench_iters):
267+
preprocess()
268+
_ = func(*args)
269+
ed.record()
270+
torch.cuda.synchronize()
271+
avg_time = st.elapsed_time(ed) / bench_iters
272+
273+
return avg_time
274+
275+
276+
def reduce_local_time(local_time):
277+
tensor = torch.tensor([local_time], dtype=torch.float32).to("cuda")
278+
dist.reduce(tensor, dst=0, op=dist.ReduceOp.SUM)
279+
if dist.get_rank() == 0:
280+
world_size = dist.get_world_size()
281+
mean_time = (tensor / world_size).item()
282+
return mean_time
283+
return None
284+
285+
286+
total_flops = 2 * M * N * K
287+
avg_time = reduce_local_time(
288+
bench(kernel, A, B, A_signal_to, A_signal_from, B_signal_to, B_signal_from, C_tilelang))
289+
290+
if RANK == 0:
291+
print(f"avg time of RANK {RANK}: {avg_time} ms")
292+
print(f"TFlops: {total_flops / avg_time * 1e-9} TFlops")

0 commit comments

Comments
 (0)