Skip to content

Commit e7b0807

Browse files
authored
[Example] Specify a fixed commit for the flash-linear-attention repository and optimize nsa examples (tile-ai#913)
- Updated the requirements.txt to specify a fixed commit for the flash-linear-attention repository. - Refactored import paths in benchmark_nsa_fwd.py for better organization. - Added a new function to generate configurations for autotuning. - Modified the tilelang_sparse_attention function to accept parameters for block size, number of stages, and threads, enhancing flexibility. - Changed allocation of shared memory for accumulators to optimize performance.
1 parent e8faffa commit e7b0807

File tree

3 files changed

+27
-16
lines changed

3 files changed

+27
-16
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,6 @@ tilelang/jit/adapter/cython/.cycache
9292

9393
# cache directory for clangd
9494
.cache/
95+
96+
# claude
97+
**/.claude

examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from einops import rearrange, repeat
1111
import triton
1212
import triton.language as tl
13-
from fla.ops.common.utils import prepare_token_indices
13+
from fla.ops.utils import prepare_token_indices
1414
from fla.utils import autocast_custom_fwd, contiguous
1515

1616

@@ -439,6 +439,20 @@ def naive_nsa(q: torch.Tensor,
439439
return o_slc.to(dtype) + o_swa.to(dtype) if o_swa is not None else o_slc.to(dtype)
440440

441441

442+
def get_configs():
443+
import itertools
444+
iter_params = dict(
445+
block_T=[128, 256, 512],
446+
num_stages=[0, 1, 2, 4, 5],
447+
threads=[32, 64, 128, 256, 512],
448+
)
449+
return [{
450+
k: v for k, v in zip(iter_params, values)
451+
} for values in itertools.product(*iter_params.values())]
452+
453+
454+
@tilelang.autotune(configs=get_configs(),)
455+
@tilelang.jit
442456
def tilelang_sparse_attention(batch,
443457
heads,
444458
seq_len,
@@ -447,7 +461,10 @@ def tilelang_sparse_attention(batch,
447461
scale=None,
448462
block_size=64,
449463
groups=1,
450-
selected_blocks=16):
464+
selected_blocks=16,
465+
block_T=128,
466+
num_stages=2,
467+
threads=32):
451468
if scale is None:
452469
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
453470
else:
@@ -461,7 +478,7 @@ def tilelang_sparse_attention(batch,
461478
dtype = "float16"
462479
accum_dtype = "float"
463480
block_S = block_size
464-
block_T = min(128, tilelang.math.next_power_of_2(dim))
481+
block_T = min(block_T, tilelang.math.next_power_of_2(dim))
465482

466483
NK = tilelang.cdiv(dim, block_T)
467484
NV = tilelang.cdiv(dim, block_T)
@@ -471,8 +488,6 @@ def tilelang_sparse_attention(batch,
471488
G = groups
472489
BS = block_S
473490
BK = BV = block_T
474-
num_stages = 2
475-
threads = 32
476491

477492
@T.prim_func
478493
def tilelang_sparse_attention(
@@ -489,19 +504,15 @@ def tilelang_sparse_attention(
489504
O_shared = T.alloc_shared([G, BV], dtype)
490505

491506
acc_s = T.alloc_fragment([G, BS], accum_dtype)
492-
acc_s_cast = T.alloc_fragment([G, BS], dtype)
507+
acc_s_cast = T.alloc_shared([G, BS], dtype)
493508
acc_o = T.alloc_fragment([G, BV], accum_dtype)
494509
scores_max = T.alloc_fragment([G], accum_dtype)
495510
scores_max_prev = T.alloc_fragment([G], accum_dtype)
496511
scores_scale = T.alloc_fragment([G], accum_dtype)
497512
scores_sum = T.alloc_fragment([G], accum_dtype)
498513
logsum = T.alloc_fragment([G], accum_dtype)
499514

500-
# T.use_swizzle(10)
501-
502-
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
503-
T.annotate_layout({K_shared: tilelang.layout.make_swizzled_layout(K_shared)})
504-
T.annotate_layout({V_shared: tilelang.layout.make_swizzled_layout(V_shared)})
515+
T.annotate_layout({O_shared: tilelang.layout.make_swizzled_layout(O_shared)})
505516

506517
i_t, i_v, i_bh = bx, by, bz
507518
i_b, i_h = i_bh // head_kv, i_bh % head_kv
@@ -597,7 +608,7 @@ def benchmark_nsa(batch_size,
597608
torch.random.manual_seed(0)
598609

599610
# Compile the NSA kernel
600-
program = tilelang_sparse_attention(
611+
kernel = tilelang_sparse_attention(
601612
batch=batch_size,
602613
heads=head_query,
603614
seq_len=seq_len,
@@ -608,9 +619,6 @@ def benchmark_nsa(batch_size,
608619
selected_blocks=selected_blocks,
609620
scale=scale,
610621
)
611-
print(program)
612-
kernel = tilelang.compile(program, out_idx=None, execution_backend="cython")
613-
print(kernel.get_kernel_source())
614622

615623
profiler = kernel.get_profiler()
616624

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
git+https://github.com/fla-org/flash-linear-attention
1+
git+https://github.com/fla-org/flash-linear-attention@c3bd56589033610264532b11f0972c69e4645f6e

0 commit comments

Comments
 (0)