|
| 1 | +import argparse |
| 2 | +import torch |
| 3 | +import tilelang |
| 4 | +from tilelang.autotuner import * |
| 5 | +import tilelang.language as T |
| 6 | +from einops import rearrange, repeat |
| 7 | +import itertools |
| 8 | + |
| 9 | + |
| 10 | +def ref_program(cb, x, dt, dA_cumsum, C, prev_states, D): |
| 11 | + """ |
| 12 | + Argument: |
| 13 | + cb: (batch, nchunks, ngroups, chunk_size, chunk_size) |
| 14 | + x: (batch, seqlen, nheads, headdim) |
| 15 | + dt: (batch, nheads, nchunks, chunk_size) |
| 16 | + dA_cumsum: (batch, nheads, nchunks, chunk_size) |
| 17 | + C: (batch, seqlen, ngroups, dstate) |
| 18 | + prev_states: (batch, nchunks, nheads, headdim, dstate) |
| 19 | + D: (nheads, headdim) or (nheads,) |
| 20 | + z: (batch, seqlen, nheads, headdim) |
| 21 | + Return: |
| 22 | + out: (batch, seqlen, nheads, headdim) |
| 23 | + """ |
| 24 | + _, _, ngroups, _, _ = cb.shape |
| 25 | + batch, seqlen, nheads, headdim = x.shape |
| 26 | + # _, _, ngroups, dstate = B.shape |
| 27 | + # assert B.shape == (batch, seqlen, ngroups, dstate) |
| 28 | + _, _, nchunks, chunk_size = dt.shape |
| 29 | + assert seqlen == nchunks * chunk_size |
| 30 | + # assert C.shape == B.shape |
| 31 | + # B = repeat(B, "b l g d -> b l (g h) d", h=nheads // ngroups) |
| 32 | + C = repeat(C, "b l g d -> b l (g h) d", h=nheads // ngroups) |
| 33 | + cb = repeat(cb, "b c g l s -> b c (g h) l s", h=nheads // ngroups) |
| 34 | + # CB = torch.einsum("bclhn,bcshn->bchls", rearrange(C, "b (c l) h n -> b c l h n", c=nchunks), |
| 35 | + # rearrange(B, "b (c s) h n -> b c s h n", c=nchunks)) |
| 36 | + # (batch, nheads, nchunks, chunksize, chunksize) |
| 37 | + dt_segment_sum = dA_cumsum[:, :, :, :, None] - dA_cumsum[:, :, :, None, :] |
| 38 | + decay = torch.exp(dt_segment_sum) |
| 39 | + scores_decay = cb * rearrange(decay, "b h c l s -> b c h l s") |
| 40 | + causal_mask = torch.tril( |
| 41 | + torch.ones(chunk_size, chunk_size, device=x.device, dtype=bool), diagonal=0) |
| 42 | + scores_decay = scores_decay.masked_fill(~causal_mask, 0) |
| 43 | + out = torch.einsum('bchls,bhcs,bcshp->bclhp', scores_decay.to(x.dtype), dt.to(x.dtype), |
| 44 | + rearrange(x, "b (c s) h p -> b c s h p", c=nchunks)) |
| 45 | + state_decay_out = torch.exp(rearrange(dA_cumsum, "b h c l -> b c l h 1")) |
| 46 | + out_prev = torch.einsum('bclhn,bchpn->bclhp', rearrange( |
| 47 | + C, "b (c l) h n -> b c l h n", c=nchunks), prev_states.to(C.dtype)) * state_decay_out |
| 48 | + out = out + out_prev |
| 49 | + out = rearrange(out, "b c l h p -> b (c l) h p") |
| 50 | + if D is not None: |
| 51 | + if D.dim() == 1: |
| 52 | + D = rearrange(D, "h -> h 1") |
| 53 | + out = out + x * D |
| 54 | + return out |
| 55 | + |
| 56 | + |
| 57 | +def get_configs(): |
| 58 | + iter_params = dict( |
| 59 | + block_M=[64, 128, 256], |
| 60 | + block_N=[32, 64], |
| 61 | + block_K=[64, 128, 256], |
| 62 | + block_Dstate=[128], |
| 63 | + num_stages=[1, 2, 3, 4, 5]) |
| 64 | + return [dict(zip(iter_params, values)) for values in itertools.product(*iter_params.values())] |
| 65 | + |
| 66 | + |
| 67 | +@autotune(configs=get_configs(), warmup=10, rep=10) |
| 68 | +@tilelang.jit( |
| 69 | + out_idx=[7], |
| 70 | + pass_configs={ |
| 71 | + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, |
| 72 | + }, |
| 73 | +) |
| 74 | +def chunk_scan_fwd(batch, |
| 75 | + seqlen, |
| 76 | + chunk_size, |
| 77 | + ngroups, |
| 78 | + nheads, |
| 79 | + headdim, |
| 80 | + dstate, |
| 81 | + block_M=64, |
| 82 | + block_N=64, |
| 83 | + block_K=64, |
| 84 | + block_Dstate=128, |
| 85 | + num_stages=2, |
| 86 | + threads=128): |
| 87 | + dtype = "float16" |
| 88 | + accum_dtype = "float" |
| 89 | + nchunks = T.ceildiv(seqlen, chunk_size) |
| 90 | + p = 1.44269504 |
| 91 | + |
| 92 | + @T.prim_func |
| 93 | + def main( |
| 94 | + cb: T.Tensor((batch, nchunks, ngroups, chunk_size, chunk_size), dtype), # type: ignore |
| 95 | + x: T.Tensor((batch, seqlen, nheads, headdim), dtype), # type: ignore |
| 96 | + dt: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore |
| 97 | + dA_cumsum: T.Tensor((batch, nheads, nchunks, chunk_size), dtype), # type: ignore |
| 98 | + C: T.Tensor((batch, seqlen, ngroups, dstate), dtype), # type: ignore |
| 99 | + prev_states: T.Tensor((batch, nchunks, nheads, headdim, dstate), dtype), # type: ignore |
| 100 | + D: T.Tensor((nheads), dtype), # type: ignore |
| 101 | + Output: T.Tensor((batch, seqlen, nheads, headdim), dtype) # type: ignore |
| 102 | + ): |
| 103 | + with T.Kernel( |
| 104 | + nheads, |
| 105 | + T.ceildiv(chunk_size, block_M) * T.ceildiv(headdim, block_N), |
| 106 | + batch * nchunks, |
| 107 | + threads=threads) as (bz, bx, by): |
| 108 | + acc_o = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 109 | + acc_o_shared = T.alloc_shared((block_M, block_N), dtype) |
| 110 | + cb_shared = T.alloc_shared((block_M, block_K), dtype, scope="shared.dyn") |
| 111 | + cb_local = T.alloc_fragment((block_M, block_K), dtype) |
| 112 | + dA_cs_k_shared = T.alloc_shared((block_K), dtype, scope="shared") |
| 113 | + dA_cs_k_local = T.alloc_fragment((block_K), accum_dtype) |
| 114 | + dA_cs_m_local = T.alloc_fragment((block_M), accum_dtype) |
| 115 | + dt_shared = T.alloc_shared((block_K), dtype, scope="shared") |
| 116 | + dt_local = T.alloc_fragment((block_K), accum_dtype) |
| 117 | + x_shared = T.alloc_shared((block_K, block_N), dtype, scope="shared.dyn") |
| 118 | + dA_cs_m_shared = T.alloc_shared((block_M), dtype, scope="shared") |
| 119 | + scale_m_local = T.alloc_fragment((block_M), accum_dtype) |
| 120 | + C_shared = T.alloc_shared((block_M, block_Dstate), dtype) |
| 121 | + prev_state_shared = T.alloc_shared((block_N, block_Dstate), dtype) |
| 122 | + D_local = T.alloc_fragment((1), accum_dtype) |
| 123 | + x_residual_shared = T.alloc_shared((block_M, block_N), dtype, scope="shared.dyn") |
| 124 | + x_residual_local = T.alloc_fragment((block_M, block_N), accum_dtype) |
| 125 | + |
| 126 | + batch_idx = by % batch |
| 127 | + chunk_idx = by // batch |
| 128 | + # m: chunk_size |
| 129 | + # n : headdim |
| 130 | + m_idx = bx // T.ceildiv(headdim, block_N) |
| 131 | + n_idx = bx % T.ceildiv(headdim, block_N) |
| 132 | + |
| 133 | + T.annotate_layout({ |
| 134 | + acc_o_shared: tilelang.layout.make_swizzled_layout(acc_o_shared), |
| 135 | + cb_shared: tilelang.layout.make_swizzled_layout(cb_shared), |
| 136 | + x_residual_shared: tilelang.layout.make_swizzled_layout(x_residual_shared) |
| 137 | + }) |
| 138 | + |
| 139 | + T.no_set_max_nreg() |
| 140 | + |
| 141 | + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, m_idx * block_M:(m_idx + 1) * block_M], |
| 142 | + dA_cs_m_shared) |
| 143 | + T.copy(dA_cs_m_shared, dA_cs_m_local) |
| 144 | + T.clear(acc_o) |
| 145 | + |
| 146 | + for i in T.Parallel(block_M): |
| 147 | + scale_m_local[i] = T.exp2(dA_cs_m_local[i] * p) |
| 148 | + T.copy( |
| 149 | + C[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + |
| 150 | + (m_idx + 1) * block_M, bz // (nheads // ngroups), 0:block_Dstate], C_shared) |
| 151 | + T.copy( |
| 152 | + prev_states[batch_idx, chunk_idx, bz, n_idx * block_N:(n_idx + 1) * block_N, |
| 153 | + 0:block_Dstate], prev_state_shared) |
| 154 | + T.gemm(C_shared, prev_state_shared, acc_o, transpose_B=True) |
| 155 | + for i, j in T.Parallel(block_M, block_N): |
| 156 | + acc_o[i, j] *= scale_m_local[i] |
| 157 | + |
| 158 | + loop_range = T.ceildiv((m_idx + 1) * block_M, block_K) |
| 159 | + |
| 160 | + for k in T.Pipelined(loop_range, num_stages=num_stages): |
| 161 | + T.copy( |
| 162 | + cb[batch_idx, chunk_idx, bz // (nheads // ngroups), |
| 163 | + m_idx * block_M:(m_idx + 1) * block_M, k * block_K:(k + 1) * block_K], |
| 164 | + cb_shared) |
| 165 | + T.copy(cb_shared, cb_local) |
| 166 | + T.copy(dA_cumsum[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], |
| 167 | + dA_cs_k_shared) |
| 168 | + T.copy(dA_cs_k_shared, dA_cs_k_local) |
| 169 | + for i, j in T.Parallel(block_M, block_K): |
| 170 | + cb_local[i, |
| 171 | + j] = cb_local[i, |
| 172 | + j] * T.exp2(dA_cs_m_local[i] * p - dA_cs_k_local[j] * p) |
| 173 | + T.copy(dt[batch_idx, bz, chunk_idx, k * block_K:(k + 1) * block_K], dt_shared) |
| 174 | + T.copy(dt_shared, dt_local) |
| 175 | + for i, j in T.Parallel(block_M, block_K): |
| 176 | + cb_local[i, j] *= dt_local[j] |
| 177 | + for i, j in T.Parallel(block_M, block_K): |
| 178 | + cb_local[i, j] = T.if_then_else(m_idx * block_M + i >= k * block_K + j, |
| 179 | + cb_local[i, j], 0) |
| 180 | + T.copy( |
| 181 | + x[batch_idx, chunk_idx * chunk_size + k * block_K:chunk_idx * chunk_size + |
| 182 | + (k + 1) * block_K, bz, n_idx * block_N:(n_idx + 1) * block_N], x_shared) |
| 183 | + T.gemm(cb_local, x_shared, acc_o) |
| 184 | + |
| 185 | + D_local[0] = D[bz] |
| 186 | + T.copy( |
| 187 | + x[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + |
| 188 | + (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N], |
| 189 | + x_residual_shared) |
| 190 | + T.copy(x_residual_shared, x_residual_local) |
| 191 | + for i, j in T.Parallel(block_M, block_N): |
| 192 | + acc_o[i, j] += x_residual_local[i, j] * D_local[0] |
| 193 | + |
| 194 | + T.copy(acc_o, acc_o_shared) |
| 195 | + T.copy( |
| 196 | + acc_o_shared, |
| 197 | + Output[batch_idx, chunk_idx * chunk_size + m_idx * block_M:chunk_idx * chunk_size + |
| 198 | + (m_idx + 1) * block_M, bz, n_idx * block_N:(n_idx + 1) * block_N]) |
| 199 | + |
| 200 | + return main |
| 201 | + |
| 202 | + |
| 203 | +if __name__ == "__main__": |
| 204 | + parser = argparse.ArgumentParser() |
| 205 | + parser.add_argument('--batch', type=int, default=8, help='batch size') |
| 206 | + parser.add_argument('--heads', type=int, default=80, help='heads') |
| 207 | + parser.add_argument('--groups', type=int, default=1, help='groups') |
| 208 | + parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') |
| 209 | + parser.add_argument('--chunk_size', type=int, default=256, help='chunk size') |
| 210 | + parser.add_argument('--dim', type=int, default=64, help='dim') |
| 211 | + parser.add_argument('--dstate', type=int, default=128, help='dstate') |
| 212 | + parser.add_argument('--tune', action='store_true', help='tune configs') |
| 213 | + args = parser.parse_args() |
| 214 | + batch, heads, groups, seq_len, chunk_size, dim, dstate = args.batch, args.heads, args.groups, args.seq_len, args.chunk_size, args.dim, args.dstate |
| 215 | + total_flops = 2 * batch * seq_len * chunk_size * heads * dim * 0.5 + 2 * batch * seq_len * heads * dim * dstate |
| 216 | + |
| 217 | + kernel = chunk_scan_fwd(batch, seq_len, chunk_size, groups, heads, dim, dstate) |
| 218 | + best_latency = kernel.latency |
| 219 | + best_config = kernel.config |
| 220 | + ref_latency = kernel.ref_latency |
| 221 | + print(f"Best latency: {best_latency}") |
| 222 | + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") |
| 223 | + print(f"Best config: {best_config}") |
0 commit comments