|
| 1 | +# Copyright (c) Tile-AI Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import tilelang |
| 5 | +import tilelang.language as T |
| 6 | +from tilelang.intrinsics import make_mma_swizzle_layout |
| 7 | + |
| 8 | +import math |
| 9 | +import argparse |
| 10 | +import torch |
| 11 | +from torch.nn import functional as F |
| 12 | +import scipy |
| 13 | + |
| 14 | + |
| 15 | +def is_pow_of_2(n): |
| 16 | + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 |
| 17 | + |
| 18 | + |
| 19 | +def hadamard(b, n, dtype): |
| 20 | + assert is_pow_of_2(n), "n must be a power of 2" |
| 21 | + assert 2 <= n <= 32768, "n must be in [2, 32768]" |
| 22 | + elem_size = {'float32': 4, 'float16': 2, 'bfloat16': 2}[dtype] |
| 23 | + |
| 24 | + logN = int(math.log2(n)) |
| 25 | + threads = [0, 1, 1, 1, 2, 4, 8, 16, 32, 32, 128, 256, 256, 256, 256, 256][logN] |
| 26 | + thread_elem = n // threads # Each thread is responsible for a chunk of elements |
| 27 | + thread_round = int(math.log2(thread_elem)) |
| 28 | + |
| 29 | + warps = 1 if threads <= 32 else threads // 32 |
| 30 | + warp_round = int(math.log2(threads / warps)) |
| 31 | + warp_size = threads // warps |
| 32 | + |
| 33 | + block_round = int(math.log2(warps)) |
| 34 | + |
| 35 | + exchange_round = n * elem_size // 32768 if n * elem_size > 32768 else 1 # Suppose we use 32KB shared memory at most |
| 36 | + thread_elem_in_smem = thread_elem // exchange_round if exchange_round > 1 else thread_elem |
| 37 | + |
| 38 | + # debug log |
| 39 | + # print(f'{threads=}, {thread_round=}') |
| 40 | + # print(f'{warps=}, {warp_round=}, {warp_size=}') |
| 41 | + # print(f'{block_round=}') |
| 42 | + # print(f'{exchange_round=}') |
| 43 | + |
| 44 | + @T.macro |
| 45 | + def warp_shfl(local: T.Tensor((thread_elem,), dtype), buf: T.Tensor((thread_elem,), dtype), |
| 46 | + round: int): |
| 47 | + tx = T.get_thread_binding(0) |
| 48 | + for i in T.serial(round): |
| 49 | + tx_stride = 1 << i |
| 50 | + another_tx = tx ^ tx_stride |
| 51 | + sign = ( |
| 52 | + tx >> i |
| 53 | + ) & 1 # get i-th lowest bit of tx, which determines the operation type for shared[tx, :] |
| 54 | + |
| 55 | + for j in T.Pipelined(thread_elem, num_stages=1): |
| 56 | + buf[j] = T.tvm_warp_shuffle( |
| 57 | + 0xffffffff, # mask of all threads |
| 58 | + local[j], |
| 59 | + another_tx % warp_size, |
| 60 | + warp_size, |
| 61 | + warp_size) |
| 62 | + local[j] = T.if_then_else(sign == 0, local[j] + buf[j], buf[j] - local[j]) |
| 63 | + |
| 64 | + @T.prim_func |
| 65 | + def main(A: T.Tensor((b, n), dtype), B: T.Tensor((b, n), dtype)): |
| 66 | + with T.Kernel(b, threads=threads) as bx: |
| 67 | + local = T.alloc_local((thread_elem,), dtype) |
| 68 | + shared = T.alloc_shared((threads, thread_elem_in_smem), dtype) |
| 69 | + T.annotate_layout({shared: make_mma_swizzle_layout(shared)}) |
| 70 | + tx = T.get_thread_binding(0) |
| 71 | + |
| 72 | + # 1. Load from HBM to register |
| 73 | + for i in T.vectorized(thread_elem): |
| 74 | + local[i] = A[bx, tx * thread_elem + i] |
| 75 | + |
| 76 | + # 2. Hadamard inside thread, n<=8 |
| 77 | + for i in T.serial(thread_round): |
| 78 | + chunksize = 1 << (i + 1) |
| 79 | + chunknum = thread_elem // chunksize |
| 80 | + for j in T.serial(chunknum): |
| 81 | + chunkbase = j * chunksize |
| 82 | + for k in T.serial(chunksize // 2): |
| 83 | + local[chunkbase + |
| 84 | + k] = local[chunkbase + k] + local[chunkbase + k + chunksize // 2] |
| 85 | + local[chunkbase + k + chunksize // |
| 86 | + 2] = local[chunkbase + k] - 2 * local[chunkbase + k + chunksize // 2] |
| 87 | + |
| 88 | + # 3. Hadamard inside warp, n<=512 |
| 89 | + # In warp level, we rely on warp shuffle to exchange data inside each warp, without using shared memory |
| 90 | + another_val = T.alloc_local((thread_elem,), dtype) |
| 91 | + |
| 92 | + warp_shfl(local, another_val, warp_round) |
| 93 | + |
| 94 | + # 4. Hadamard inside block, n<=32768 |
| 95 | + # Only exchange once for n<=8192, since shared mem can hold all elems |
| 96 | + if block_round > 0: |
| 97 | + warp_id = tx // warp_size |
| 98 | + lane_id = tx % warp_size |
| 99 | + src_tx = warp_id * warp_size + lane_id |
| 100 | + tgt_warp_id = tx % warps |
| 101 | + tgt_lane_id = tx // warps |
| 102 | + tgt_tx = tgt_warp_id * warp_size + tgt_lane_id |
| 103 | + |
| 104 | + # 4.1 Write to smem, swap, read from smem |
| 105 | + for cur_round in T.serial(exchange_round): |
| 106 | + exchange_base = thread_elem_in_smem * cur_round |
| 107 | + for j in T.vectorized(thread_elem_in_smem): |
| 108 | + shared[src_tx, j] = local[exchange_base + j] |
| 109 | + |
| 110 | + for j in T.vectorized(thread_elem_in_smem): |
| 111 | + local[exchange_base + j] = shared[tgt_tx, j] |
| 112 | + |
| 113 | + # 4.2 Warp shuffle |
| 114 | + warp_shfl(local, another_val, block_round) |
| 115 | + |
| 116 | + # 4.3 Write to smem, swap, read from smem |
| 117 | + for cur_round in T.serial(exchange_round): |
| 118 | + exchange_base = thread_elem_in_smem * cur_round |
| 119 | + for j in T.vectorized(thread_elem_in_smem): |
| 120 | + shared[tgt_tx, j] = local[exchange_base + j] |
| 121 | + |
| 122 | + for j in T.vectorized(thread_elem_in_smem): |
| 123 | + local[exchange_base + j] = shared[src_tx, j] |
| 124 | + |
| 125 | + # 5. Write back to HBM |
| 126 | + for i in T.vectorized(thread_elem): |
| 127 | + B[bx, tx * thread_elem + i] = local[i] |
| 128 | + |
| 129 | + return main |
| 130 | + |
| 131 | + |
| 132 | +def ref_program(x: torch.Tensor): |
| 133 | + assert x.ndim == 2 |
| 134 | + dim = x.shape[-1] |
| 135 | + assert is_pow_of_2(dim) |
| 136 | + return F.linear( |
| 137 | + x, torch.tensor(scipy.linalg.hadamard(dim, dtype=float), dtype=x.dtype, device=x.device)) |
| 138 | + |
| 139 | + |
| 140 | +def main(): |
| 141 | + parser = argparse.ArgumentParser() |
| 142 | + parser.add_argument('--batch', type=int, default=64, help='Batch size') |
| 143 | + parser.add_argument('--dim', type=int, default=32768, help='Dimension') |
| 144 | + args = parser.parse_args() |
| 145 | + |
| 146 | + B, D = args.batch, args.dim |
| 147 | + x = torch.randn((B, D), device='cuda') |
| 148 | + kernel = tilelang.compile(hadamard(B, D, 'float32'), out_idx=1) |
| 149 | + y = kernel(x) |
| 150 | + y_ref = ref_program(x) |
| 151 | + torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) |
| 152 | + print('All tests passed.') |
| 153 | + |
| 154 | + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) |
| 155 | + latency = profiler.do_bench(warmup=100) |
| 156 | + print("Tile-lang: {:.2f} ms".format(latency)) |
| 157 | + |
| 158 | + |
| 159 | +if __name__ == '__main__': |
| 160 | + main() |
0 commit comments