1010from einops import rearrange , repeat
1111import triton
1212import triton .language as tl
13- from fla .ops .common . utils import prepare_token_indices
13+ from fla .ops .utils import prepare_token_indices
1414from fla .utils import autocast_custom_fwd , contiguous
1515
1616
@@ -439,6 +439,20 @@ def naive_nsa(q: torch.Tensor,
439439 return o_slc .to (dtype ) + o_swa .to (dtype ) if o_swa is not None else o_slc .to (dtype )
440440
441441
442+ def get_configs ():
443+ import itertools
444+ iter_params = dict (
445+ block_T = [128 , 256 , 512 ],
446+ num_stages = [0 , 1 , 2 , 4 , 5 ],
447+ threads = [32 , 64 , 128 , 256 , 512 ],
448+ )
449+ return [{
450+ k : v for k , v in zip (iter_params , values )
451+ } for values in itertools .product (* iter_params .values ())]
452+
453+
454+ @tilelang .autotune (configs = get_configs (),)
455+ @tilelang .jit
442456def tilelang_sparse_attention (batch ,
443457 heads ,
444458 seq_len ,
@@ -447,7 +461,10 @@ def tilelang_sparse_attention(batch,
447461 scale = None ,
448462 block_size = 64 ,
449463 groups = 1 ,
450- selected_blocks = 16 ):
464+ selected_blocks = 16 ,
465+ block_T = 128 ,
466+ num_stages = 2 ,
467+ threads = 32 ):
451468 if scale is None :
452469 scale = (1.0 / dim )** 0.5 * 1.44269504 # log2(e)
453470 else :
@@ -461,7 +478,7 @@ def tilelang_sparse_attention(batch,
461478 dtype = "float16"
462479 accum_dtype = "float"
463480 block_S = block_size
464- block_T = min (128 , tilelang .math .next_power_of_2 (dim ))
481+ block_T = min (block_T , tilelang .math .next_power_of_2 (dim ))
465482
466483 NK = tilelang .cdiv (dim , block_T )
467484 NV = tilelang .cdiv (dim , block_T )
@@ -471,8 +488,6 @@ def tilelang_sparse_attention(batch,
471488 G = groups
472489 BS = block_S
473490 BK = BV = block_T
474- num_stages = 2
475- threads = 32
476491
477492 @T .prim_func
478493 def tilelang_sparse_attention (
@@ -489,19 +504,15 @@ def tilelang_sparse_attention(
489504 O_shared = T .alloc_shared ([G , BV ], dtype )
490505
491506 acc_s = T .alloc_fragment ([G , BS ], accum_dtype )
492- acc_s_cast = T .alloc_fragment ([G , BS ], dtype )
507+ acc_s_cast = T .alloc_shared ([G , BS ], dtype )
493508 acc_o = T .alloc_fragment ([G , BV ], accum_dtype )
494509 scores_max = T .alloc_fragment ([G ], accum_dtype )
495510 scores_max_prev = T .alloc_fragment ([G ], accum_dtype )
496511 scores_scale = T .alloc_fragment ([G ], accum_dtype )
497512 scores_sum = T .alloc_fragment ([G ], accum_dtype )
498513 logsum = T .alloc_fragment ([G ], accum_dtype )
499514
500- # T.use_swizzle(10)
501-
502- T .annotate_layout ({Q_shared : tilelang .layout .make_swizzled_layout (Q_shared )})
503- T .annotate_layout ({K_shared : tilelang .layout .make_swizzled_layout (K_shared )})
504- T .annotate_layout ({V_shared : tilelang .layout .make_swizzled_layout (V_shared )})
515+ T .annotate_layout ({O_shared : tilelang .layout .make_swizzled_layout (O_shared )})
505516
506517 i_t , i_v , i_bh = bx , by , bz
507518 i_b , i_h = i_bh // head_kv , i_bh % head_kv
@@ -597,7 +608,7 @@ def benchmark_nsa(batch_size,
597608 torch .random .manual_seed (0 )
598609
599610 # Compile the NSA kernel
600- program = tilelang_sparse_attention (
611+ kernel = tilelang_sparse_attention (
601612 batch = batch_size ,
602613 heads = head_query ,
603614 seq_len = seq_len ,
@@ -608,9 +619,6 @@ def benchmark_nsa(batch_size,
608619 selected_blocks = selected_blocks ,
609620 scale = scale ,
610621 )
611- print (program )
612- kernel = tilelang .compile (program , out_idx = None , execution_backend = "cython" )
613- print (kernel .get_kernel_source ())
614622
615623 profiler = kernel .get_profiler ()
616624
0 commit comments