|
| 1 | +# Copyright (c) Tile-AI Corporation. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +import torch |
| 5 | +import tilelang as tl |
| 6 | +import tilelang.language as T |
| 7 | +from tilelang.profiler import do_bench |
| 8 | + |
| 9 | +import argparse |
| 10 | +from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA |
| 11 | + |
| 12 | + |
| 13 | +def chunk_linear_attn_bwd_kernel( |
| 14 | + B, |
| 15 | + S, |
| 16 | + H, |
| 17 | + DK, |
| 18 | + DV, |
| 19 | + dtype: str = 'float16', |
| 20 | + scale: float = None, |
| 21 | +) -> torch.Tensor: |
| 22 | + |
| 23 | + if scale is None: |
| 24 | + scale = DK**-0.5 |
| 25 | + accum_dtype = 'float' |
| 26 | + |
| 27 | + chunk_size = 64 |
| 28 | + BK = BV = 64 |
| 29 | + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 |
| 30 | + NK = tl.cdiv(DK, BK) |
| 31 | + NV = tl.cdiv(DV, BV) |
| 32 | + NT = tl.cdiv(S, chunk_size) |
| 33 | + |
| 34 | + @T.prim_func |
| 35 | + def main( |
| 36 | + Q: T.Tensor([B, S, H, DK], dtype), |
| 37 | + K: T.Tensor([B, S, H, DK], dtype), |
| 38 | + V: T.Tensor([B, S, H, DV], dtype), |
| 39 | + dO: T.Tensor([B, S, H, DV], dtype), |
| 40 | + dQ: T.Tensor([NV, B, S, H, DK], dtype), |
| 41 | + dK: T.Tensor([NV, B, S, H, DK], dtype), |
| 42 | + dV: T.Tensor([NK, B, S, H, DV], dtype), |
| 43 | + ): |
| 44 | + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): |
| 45 | + i_b = i_bh // H |
| 46 | + i_h = i_bh % H |
| 47 | + |
| 48 | + ds = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) |
| 49 | + ds_shared = T.alloc_shared([chunk_size, chunk_size], dtype) |
| 50 | + dq = T.alloc_fragment([chunk_size, BK], accum_dtype) |
| 51 | + dk = T.alloc_fragment([chunk_size, BK], accum_dtype) |
| 52 | + dv = T.alloc_fragment([chunk_size, BV], accum_dtype) |
| 53 | + q = T.alloc_shared([chunk_size, BK], dtype) |
| 54 | + k = T.alloc_shared([chunk_size, BK], dtype) |
| 55 | + v = T.alloc_shared([chunk_size, BV], dtype) |
| 56 | + do = T.alloc_shared([chunk_size, BV], dtype) |
| 57 | + h = T.alloc_fragment([BV, BK], accum_dtype) |
| 58 | + h_shared = T.alloc_shared([BV, BK], dtype) |
| 59 | + dh = T.alloc_fragment([BK, BV], accum_dtype) |
| 60 | + dh_shared = T.alloc_shared([BK, BV], dtype) |
| 61 | + T.clear(h) |
| 62 | + T.clear(dh) |
| 63 | + |
| 64 | + T.annotate_layout({ |
| 65 | + ds_shared: tl.layout.make_swizzled_layout(ds_shared), |
| 66 | + q: tl.layout.make_swizzled_layout(q), |
| 67 | + k: tl.layout.make_swizzled_layout(k), |
| 68 | + v: tl.layout.make_swizzled_layout(v), |
| 69 | + do: tl.layout.make_swizzled_layout(do), |
| 70 | + h_shared: tl.layout.make_swizzled_layout(h_shared), |
| 71 | + dh_shared: tl.layout.make_swizzled_layout(dh_shared) |
| 72 | + }) |
| 73 | + |
| 74 | + # Calculate dQ |
| 75 | + for i in T.Pipelined(0, NT, num_stages=1): |
| 76 | + T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) |
| 77 | + T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) |
| 78 | + T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], |
| 79 | + do) |
| 80 | + |
| 81 | + T.gemm(do, v, ds, transpose_B=True, clear_accum=True) |
| 82 | + for row, col in T.Parallel(chunk_size, chunk_size): |
| 83 | + ds_shared[row, col] = T.if_then_else(row >= col, ds[row, col], 0) |
| 84 | + |
| 85 | + T.gemm(ds_shared, k, dq, clear_accum=True) |
| 86 | + T.copy(h, h_shared) |
| 87 | + T.gemm(do, h_shared, dq) |
| 88 | + T.gemm(v, k, h, transpose_A=True) |
| 89 | + for row, col in T.Parallel(chunk_size, BK): |
| 90 | + dq[row, col] *= scale |
| 91 | + T.copy( |
| 92 | + dq, dQ[i_v, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, |
| 93 | + i_k * BK:(i_k + 1) * BK]) |
| 94 | + |
| 95 | + # Calculate dK, dV (reversely) |
| 96 | + for i in T.Pipelined(1, NT + 1, num_stages=1): |
| 97 | + start = NT - i |
| 98 | + for row, col in T.Parallel(chunk_size, BK): |
| 99 | + q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale |
| 100 | + T.copy( |
| 101 | + K[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, |
| 102 | + i_k * BK:(i_k + 1) * BK], k) |
| 103 | + T.copy( |
| 104 | + V[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, |
| 105 | + i_v * BV:(i_v + 1) * BV], v) |
| 106 | + T.copy( |
| 107 | + dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, |
| 108 | + i_v * BV:(i_v + 1) * BV], do) |
| 109 | + T.copy(dh, dh_shared) |
| 110 | + |
| 111 | + # Calculate dk |
| 112 | + T.gemm( |
| 113 | + v, do, ds, transpose_B=True, clear_accum=True |
| 114 | + ) # ds here actually means `s`, but we simply reuse the buffer `ds` |
| 115 | + for row, col in T.Parallel(chunk_size, chunk_size): |
| 116 | + ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) |
| 117 | + T.gemm(ds_shared, q, dk, clear_accum=True) |
| 118 | + T.gemm(v, dh_shared, dk, transpose_B=True) |
| 119 | + |
| 120 | + # Calculate dv |
| 121 | + T.gemm(k, q, ds, transpose_B=True, clear_accum=True) |
| 122 | + for row, col in T.Parallel(chunk_size, chunk_size): |
| 123 | + ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) |
| 124 | + T.gemm(ds_shared, do, dv, clear_accum=True) |
| 125 | + T.gemm(k, dh_shared, dv) |
| 126 | + |
| 127 | + # Update dh |
| 128 | + T.gemm(q, do, dh, transpose_A=True) |
| 129 | + |
| 130 | + T.copy( |
| 131 | + dk, dK[i_v, i_b, start * chunk_size:(start + 1) * chunk_size, i_h, |
| 132 | + i_k * BK:(i_k + 1) * BK]) |
| 133 | + T.copy( |
| 134 | + dv, dV[i_k, i_b, start * chunk_size:(start + 1) * chunk_size, i_h, |
| 135 | + i_v * BV:(i_v + 1) * BV]) |
| 136 | + |
| 137 | + return main |
| 138 | + |
| 139 | + |
| 140 | +def postprocess(dQ, dK, dV): |
| 141 | + dQ = dQ[0] if dQ.size(0) == 1 else dQ.sum(0) |
| 142 | + dK = dK[0] if dK.size(0) == 1 else dK.sum(0) |
| 143 | + dV = dV[0] if dV.size(0) == 1 else dV.sum(0) |
| 144 | + return dQ, dK, dV |
| 145 | + |
| 146 | + |
| 147 | +def main(): |
| 148 | + parser = argparse.ArgumentParser() |
| 149 | + parser.add_argument('--B', type=int, default=8, help='Batch size') |
| 150 | + parser.add_argument('--S', type=int, default=2048, help='Seq len') |
| 151 | + parser.add_argument('--H', type=int, default=64, help='Num heads') |
| 152 | + parser.add_argument('--D', type=int, default=256, help='Head dim') |
| 153 | + args = parser.parse_args() |
| 154 | + B, S, H, D = args.B, args.S, args.H, args.D |
| 155 | + |
| 156 | + q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) |
| 157 | + k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) |
| 158 | + v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) |
| 159 | + do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) |
| 160 | + |
| 161 | + fn = chunk_linear_attn_bwd_kernel(B, S, H, D, D) |
| 162 | + kernel = tl.compile(fn, out_idx=[4, 5, 6], target='cuda') |
| 163 | + dq, dk, dv = postprocess(*kernel(q, k, v, do)) |
| 164 | + o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) |
| 165 | + o_ref.backward(do, retain_graph=True) |
| 166 | + if torch.allclose(dq, q.grad) and torch.allclose(dk, k.grad) and torch.allclose(dv, v.grad): |
| 167 | + print('Passed all tests!✅') |
| 168 | + else: |
| 169 | + print('Failed some tests!❌') |
| 170 | + t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100) |
| 171 | + q.grad = k.grad = v.grad = None |
| 172 | + o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) |
| 173 | + t2 = do_bench(lambda: postprocess(*kernel(q, k, v, do)), warmup=25, rep=100) |
| 174 | + print(f'Triton latency: {t1:.3f} ms') |
| 175 | + print(f'TileLang latency: {t2:.3f} ms') |
| 176 | + print(f'Speedup: {t1/t2:.3f}x') |
| 177 | + |
| 178 | + |
| 179 | +if __name__ == '__main__': |
| 180 | + main() |
0 commit comments