Skip to content

Commit e59e919

Browse files
authored
[CI] Add hadamard example to CI (#549)
* [CI] Add hadamard example to CI * Run yapf and ruff * Run yapf and ruff
1 parent 4747edb commit e59e919

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Tile-AI Corporation.
2+
# Licensed under the MIT License.
3+
4+
import tilelang
5+
import tilelang.language as T
6+
from tilelang.intrinsics import make_mma_swizzle_layout
7+
8+
import math
9+
import argparse
10+
import torch
11+
from torch.nn import functional as F
12+
import scipy
13+
14+
15+
def is_pow_of_2(n):
16+
return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0
17+
18+
19+
def hadamard(b, n, dtype):
20+
assert is_pow_of_2(n), "n must be a power of 2"
21+
assert 2 <= n <= 32768, "n must be in [2, 32768]"
22+
elem_size = {'float32': 4, 'float16': 2, 'bfloat16': 2}[dtype]
23+
24+
logN = int(math.log2(n))
25+
threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN]
26+
thread_elem = n // threads # Each thread is responsible for a chunk of elements
27+
thread_round = int(math.log2(thread_elem))
28+
29+
warps = 1 if threads <= 32 else threads // 32
30+
warp_round = int(math.log2(threads / warps))
31+
warp_size = threads // warps
32+
33+
block_round = int(math.log2(warps))
34+
35+
exchange_round = n * elem_size // 32768 if n * elem_size > 32768 else 1 # Suppose we use 32KB shared memory at most
36+
thread_elem_in_smem = thread_elem // exchange_round if exchange_round > 1 else thread_elem
37+
38+
# debug log
39+
# print(f'{threads=}, {thread_round=}')
40+
# print(f'{warps=}, {warp_round=}, {warp_size=}')
41+
# print(f'{block_round=}')
42+
# print(f'{exchange_round=}')
43+
44+
@T.macro
45+
def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype),
46+
round: int):
47+
tx = T.get_thread_binding(0)
48+
for i in T.serial(round):
49+
tx_stride = 1 << i
50+
another_tx = tx ^ tx_stride
51+
sign = (
52+
tx >> i
53+
) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :]
54+
55+
for j in T.Pipelined(thread_elem, num_stages=1):
56+
buf[j] = T.tvm_warp_shuffle(
57+
0xffffffff, # mask of all threads
58+
local[j],
59+
another_tx % warp_size,
60+
warp_size,
61+
warp_size)
62+
local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j])
63+
64+
@T.prim_func
65+
def main(A: T.Tensor((b, n), dtype), B: T.Tensor((b, n), dtype)):
66+
with T.Kernel(b, threads=threads) as bx:
67+
local = T.alloc_local((thread_elem,), dtype)
68+
shared = T.alloc_shared((threads, thread_elem_in_smem), dtype)
69+
T.annotate_layout({shared: make_mma_swizzle_layout(shared)})
70+
tx = T.get_thread_binding(0)
71+
72+
# 1. Load from HBM to register
73+
for i in T.vectorized(thread_elem):
74+
local[i] = A[bx, tx * thread_elem + i]
75+
76+
# 2. Hadamard inside thread, n<=8
77+
for i in T.serial(thread_round):
78+
chunksize = 1 << (i + 1)
79+
chunknum = thread_elem // chunksize
80+
for j in T.serial(chunknum):
81+
chunkbase = j * chunksize
82+
for k in T.serial(chunksize // 2):
83+
local[chunkbase +
84+
k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2]
85+
local[chunkbase + k + chunksize //
86+
2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2]
87+
88+
# 3. Hadamard inside warp, n<=512
89+
# In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory
90+
another_val = T.alloc_local((thread_elem,), dtype)
91+
92+
warp_shfl(local, another_val, warp_round)
93+
94+
# 4. Hadamard inside block, n<=32768
95+
# Only exchange once for n<=8192, since shared mem can hold all elems
96+
if block_round > 0:
97+
warp_id = tx // warp_size
98+
lane_id = tx % warp_size
99+
src_tx = warp_id * warp_size + lane_id
100+
tgt_warp_id = tx % warps
101+
tgt_lane_id = tx // warps
102+
tgt_tx = tgt_warp_id * warp_size + tgt_lane_id
103+
104+
# 4.1 Write to smem, swap, read from smem
105+
for cur_round in T.serial(exchange_round):
106+
exchange_base = thread_elem_in_smem * cur_round
107+
for j in T.vectorized(thread_elem_in_smem):
108+
shared[src_tx, j] = local[exchange_base + j]
109+
110+
for j in T.vectorized(thread_elem_in_smem):
111+
local[exchange_base + j] = shared[tgt_tx, j]
112+
113+
# 4.2 Warp shuffle
114+
warp_shfl(local, another_val, block_round)
115+
116+
# 4.3 Write to smem, swap, read from smem
117+
for cur_round in T.serial(exchange_round):
118+
exchange_base = thread_elem_in_smem * cur_round
119+
for j in T.vectorized(thread_elem_in_smem):
120+
shared[tgt_tx, j] = local[exchange_base + j]
121+
122+
for j in T.vectorized(thread_elem_in_smem):
123+
local[exchange_base + j] = shared[src_tx, j]
124+
125+
# 5. Write back to HBM
126+
for i in T.vectorized(thread_elem):
127+
B[bx, tx * thread_elem + i] = local[i]
128+
129+
return main
130+
131+
132+
def ref_program(x: torch.Tensor):
133+
assert x.ndim == 2
134+
dim = x.shape[-1]
135+
assert is_pow_of_2(dim)
136+
return F.linear(
137+
x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device))
138+
139+
140+
def main():
141+
parser = argparse.ArgumentParser()
142+
parser.add_argument('--batch', type=int, default=64, help='Batch size')
143+
parser.add_argument('--dim', type=int, default=32768, help='Dimension')
144+
args = parser.parse_args()
145+
146+
B, D = args.batch, args.dim
147+
x = torch.randn((B, D), device='cuda')
148+
kernel = tilelang.compile(hadamard(B, D, 'float32'), out_idx=1)
149+
y = kernel(x)
150+
y_ref = ref_program(x)
151+
torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2)
152+
print('All tests passed.')
153+
154+
profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto)
155+
latency = profiler.do_bench(warmup=100)
156+
print("Tile-lang: {:.2f} ms".format(latency))
157+
158+
159+
if __name__ == '__main__':
160+
main()

0 commit comments

Comments
 (0)