2222
2323
2424@tilelang .jit (pass_configs = {"tl.disable_warp_specialized" : True , "tl.disable_tma_lower" : True })
25- def copy_and_barrier_all_intra_node_kernel (local_rank ,
26- rank ,
27- num_ranks ,
28- M ,
29- K ,
30- block_M ,
31- block_K ,
32- threads ,
33- dtype = "float16" ):
34-
35- M_per_rank = T .ceildiv (M , num_ranks )
36- sm_num = driver .get_num_sms ()
37- m_blocks = T .ceildiv (M_per_rank , block_M )
38- k_blocks = T .ceildiv (K , block_K )
39- waves = T .ceildiv (m_blocks * k_blocks , sm_num )
40-
41- @T .macro
42- def copy_kernel (src : T .Tensor ((M_per_rank , K ), dtype ), dst : T .Tensor ((M , K ), dtype ),
43- data_shared : T .Tensor ((block_M , block_K ), dtype ), block_id ):
44- for w in T .serial (waves ):
45- tile_id = sm_num * w + block_id
46- bx = tile_id % m_blocks
47- by = tile_id // m_blocks
48-
49- if by < k_blocks :
50- T .copy (src [bx * block_M , by * block_K ], data_shared )
51- T .copy (data_shared , dst [rank * M_per_rank + bx * block_M , by * block_K ])
52-
53- @T .macro
54- def barrier_all_intra_node_non_atomic (
55- sync_buffer : T .Tensor ((3 * num_ranks ), "uint32" ), block_id ):
56- if block_id == 0 :
57- T .barrier_all_blocks_sys (sync_buffer )
58- # barrier all CTAs
59- T .sync_grid (sync_buffer [2 * num_ranks ])
25+ def set_signal_kernel (local_rank , num_local_ranks , threads ):
6026
6127 @T .prim_func
62- def local_copy (
63- A : T .Tensor ((M_per_rank , K ), dtype ),
64- ag_buffer : T .Tensor ((M , K ), dtype ),
65- signal_buffer : T .Tensor ((num_ranks ), "uint32" ),
66- sync_buffer : T .Tensor ((3 * num_ranks ), "uint32" ),
67- ):
68- with T .Kernel (sm_num , threads = threads ) as (block_id ):
69- data_shared = T .alloc_shared ((block_M , block_K ), dtype )
70- T .annotate_layout ({data_shared : tilelang .layout .make_swizzled_layout (data_shared )})
71-
72- barrier_all_intra_node_non_atomic (sync_buffer , block_id )
73- copy_kernel (A , ag_buffer , data_shared , block_id )
28+ def _set_signal_kernel (signal_buffer : T .Tensor ((num_local_ranks ), "uint32" ),):
29+ with T .Kernel (1 , threads = threads ):
7430 tx = T .get_thread_binding (0 )
75- if block_id == 0 and tx < num_ranks : # set symm barrier
76- if tx == rank :
31+ if tx < num_local_ranks :
32+ if tx == local_rank :
7733 signal_buffer [tx ] = 1
7834 else :
7935 signal_buffer [tx ] = 0
80- barrier_all_intra_node_non_atomic (sync_buffer , block_id )
8136
82- return local_copy
37+ return _set_signal_kernel
8338
8439
8540@tilelang .jit
8641def gemm_kernel (M ,
8742 N ,
8843 K ,
89- num_rank ,
9044 local_rank ,
45+ num_local_rank ,
9146 block_M ,
9247 block_N ,
9348 block_K ,
9449 threads ,
50+ persistent = False ,
9551 dtype = "float16" ,
9652 accum_dtype = "float" ):
9753
98- M_per_rank = T .ceildiv (M , num_rank )
54+ sm_num = driver .get_num_sms ()
55+ m_blocks = T .ceildiv (M , block_M )
56+ n_blocks = T .ceildiv (N // num_local_rank , block_N )
57+ waves = T .ceildiv (m_blocks * n_blocks , sm_num )
58+ M_per_rank = T .ceildiv (M , num_local_rank )
9959 GROUP_SIZE_M = 8
10060
10161 @T .prim_func
10262 def main (
10363 A : T .Tensor ((M , K ), dtype ),
104- B : T .Tensor ((K , N // num_rank ), dtype ),
105- signal_buffer : T .Tensor ((num_rank ), "uint32" ),
106- C : T .Tensor ((M , N // num_rank ), dtype ),
64+ B : T .Tensor ((K , N // num_local_rank ), dtype ),
65+ signal_buffer : T .Tensor ((num_local_rank ), "uint32" ),
66+ C : T .Tensor ((M , N // num_local_rank ), dtype ),
10767 ):
10868 with T .Kernel (
109- T .ceildiv (M , block_M ) * T .ceildiv (N // num_rank , block_N ),
69+ T .ceildiv (M , block_M ) * T .ceildiv (N // num_local_rank , block_N ),
11070 threads = threads ) as (bid ):
11171 A_shared = T .alloc_shared ((block_M , block_K ), dtype )
11272 B_shared = T .alloc_shared ((block_K , block_N ), dtype )
11373 C_shared = T .alloc_shared ((block_M , block_N ), dtype )
11474 C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
11575
11676 num_pid_m = T .ceildiv (M , block_M )
117- num_pid_n = T .ceildiv (N // num_rank , block_N )
77+ num_pid_n = T .ceildiv (N // num_local_rank , block_N )
11878 num_pid_in_group = GROUP_SIZE_M * num_pid_n
11979 group_id = bid // num_pid_in_group
12080 first_pid_m = group_id * GROUP_SIZE_M
@@ -140,55 +100,94 @@ def main(
140100 T .copy (C_local , C_shared )
141101 T .copy (C_shared , C [pid_m * block_M , pid_n * block_N ])
142102
143- return main
103+ @T .prim_func
104+ def main_persistent (
105+ A : T .Tensor ((M , K ), dtype ),
106+ B : T .Tensor ((K , N // num_local_rank ), dtype ),
107+ signal_buffer : T .Tensor ((num_local_rank ), "uint32" ),
108+ C : T .Tensor ((M , N // num_local_rank ), dtype ),
109+ ):
110+ with T .Kernel (sm_num , threads = threads ) as (bid ):
111+ A_shared = T .alloc_shared ((block_M , block_K ), dtype )
112+ B_shared = T .alloc_shared ((block_K , block_N ), dtype )
113+ C_shared = T .alloc_shared ((block_M , block_N ), dtype )
114+ C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
115+
116+ for w in T .serial (waves ):
117+ tile_id = bid + w * sm_num
118+ num_pid_m = T .ceildiv (M , block_M )
119+ num_pid_n = T .ceildiv (N // num_local_rank , block_N )
120+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
121+ group_id = tile_id // num_pid_in_group
122+ first_pid_m = group_id * GROUP_SIZE_M
123+ group_size_m = T .min (num_pid_m - first_pid_m , GROUP_SIZE_M )
124+ pid_m_ = first_pid_m + ((tile_id % num_pid_in_group ) % group_size_m )
125+ pid_n_ = (tile_id % num_pid_in_group ) // group_size_m
126+
127+ # threadblock swizzle
128+ # no stream-k support. only split by m x n
129+ m_offset = M_per_rank * local_rank
130+ pid_m_offset = T .ceildiv (m_offset , block_M )
131+ pid_m = (pid_m_ + pid_m_offset ) % num_pid_m
132+ pid_n = pid_n_
133+
134+ if pid_n_ * block_N < (N // num_local_rank ) and pid_m_ * block_M < M :
135+ tid = T .get_thread_binding (0 )
136+ T .clear (C_local )
137+ if tid == 0 :
138+ T .wait_eq (signal_buffer [pid_m * block_M // M_per_rank ], 1 )
139+ for k in T .Pipelined (T .ceildiv (K , block_K ), num_stages = 3 ):
140+ T .copy (A [pid_m * block_M , k * block_K ], A_shared )
141+ T .copy (B [k * block_K , pid_n * block_N ], B_shared )
142+ T .gemm (A_shared , B_shared , C_local )
143+ T .copy (C_local , C_shared )
144+ T .copy (C_shared , C [pid_m * block_M , pid_n * block_N ])
145+
146+ return main if not persistent else main_persistent
144147
145148
146149def cp_engine_producer_all_gather_full_mesh_pull (
147- local_tensor ,
148150 ag_buffer ,
149151 signal_buffer ,
150152 M_per_rank ,
151- N ,
152153 signal_target ,
153- rank ,
154+ local_rank ,
154155 local_world_size ,
155- world_size ,
156156 intranode_ag_stream ,
157157):
158- rank_orders = [(rank + i ) % local_world_size for i in range (local_world_size )]
158+ rank_orders = [(local_rank + i ) % local_world_size for i in range (local_world_size )]
159159
160160 with torch .cuda .stream (intranode_ag_stream ):
161161 for src_rank in rank_orders :
162- if src_rank == rank :
162+ if src_rank == local_rank :
163163 continue
164- dst = ag_buffer [rank ][src_rank * M_per_rank :(src_rank + 1 ) * M_per_rank , :]
164+ dst = ag_buffer [local_rank ][src_rank * M_per_rank :(src_rank + 1 ) * M_per_rank , :]
165165 src = ag_buffer [src_rank ][src_rank * M_per_rank :(src_rank + 1 ) * M_per_rank , :]
166166 dst .copy_ (src )
167167
168168 (err ,) = cuda .cuStreamWriteValue32 (
169169 intranode_ag_stream .cuda_stream ,
170- signal_buffer [rank ][src_rank ].data_ptr (),
170+ signal_buffer [local_rank ][src_rank ].data_ptr (),
171171 signal_target ,
172172 cuda .CUstreamWriteValue_flags .CU_STREAM_WRITE_VALUE_DEFAULT ,
173173 )
174174
175175
176- def ag_gemm_op (A , B , C , ag_buffer , signal_buffer , sync_buffer , M_per_rank , N , signal_target , rank ,
177- group , local_world_size , world_size , local_copy_kernel , gemm_kernel , gemm_stream ,
178- ag_stream ):
176+ def ag_gemm_op (A , B , C , ag_buffer , signal_buffer , M_per_rank , N , signal_target , local_rank ,
177+ local_world_size , set_signal_kernel , gemm_kernel , gemm_stream , ag_stream ):
179178
180179 with torch .cuda .stream (gemm_stream ):
181- local_copy_kernel (
182- A , ag_buffer [rank ], signal_buffer [rank ], sync_buffer , stream = gemm_stream .cuda_stream )
180+ set_signal_kernel (signal_buffer [local_rank ], stream = gemm_stream .cuda_stream )
183181
184182 ag_stream .wait_stream (gemm_stream )
185183
186- cp_engine_producer_all_gather_full_mesh_pull (A , ag_buffer , signal_buffer , M_per_rank , N ,
187- signal_target , rank , local_world_size , world_size ,
184+ cp_engine_producer_all_gather_full_mesh_pull (ag_buffer , signal_buffer , M_per_rank ,
185+ signal_target , local_rank , local_world_size ,
188186 ag_stream )
189187
190188 with torch .cuda .stream (gemm_stream ):
191- gemm_kernel (ag_buffer [rank ], B , signal_buffer [rank ], C , stream = gemm_stream .cuda_stream )
189+ gemm_kernel (
190+ ag_buffer [local_rank ], B , signal_buffer [local_rank ], C , stream = gemm_stream .cuda_stream )
192191
193192 gemm_stream .wait_stream (ag_stream )
194193 current_stream = torch .cuda .current_stream ()
@@ -212,6 +211,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
212211 M = args .M if args else 8192
213212 N = args .N if args else 8192
214213 K = args .K if args else 8192
214+ persistent = args .persistent
215215 M_per_rank = M // num_local_ranks
216216 N_per_rank = N // num_local_ranks
217217
@@ -221,48 +221,45 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
221221 threads = 256
222222
223223 rank , num_ranks , group = init_dist (local_rank , num_local_ranks )
224+ assert rank == local_rank and num_ranks == num_local_ranks , "only support single node for now"
224225 allocator = tilelang .get_allocator (
225226 size = 2 ** 30 ,
226227 device = "cuda" ,
227228 is_distributed = True ,
228229 local_rank = local_rank ,
229230 num_local_ranks = num_local_ranks ,
230231 group = group )
231- kernel = gemm_kernel (M , N , K , num_ranks , rank , BLOCK_M , BLOCK_N , BLOCK_K , threads )
232- local_copy_kernel = copy_and_barrier_all_intra_node_kernel (
232+ gemm_func = gemm_kernel (M , N , K , local_rank , num_local_ranks , BLOCK_M , BLOCK_N , BLOCK_K ,
233+ threads , persistent )
234+ set_signal_func = set_signal_kernel (
233235 local_rank = local_rank ,
234- rank = local_rank ,
235- num_ranks = num_ranks ,
236- M = M ,
237- K = K ,
238- block_M = 64 ,
239- block_K = 64 ,
240- threads = 128 ,
236+ num_local_ranks = num_local_ranks ,
237+ threads = 32 ,
241238 )
242- kernel .initialize (allocator = allocator )
243- local_copy_kernel .initialize (allocator = allocator )
239+ gemm_func .initialize (allocator = allocator )
240+ set_signal_func .initialize (allocator = allocator )
244241 if local_rank == 1 :
245- print (kernel .get_kernel_source ())
246- print (local_copy_kernel .get_kernel_source ())
242+ print (gemm_func .get_kernel_source ())
243+ print (set_signal_func .get_kernel_source ())
247244
248- A = tilelang .tensor ((M_per_rank , K ), dtype , allocator = allocator ).normal_ ()
249245 B = tilelang .tensor ((K , N_per_rank ), dtype , allocator = allocator ).normal_ ()
250246 C = tilelang .tensor ((M , N_per_rank ), dtype , allocator = allocator )
251247 ag_buffer = tilelang .tensor ((M , K ), dtype , allocator = allocator , return_peers = True )
248+ A = ag_buffer [local_rank ][M_per_rank * local_rank :M_per_rank * (local_rank + 1 ), :].normal_ ()
252249 signal_buffer = tilelang .tensor ((num_local_ranks ,),
253250 torch .uint32 ,
254251 allocator = allocator ,
255252 return_peers = True )
256- signal_buffer [rank ].fill_ (0 ) # check if needed
257- sync_buffer = tilelang .tensor ((3 * num_ranks ,), torch .uint32 , allocator = allocator )
258253
259254 gemm_stream = torch .cuda .Stream ()
260255 ag_stream = torch .cuda .Stream (priority = - 1 )
261256 signal_target = 1
262257
263- tilelang_C = ag_gemm_op (A , B , C , ag_buffer , signal_buffer , sync_buffer , M_per_rank , K ,
264- signal_target , rank , group , num_local_ranks , num_local_ranks ,
265- local_copy_kernel , kernel , gemm_stream , ag_stream )
258+ dist .barrier ()
259+
260+ tilelang_C = ag_gemm_op (A , B , C , ag_buffer , signal_buffer , M_per_rank , K , signal_target ,
261+ local_rank , num_local_ranks , set_signal_func , gemm_func , gemm_stream ,
262+ ag_stream )
266263
267264 torch_ag_buffer = torch .empty ([M , K ], dtype = dtype , device = "cuda" )
268265 torch_C = torch_ag_gemm (group , A , B , torch_ag_buffer )
@@ -273,10 +270,10 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
273270 print (f"rank { local_rank } check failed.❌" )
274271 print (f"torch_C: { torch_C } , tilelang_C: { tilelang_C } " )
275272
276- tl_out , tl_t = perf_fn (
277- lambda : ag_gemm_op ( A , B , C , ag_buffer , signal_buffer , sync_buffer , M_per_rank , K ,
278- signal_target , rank , group , num_local_ranks , num_local_ranks ,
279- local_copy_kernel , kernel , gemm_stream , ag_stream ),
273+ _ , tl_t = perf_fn (
274+ lambda :
275+ ag_gemm_op ( A , B , C , ag_buffer , signal_buffer , M_per_rank , K , signal_target , local_rank ,
276+ num_local_ranks , set_signal_func , gemm_func , gemm_stream , ag_stream ),
280277 warmup = 5 ,
281278 rep = 10 )
282279
@@ -294,6 +291,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
294291 parser .add_argument ('--M' , type = int , default = 8192 , help = 'M dimension' )
295292 parser .add_argument ('--N' , type = int , default = 28672 , help = 'N dimension' )
296293 parser .add_argument ('--K' , type = int , default = 8192 , help = 'K dimension' )
294+ parser .add_argument ('--persistent' , action = 'store_true' , help = 'Use persistent kernel' )
297295 args = parser .parse_args ()
298296 num_processes = args .num_processes
299297
0 commit comments