|
5 | 5 | import tilelang.language as T |
6 | 6 | from einops import rearrange, repeat |
7 | 7 | import itertools |
| 8 | +import math |
| 9 | +from tilelang.profiler import do_bench |
| 10 | + |
| 11 | +try: |
| 12 | + from mamba_ssm.ops.triton.ssd_chunk_scan import _chunk_scan_fwd |
| 13 | +except ImportError as err: |
| 14 | + raise ImportError("Please install mamba-ssm to use the triton chunk scan operator.") from err |
| 15 | + |
| 16 | +try: |
| 17 | + import helion |
| 18 | + from helion._testing import run_example |
| 19 | + import helion.language as hl |
| 20 | +except ImportError as err: |
| 21 | + raise ImportError("Please install helion to use the helion chunk scan operator.") from err |
8 | 22 |
|
9 | 23 |
|
10 | 24 | def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): |
@@ -54,6 +68,119 @@ def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): |
54 | 68 | return out |
55 | 69 |
|
56 | 70 |
|
| 71 | +def chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D): |
| 72 | + out, _ = _chunk_scan_fwd(cb, x, dt, dA_cumsum, C, states, D) |
| 73 | + return out |
| 74 | + |
| 75 | + |
| 76 | +def chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D): |
| 77 | + |
| 78 | + @helion.kernel() |
| 79 | + def helion_mamba2_chunk_scan_kernel( |
| 80 | + cb: torch.Tensor, |
| 81 | + x: torch.Tensor, |
| 82 | + dt: torch.Tensor, |
| 83 | + dA_cumsum: torch.Tensor, |
| 84 | + C: torch.Tensor, |
| 85 | + prev_states: torch.Tensor, |
| 86 | + D: torch.Tensor, |
| 87 | + ) -> torch.Tensor: |
| 88 | + """ |
| 89 | + Argument: |
| 90 | + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) |
| 91 | + x: (batch, seqlen, nheads, headdim) |
| 92 | + dt: (batch, nheads, nchunks, chunk_size) |
| 93 | + dA_cumsum: (batch, nheads, nchunks, chunk_size) |
| 94 | + C: (batch, seqlen, ngroups, dstate) |
| 95 | + prev_states: (batch, nchunks, nheads, headdim, dstate) |
| 96 | + D: (nheads,) |
| 97 | + Return: |
| 98 | + out: (batch, seqlen, nheads, headdim) |
| 99 | + """ |
| 100 | + |
| 101 | + batch, nchunks, ngroups, chunk_size, _ = cb.shape |
| 102 | + _, seqlen, nheads, headdim = x.shape |
| 103 | + _, _, _, dstate = C.shape |
| 104 | + assert nchunks == (seqlen + chunk_size - 1) // chunk_size |
| 105 | + |
| 106 | + block_m = hl.register_block_size(chunk_size) |
| 107 | + block_n = hl.register_block_size(headdim) |
| 108 | + block_k = hl.register_block_size(64, 64) |
| 109 | + dstate = hl.specialize(dstate) |
| 110 | + |
| 111 | + assert cb.shape == (batch, nchunks, ngroups, chunk_size, chunk_size) |
| 112 | + assert x.shape == (batch, seqlen, nheads, headdim) |
| 113 | + assert dt.shape == (batch, nheads, nchunks, chunk_size) |
| 114 | + assert dA_cumsum.shape == (batch, nheads, nchunks, chunk_size) |
| 115 | + assert C.shape == (batch, seqlen, ngroups, dstate) |
| 116 | + assert prev_states.shape == (batch, nchunks, nheads, headdim, dstate) |
| 117 | + assert D.shape == (nheads,) |
| 118 | + |
| 119 | + dtype = cb.dtype |
| 120 | + accum_dtype = torch.float32 |
| 121 | + assert (x.dtype == dt.dtype == dA_cumsum.dtype == C.dtype == prev_states.dtype == D.dtype == |
| 122 | + dtype) |
| 123 | + |
| 124 | + out = torch.empty_like(x) |
| 125 | + |
| 126 | + p = 1.44269504 |
| 127 | + |
| 128 | + for tile_h, tile_m, tile_n, tile_b, tile_c in hl.tile( |
| 129 | + [nheads, chunk_size, headdim, batch, nchunks], |
| 130 | + block_size=[1, block_m, block_n, 1, 1], |
| 131 | + ): |
| 132 | + acc_o = hl.zeros([tile_m, tile_n], dtype=accum_dtype) |
| 133 | + dA_cumsum_local_m = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, |
| 134 | + tile_m].to(torch.float32) |
| 135 | + scale_m_local = torch.exp2(dA_cumsum_local_m * p) |
| 136 | + |
| 137 | + C_local = C[ |
| 138 | + tile_b.begin, |
| 139 | + tile_m.index + tile_c.begin * chunk_size, |
| 140 | + tile_h.begin // (nheads // ngroups), |
| 141 | + :, |
| 142 | + ] |
| 143 | + prev_states_local = prev_states[tile_b.begin, tile_c.begin, tile_h.begin, tile_n, :] |
| 144 | + acc_o = hl.dot(C_local, prev_states_local.T, acc=acc_o) |
| 145 | + acc_o *= scale_m_local[:, None] |
| 146 | + |
| 147 | + for tile_k in hl.tile((tile_m.id + 1) * block_m, block_size=block_k): |
| 148 | + cb_local = cb[ |
| 149 | + tile_b.begin, |
| 150 | + tile_c.begin, |
| 151 | + tile_h.begin // (nheads // ngroups), |
| 152 | + tile_m, |
| 153 | + tile_k, |
| 154 | + ] |
| 155 | + dA_cumsum_local_k = dA_cumsum[tile_b.begin, tile_h.begin, tile_c.begin, |
| 156 | + tile_k].to(torch.float32) |
| 157 | + cb_local *= torch.exp2(dA_cumsum_local_m[:, None] * p - |
| 158 | + dA_cumsum_local_k[None, :] * p) |
| 159 | + dt_local = dt[tile_b.begin, tile_h.begin, tile_c.begin, tile_k].to(torch.float32) |
| 160 | + cb_local = (cb_local * dt_local[None, :]).to(dtype) |
| 161 | + pred = (tile_m.index + 0)[:, None] >= (tile_k.index + 0)[None, :] |
| 162 | + cb_local = torch.where(pred, cb_local, torch.zeros_like(cb_local)) |
| 163 | + x_local = x[ |
| 164 | + tile_b.begin, |
| 165 | + tile_c.begin * chunk_size + tile_k.index, |
| 166 | + tile_h.begin, |
| 167 | + tile_n, |
| 168 | + ] |
| 169 | + acc_o = hl.dot(cb_local, x_local, acc=acc_o) |
| 170 | + |
| 171 | + D_local = D[tile_h.begin].to(torch.float32) |
| 172 | + x_residual = x[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, |
| 173 | + tile_n].to(torch.float32) |
| 174 | + acc_o += x_residual * D_local |
| 175 | + out[tile_b.begin, tile_c.begin * chunk_size + tile_m.index, tile_h.begin, |
| 176 | + tile_n] = acc_o.to(dtype=dtype) |
| 177 | + |
| 178 | + return out |
| 179 | + |
| 180 | + args = (cb, x, dt, dA_cumsum, C, states, D) |
| 181 | + run_example(helion_mamba2_chunk_scan_kernel, ref_program, args) |
| 182 | + |
| 183 | + |
57 | 184 | def get_configs(): |
58 | 185 | iter_params = dict( |
59 | 186 | block_M=[64, 128, 256], |
@@ -212,12 +339,30 @@ def main( |
212 | 339 | parser.add_argument('--tune', action='store_true', help='tune configs') |
213 | 340 | args = parser.parse_args() |
214 | 341 | batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate |
| 342 | + nchunks = math.ceil(seq_len / chunk_size) |
215 | 343 | total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate |
216 | 344 |
|
| 345 | + print("Benchmarking TileLang...") |
217 | 346 | kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) |
218 | 347 | best_latency = kernel.latency |
219 | 348 | best_config = kernel.config |
220 | 349 | ref_latency = kernel.ref_latency |
221 | 350 | print(f"Best latency: {best_latency}") |
222 | 351 | print(f"Best TFlops: {total_flops / best_latency * 1e-9}") |
223 | 352 | print(f"Best config: {best_config}") |
| 353 | + |
| 354 | + cb = torch.randn(batch, nchunks, groups, chunk_size, chunk_size).half().cuda() |
| 355 | + x = torch.randn(batch, seq_len, heads, dim).half().cuda() |
| 356 | + dt = torch.randn(batch, heads, nchunks, chunk_size).half().cuda() |
| 357 | + dA_cumsum = torch.randn(batch, heads, nchunks, chunk_size).half().cuda() |
| 358 | + C = torch.randn(batch, seq_len, groups, dstate).half().cuda() |
| 359 | + states = torch.randn(batch, nchunks, heads, dim, dstate).half().cuda() |
| 360 | + D = torch.randn(heads).half().cuda() |
| 361 | + |
| 362 | + print("Benchmarking Triton...") |
| 363 | + triton_latency = do_bench( |
| 364 | + lambda: chunk_scan_triton(cb, x, dt, dA_cumsum, C, states, D), _n_warmup=10, _n_repeat=10) |
| 365 | + print(f"Triton TFlops: {total_flops / triton_latency * 1e-9}") |
| 366 | + |
| 367 | + print("Benchmarking Helion...") |
| 368 | + chunk_scan_helion(cb, x, dt, dA_cumsum, C, states, D) |
0 commit comments