Skip to content

Commit dc2e90c

Browse files
committed
[Example]Adds example for top-k operation
Adds an example demonstrating the top-k operation using tilelang
1 parent 9a86939 commit dc2e90c

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

examples/topk/example_topk.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
2+
import tilelang
3+
import tilelang.language as T
4+
import torch
5+
import itertools
6+
7+
# tilelang.disable_cache()
8+
9+
# torch.manual_seed(42)
10+
11+
def get_configs():
12+
iter_params = dict(
13+
blk_m=[64, 128, 256],
14+
threads=[128, 256, 512],
15+
)
16+
return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())]
17+
18+
19+
@tilelang.autotune(configs=get_configs())
20+
@tilelang.jit(out_idx=[1, 2])
21+
def tl_topk(
22+
M,
23+
N,
24+
topk,
25+
blk_m,
26+
threads=128,
27+
):
28+
dtype = "float32"
29+
30+
@T.prim_func
31+
def topk_kernel(
32+
logits: T.Tensor([M, N], dtype),
33+
topk_gates: T.Tensor([M, topk], dtype),
34+
topk_indices: T.Tensor([M, topk], "int32"),
35+
):
36+
with T.Kernel(T.ceildiv(M, blk_m), threads=threads) as bx:
37+
logits_frag = T.alloc_fragment([blk_m, N], dtype=dtype)
38+
max_val = T.alloc_fragment([blk_m], dtype=dtype)
39+
expand_max_idx = T.alloc_fragment([blk_m, N], "int32")
40+
max_idx = T.alloc_fragment([blk_m], "int32")
41+
42+
T.copy(logits[bx * blk_m, 0], logits_frag)
43+
44+
for k in T.serial(topk):
45+
T.fill(expand_max_idx, -1)
46+
T.reduce_max(logits_frag, max_val, dim=1, clear=True)
47+
48+
49+
for i, j in T.Parallel(blk_m, N):
50+
expand_max_idx[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], j, expand_max_idx[i, j])
51+
52+
T.reduce_max(expand_max_idx, max_idx, dim=1, clear=True)
53+
54+
for i, j in T.Parallel(blk_m, N):
55+
56+
logits_frag[i, j] = T.if_then_else(max_val[i] == logits_frag[i, j], -10000.0, logits_frag[i, j])
57+
58+
for i in T.Parallel(blk_m):
59+
topk_gates[bx * blk_m + i, k] = max_val[i]
60+
topk_indices[bx * blk_m + i, k] = max_idx[i]
61+
return topk_kernel
62+
63+
64+
def ref_program(logits, top_k):
65+
66+
top_k_gates, top_k_indices = logits.topk(top_k, dim=1)
67+
68+
return top_k_gates, top_k_indices.to(torch.int32)
69+
70+
71+
def main():
72+
M = 320
73+
N = 128
74+
topk = 6
75+
76+
logits = torch.rand(M, N).to("cuda")
77+
78+
kernel = tl_topk(M=M, N=N, topk=topk)
79+
tl_gates, tl_indices = kernel(logits)
80+
# print(tl_gates)
81+
82+
# print(kernel.get_kernel_source())
83+
print(kernel.config)
84+
85+
torch_gates, torch_indices = ref_program(logits, topk)
86+
# print(torch_gates)
87+
88+
# test accuracy
89+
torch.testing.assert_close(tl_gates, torch_gates)
90+
torch.testing.assert_close(tl_indices, torch_indices)
91+
92+
# profile
93+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
94+
tilelang_latency = profiler.do_bench()
95+
print(f"Tilelang latency: {tilelang_latency}")
96+
97+
98+
if __name__ == "__main__":
99+
main()

0 commit comments

Comments
 (0)