From 9e5a793e2932ce8f080d9383221601702995bca1 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 4 Jun 2025 18:18:56 +0800 Subject: [PATCH 1/4] Add linear attention examples. --- .../example_linear_attn_bwd.py | 181 ++++++++++++++++++ .../example_linear_attn_fwd.py | 129 +++++++++++++ 2 files changed, 310 insertions(+) create mode 100644 examples/linear_attention/example_linear_attn_bwd.py create mode 100644 examples/linear_attention/example_linear_attn_fwd.py diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py new file mode 100644 index 000000000..ea79d9733 --- /dev/null +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -0,0 +1,181 @@ +from numpy import transpose +import torch +import tilelang as tl +import tilelang.language as T +from tilelang.profiler import do_bench + +import argparse +from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA + + +def chunk_linear_attn_bwd( + B, S, H, DK, DV, + dtype: str = 'float16', + scale: float = None, +) -> torch.Tensor: + + if scale is None: + scale = DK ** -0.5 + accum_dtype = 'float' + + chunk_size = 64 + BK = BV = 64 + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 + NK = tl.cdiv(DK, BK) + NV = tl.cdiv(DV, BV) + NT = tl.cdiv(S, chunk_size) + + @T.prim_func + def main( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + dO: T.Tensor([B, S, H, DV], dtype), # type: ignore + dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore + dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore + dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + ): + with T.Kernel(NV, NK, B*H) as (i_v, i_k, i_bh): + i_b = i_bh // H + i_h = i_bh % H + + ds = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + ds_shared = T.alloc_shared([chunk_size, chunk_size], dtype) + dq = T.alloc_fragment([chunk_size, BK], accum_dtype) + dk = T.alloc_fragment([chunk_size, BK], accum_dtype) + dv = T.alloc_fragment([chunk_size, BV], accum_dtype) + q = T.alloc_shared([chunk_size, BK], dtype) + k = T.alloc_shared([chunk_size, BK], dtype) + v = T.alloc_shared([chunk_size, BV], dtype) + do = T.alloc_shared([chunk_size, BV], dtype) + h = T.alloc_fragment([BV, BK], accum_dtype) + h_shared = T.alloc_shared([BV, BK], dtype) + dh = T.alloc_fragment([BK, BV], accum_dtype) + dh_shared = T.alloc_shared([BK, BV], dtype) + T.clear(h) + T.clear(dh) + + T.annotate_layout({ + ds_shared: tl.layout.make_swizzled_layout(ds_shared), + q: tl.layout.make_swizzled_layout(q), + k: tl.layout.make_swizzled_layout(k), + v: tl.layout.make_swizzled_layout(v), + do: tl.layout.make_swizzled_layout(do), + h_shared: tl.layout.make_swizzled_layout(h_shared), + dh_shared: tl.layout.make_swizzled_layout(dh_shared) + }) + + + # Calculate dQ + for i in T.Pipelined(0, NT, num_stages=1): + T.copy(K[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_k * BK:(i_k+1) * BK], k) + T.copy(V[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV], v) + T.copy(dO[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV], do) + + T.gemm(do, v, ds, transpose_B=True, clear_accum=True) + for row, col in T.Parallel(chunk_size, chunk_size): + ds_shared[row, col] = T.if_then_else( + row >= col, + ds[row, col], + 0 + ) + + T.gemm(ds_shared, k, dq, clear_accum=True) + T.copy(h, h_shared) + T.gemm(do, h_shared, dq) + T.gemm(v, k, h, transpose_A=True) + for row, col in T.Parallel(chunk_size, BK): + dq[row, col] *= scale + T.copy(dq, dQ[i_v, i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_k * BK:(i_k+1) * BK]) + + # Calculate dK, dV (reversely) + for i in T.Pipelined(1, NT+1, num_stages=1): + start = NT-i + for row, col in T.Parallel(chunk_size, BK): + q[row, col] = Q[i_b, start*chunk_size+row, i_h, i_k*BK+col] * scale + T.copy(K[i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_k * BK:(i_k+1) * BK], k) + T.copy(V[i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV], v) + T.copy(dO[i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV], do) + T.copy(dh, dh_shared) + + # Calculate dk + T.gemm(v, do, ds, transpose_B=True, clear_accum=True) # ds here actually means `s`, but we simply reuse the buffer `ds` + for row, col in T.Parallel(chunk_size, chunk_size): + ds_shared[row, col] = T.if_then_else( + row <= col, + ds[row, col], + 0 + ) + T.gemm(ds_shared, q, dk, clear_accum=True) + T.gemm(v, dh_shared, dk, transpose_B=True) + + # Calculate dv + T.gemm(k, q, ds, transpose_B=True, clear_accum=True) + for row, col in T.Parallel(chunk_size, chunk_size): + ds_shared[row, col] = T.if_then_else( + row <= col, + ds[row, col], + 0 + ) + T.gemm(ds_shared, do, dv, clear_accum=True) + T.gemm(k, dh_shared, dv) + + # Update dh + T.gemm(q, do, dh, transpose_A=True) + + T.copy(dk, dK[i_v, i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_k * BK:(i_k+1) * BK]) + T.copy(dv, dV[i_k, i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV]) + + return main + + +def postprocess(dQ, dK, dV): + dQ = dQ[0] if dQ.size(0) == 1 else dQ.sum(0) + dK = dK[0] if dK.size(0) == 1 else dK.sum(0) + dV = dV[0] if dV.size(0) == 1 else dV.sum(0) + return dQ, dK, dV + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--B', type=int, default=8, help='Batch size') + parser.add_argument('--S', type=int, default=2048, help='Seq len') + parser.add_argument('--H', type=int, default=64, help='Num heads') + parser.add_argument('--D', type=int, default=256, help='Head dim') + args = parser.parse_args() + B, S, H, D = args.B, args.S, args.H, args.D + + q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) + k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) + v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) + do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + + fn = chunk_linear_attn_bwd(B, S, H, D, D) + kernel = tl.compile(fn, out_idx=[4, 5, 6], target='cuda') + dq, dk, dv = postprocess(*kernel(q, k, v, do)) + o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) + o_ref.backward(do, retain_graph=True) + if torch.allclose(dq, q.grad) and torch.allclose(dk, k.grad) and torch.allclose(dv, v.grad): + print('Passed all tests!✅') + else: + print('Failed some tests!❌') + t1 = do_bench ( + lambda: o_ref.backward(do, retain_graph=True), + warmup=25, + rep=100 + ) + q.grad = k.grad = v.grad = None + o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) + t2 = do_bench ( + lambda: postprocess(*kernel(q, k, v, do)), + warmup=25, + rep=100 + ) + print(f'Triton latency: {t1:.3f} ms') + print(f'TileLang latency: {t2:.3f} ms') + print(f'Speedup: {t1/t2:.3f}x') + + +if __name__ == '__main__': + main() + \ No newline at end of file diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py new file mode 100644 index 000000000..655fd0940 --- /dev/null +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -0,0 +1,129 @@ +import torch +import tilelang as tl +import tilelang.language as T +from tilelang.profiler import do_bench + +import argparse +from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA + + +def chunk_linear_attn_fwd_kernel( + B, S, H, DK, DV, + dtype: str = 'float16', + scale: float = None, +) -> torch.Tensor: + + if scale is None: + scale = DK ** -0.5 + accum_dtype = 'float' + + chunk_size = 64 + BK = BV = 64 + assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 + NK = tl.cdiv(DK, BK) + NV = tl.cdiv(DV, BV) + NT = tl.cdiv(S, chunk_size) + + @T.prim_func + def main( + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + final_state: T.Tensor([B, H, DK, DV], accum_dtype) # type: ignore + ): + with T.Kernel(NV, NK, B*H) as (i_v, i_k, i_bh): + i_b = i_bh // H + i_h = i_bh % H + + q = T.alloc_shared([chunk_size, BK], dtype) + k = T.alloc_shared([chunk_size, BK], dtype) + v = T.alloc_shared([chunk_size, BV], dtype) + h = T.alloc_fragment([BK, BV], accum_dtype) + h_shared = T.alloc_shared([BK, BV], dtype) + s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + s_shared = T.alloc_shared([chunk_size, chunk_size], dtype) + o = T.alloc_fragment([chunk_size, BV], accum_dtype) + T.clear(h) + + T.annotate_layout({ + q: tl.layout.make_swizzled_layout(q), + k: tl.layout.make_swizzled_layout(k), + v: tl.layout.make_swizzled_layout(v), + h_shared: tl.layout.make_swizzled_layout(h_shared), + s_shared: tl.layout.make_swizzled_layout(s_shared), + }) + T.use_swizzle(8) + + for i in T.Pipelined(0, NT, num_stages=1): + for row, col in T.Parallel(chunk_size, BK): + q[row, col] = Q[i_b, i*chunk_size+row, i_h, i_k*BK+col] * scale + T.copy(K[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) + T.copy(V[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + + T.gemm(q, k, s, clear_accum=True, transpose_B=True) + for row, col in T.Parallel(chunk_size, chunk_size): + s_shared[row, col] = T.if_then_else( + row >= col, + s[row, col], + 0 + ) + + T.gemm(s_shared, v, o, clear_accum=True) + T.copy(h, h_shared) + T.gemm(q, h_shared, o) + T.gemm(k, v, h, transpose_A=True) + T.copy(o, O[i_k, i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV]) + + # Output final state + T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k+1) * BK, i_v * BV:(i_v+1) * BV]) + + return main + + +def postprocess(o, h): + o = o[0] if o.size(0) == 1 else o.sum(0) + return o, h + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--B', type=int, default=8, help='Batch size') + parser.add_argument('--S', type=int, default=2048, help='Seq len') + parser.add_argument('--H', type=int, default=64, help='Num heads') + parser.add_argument('--D', type=int, default=256, help='Head dim') + args = parser.parse_args() + B, S, H, D = args.B, args.S, args.H, args.D + + q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + + fn = chunk_linear_attn_fwd_kernel(B, S, H, D, D) + kernel = tl.compile(fn, out_idx=[3, 4], target='cuda') + o, h = postprocess(*kernel(q, k, v)) + o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) + + if torch.allclose(o, o_ref) and torch.allclose(h, h_ref): + print('Passed all tests!✅') + else: + print('Failed some tests!❌') + + t1 = do_bench ( + lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)[0], + warmup=25, + rep=100 + ) + t2 = do_bench ( + lambda: kernel(q, k, v)[0].sum(0), + warmup=25, + rep=100 + ) + print(f'Triton latency: {t1:.3f} ms') + print(f'TileLang latency: {t2:.3f} ms') + print(f'Speedup: {t1/t2:.3f}x') + + +if __name__ == '__main__': + main() + \ No newline at end of file From b93512c45ffa1c372124121db51e921de4f2872b Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 4 Jun 2025 18:21:41 +0800 Subject: [PATCH 2/4] Add license --- .../example_linear_attn_bwd.py | 151 +++++++++--------- .../example_linear_attn_fwd.py | 95 ++++++----- 2 files changed, 122 insertions(+), 124 deletions(-) diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index ea79d9733..72a3618a7 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -1,45 +1,51 @@ -from numpy import transpose +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + import torch import tilelang as tl -import tilelang.language as T +import tilelang.language as T from tilelang.profiler import do_bench import argparse -from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA - - -def chunk_linear_attn_bwd( - B, S, H, DK, DV, +from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA + + +def chunk_linear_attn_bwd_kernel( + B, + S, + H, + DK, + DV, dtype: str = 'float16', scale: float = None, ) -> torch.Tensor: - + if scale is None: - scale = DK ** -0.5 + scale = DK**-0.5 accum_dtype = 'float' - + chunk_size = 64 BK = BV = 64 assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 NK = tl.cdiv(DK, BK) NV = tl.cdiv(DV, BV) NT = tl.cdiv(S, chunk_size) - + @T.prim_func def main( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - dO: T.Tensor([B, S, H, DV], dtype), # type: ignore - dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore - dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore - dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + dO: T.Tensor([B, S, H, DV], dtype), # type: ignore + dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore + dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore + dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore ): - with T.Kernel(NV, NK, B*H) as (i_v, i_k, i_bh): + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H - - ds = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) + + ds = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) ds_shared = T.alloc_shared([chunk_size, chunk_size], dtype) dq = T.alloc_fragment([chunk_size, BK], accum_dtype) dk = T.alloc_fragment([chunk_size, BK], accum_dtype) @@ -54,7 +60,7 @@ def main( dh_shared = T.alloc_shared([BK, BV], dtype) T.clear(h) T.clear(dh) - + T.annotate_layout({ ds_shared: tl.layout.make_swizzled_layout(ds_shared), q: tl.layout.make_swizzled_layout(q), @@ -65,67 +71,69 @@ def main( dh_shared: tl.layout.make_swizzled_layout(dh_shared) }) - # Calculate dQ for i in T.Pipelined(0, NT, num_stages=1): - T.copy(K[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_k * BK:(i_k+1) * BK], k) - T.copy(V[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV], v) - T.copy(dO[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV], do) - + T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], + do) + T.gemm(do, v, ds, transpose_B=True, clear_accum=True) for row, col in T.Parallel(chunk_size, chunk_size): - ds_shared[row, col] = T.if_then_else( - row >= col, - ds[row, col], - 0 - ) - + ds_shared[row, col] = T.if_then_else(row >= col, ds[row, col], 0) + T.gemm(ds_shared, k, dq, clear_accum=True) T.copy(h, h_shared) T.gemm(do, h_shared, dq) T.gemm(v, k, h, transpose_A=True) for row, col in T.Parallel(chunk_size, BK): dq[row, col] *= scale - T.copy(dq, dQ[i_v, i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_k * BK:(i_k+1) * BK]) - + T.copy( + dq, dQ[i_v, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, + i_k * BK:(i_k + 1) * BK]) + # Calculate dK, dV (reversely) - for i in T.Pipelined(1, NT+1, num_stages=1): - start = NT-i + for i in T.Pipelined(1, NT + 1, num_stages=1): + start = NT - i for row, col in T.Parallel(chunk_size, BK): - q[row, col] = Q[i_b, start*chunk_size+row, i_h, i_k*BK+col] * scale - T.copy(K[i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_k * BK:(i_k+1) * BK], k) - T.copy(V[i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV], v) - T.copy(dO[i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV], do) + q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale + T.copy( + K[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, + i_k * BK:(i_k + 1) * BK], k) + T.copy( + V[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, + i_v * BV:(i_v + 1) * BV], v) + T.copy( + dO[i_b, start * chunk_size:(start + 1) * chunk_size, i_h, + i_v * BV:(i_v + 1) * BV], do) T.copy(dh, dh_shared) - - # Calculate dk - T.gemm(v, do, ds, transpose_B=True, clear_accum=True) # ds here actually means `s`, but we simply reuse the buffer `ds` + + # Calculate dk + T.gemm( + v, do, ds, transpose_B=True, clear_accum=True + ) # ds here actually means `s`, but we simply reuse the buffer `ds` for row, col in T.Parallel(chunk_size, chunk_size): - ds_shared[row, col] = T.if_then_else( - row <= col, - ds[row, col], - 0 - ) + ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) T.gemm(ds_shared, q, dk, clear_accum=True) T.gemm(v, dh_shared, dk, transpose_B=True) - + # Calculate dv T.gemm(k, q, ds, transpose_B=True, clear_accum=True) for row, col in T.Parallel(chunk_size, chunk_size): - ds_shared[row, col] = T.if_then_else( - row <= col, - ds[row, col], - 0 - ) + ds_shared[row, col] = T.if_then_else(row <= col, ds[row, col], 0) T.gemm(ds_shared, do, dv, clear_accum=True) T.gemm(k, dh_shared, dv) - + # Update dh T.gemm(q, do, dh, transpose_A=True) - - T.copy(dk, dK[i_v, i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_k * BK:(i_k+1) * BK]) - T.copy(dv, dV[i_k, i_b, start*chunk_size:(start+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV]) - + + T.copy( + dk, dK[i_v, i_b, start * chunk_size:(start + 1) * chunk_size, i_h, + i_k * BK:(i_k + 1) * BK]) + T.copy( + dv, dV[i_k, i_b, start * chunk_size:(start + 1) * chunk_size, i_h, + i_v * BV:(i_v + 1) * BV]) + return main @@ -144,13 +152,13 @@ def main(): parser.add_argument('--D', type=int, default=256, help='Head dim') args = parser.parse_args() B, S, H, D = args.B, args.S, args.H, args.D - + q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - - fn = chunk_linear_attn_bwd(B, S, H, D, D) + + fn = chunk_linear_attn_bwd_kernel(B, S, H, D, D) kernel = tl.compile(fn, out_idx=[4, 5, 6], target='cuda') dq, dk, dv = postprocess(*kernel(q, k, v, do)) o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) @@ -159,23 +167,14 @@ def main(): print('Passed all tests!✅') else: print('Failed some tests!❌') - t1 = do_bench ( - lambda: o_ref.backward(do, retain_graph=True), - warmup=25, - rep=100 - ) + t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100) q.grad = k.grad = v.grad = None o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) - t2 = do_bench ( - lambda: postprocess(*kernel(q, k, v, do)), - warmup=25, - rep=100 - ) + t2 = do_bench(lambda: postprocess(*kernel(q, k, v, do)), warmup=25, rep=100) print(f'Triton latency: {t1:.3f} ms') print(f'TileLang latency: {t2:.3f} ms') - print(f'Speedup: {t1/t2:.3f}x') - + print(f'Speedup: {t1/t2:.3f}x') + if __name__ == '__main__': main() - \ No newline at end of file diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index 655fd0940..3d28be316 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -1,51 +1,58 @@ +# Copyright (c) Tile-AI Corporation. +# Licensed under the MIT License. + import torch import tilelang as tl -import tilelang.language as T +import tilelang.language as T from tilelang.profiler import do_bench import argparse -from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA - - +from fla.ops.linear_attn import fused_chunk_linear_attn # We compare with FLA + + def chunk_linear_attn_fwd_kernel( - B, S, H, DK, DV, + B, + S, + H, + DK, + DV, dtype: str = 'float16', scale: float = None, ) -> torch.Tensor: - + if scale is None: - scale = DK ** -0.5 + scale = DK**-0.5 accum_dtype = 'float' - + chunk_size = 64 BK = BV = 64 assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0 NK = tl.cdiv(DK, BK) NV = tl.cdiv(DV, BV) NT = tl.cdiv(S, chunk_size) - + @T.prim_func def main( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore - final_state: T.Tensor([B, H, DK, DV], accum_dtype) # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), # type: ignore + K: T.Tensor([B, S, H, DK], dtype), # type: ignore + V: T.Tensor([B, S, H, DV], dtype), # type: ignore + O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + final_state: T.Tensor([B, H, DK, DV], accum_dtype) # type: ignore ): - with T.Kernel(NV, NK, B*H) as (i_v, i_k, i_bh): + with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H - + q = T.alloc_shared([chunk_size, BK], dtype) k = T.alloc_shared([chunk_size, BK], dtype) v = T.alloc_shared([chunk_size, BV], dtype) - h = T.alloc_fragment([BK, BV], accum_dtype) + h = T.alloc_fragment([BK, BV], accum_dtype) h_shared = T.alloc_shared([BK, BV], dtype) s = T.alloc_fragment([chunk_size, chunk_size], accum_dtype) s_shared = T.alloc_shared([chunk_size, chunk_size], dtype) o = T.alloc_fragment([chunk_size, BV], accum_dtype) T.clear(h) - + T.annotate_layout({ q: tl.layout.make_swizzled_layout(q), k: tl.layout.make_swizzled_layout(k), @@ -54,30 +61,28 @@ def main( s_shared: tl.layout.make_swizzled_layout(s_shared), }) T.use_swizzle(8) - + for i in T.Pipelined(0, NT, num_stages=1): for row, col in T.Parallel(chunk_size, BK): - q[row, col] = Q[i_b, i*chunk_size+row, i_h, i_k*BK+col] * scale - T.copy(K[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) - T.copy(V[i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) - + q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale + T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k) + T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v) + T.gemm(q, k, s, clear_accum=True, transpose_B=True) for row, col in T.Parallel(chunk_size, chunk_size): - s_shared[row, col] = T.if_then_else( - row >= col, - s[row, col], - 0 - ) - + s_shared[row, col] = T.if_then_else(row >= col, s[row, col], 0) + T.gemm(s_shared, v, o, clear_accum=True) T.copy(h, h_shared) T.gemm(q, h_shared, o) T.gemm(k, v, h, transpose_A=True) - T.copy(o, O[i_k, i_b, i*chunk_size:(i+1)*chunk_size, i_h, i_v * BV:(i_v+1) * BV]) - + T.copy( + o, O[i_k, i_b, i * chunk_size:(i + 1) * chunk_size, i_h, + i_v * BV:(i_v + 1) * BV]) + # Output final state - T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k+1) * BK, i_v * BV:(i_v+1) * BV]) - + T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV]) + return main @@ -94,36 +99,30 @@ def main(): parser.add_argument('--D', type=int, default=256, help='Head dim') args = parser.parse_args() B, S, H, D = args.B, args.S, args.H, args.D - + q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) - + fn = chunk_linear_attn_fwd_kernel(B, S, H, D, D) kernel = tl.compile(fn, out_idx=[3, 4], target='cuda') o, h = postprocess(*kernel(q, k, v)) o_ref, h_ref = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False) - + if torch.allclose(o, o_ref) and torch.allclose(h, h_ref): print('Passed all tests!✅') else: print('Failed some tests!❌') - - t1 = do_bench ( + + t1 = do_bench( lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)[0], warmup=25, - rep=100 - ) - t2 = do_bench ( - lambda: kernel(q, k, v)[0].sum(0), - warmup=25, - rep=100 - ) + rep=100) + t2 = do_bench(lambda: kernel(q, k, v)[0].sum(0), warmup=25, rep=100) print(f'Triton latency: {t1:.3f} ms') print(f'TileLang latency: {t2:.3f} ms') - print(f'Speedup: {t1/t2:.3f}x') - + print(f'Speedup: {t1/t2:.3f}x') + if __name__ == '__main__': main() - \ No newline at end of file From 6c8baec5ab180d8cbf8f9acdb19627298dc1ff0f Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 4 Jun 2025 18:24:23 +0800 Subject: [PATCH 3/4] Remove comments --- .../linear_attention/example_linear_attn_bwd.py | 14 +++++++------- .../linear_attention/example_linear_attn_fwd.py | 10 +++++----- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index 72a3618a7..ddd514688 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -33,13 +33,13 @@ def chunk_linear_attn_bwd_kernel( @T.prim_func def main( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - dO: T.Tensor([B, S, H, DV], dtype), # type: ignore - dQ: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore - dK: T.Tensor([NV, B, S, H, DK], dtype), # type: ignore - dV: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), + K: T.Tensor([B, S, H, DK], dtype), + V: T.Tensor([B, S, H, DV], dtype), + dO: T.Tensor([B, S, H, DV], dtype), + dQ: T.Tensor([NV, B, S, H, DK], dtype), + dK: T.Tensor([NV, B, S, H, DK], dtype), + dV: T.Tensor([NK, B, S, H, DV], dtype), ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index 3d28be316..61282d049 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -33,11 +33,11 @@ def chunk_linear_attn_fwd_kernel( @T.prim_func def main( - Q: T.Tensor([B, S, H, DK], dtype), # type: ignore - K: T.Tensor([B, S, H, DK], dtype), # type: ignore - V: T.Tensor([B, S, H, DV], dtype), # type: ignore - O: T.Tensor([NK, B, S, H, DV], dtype), # type: ignore - final_state: T.Tensor([B, H, DK, DV], accum_dtype) # type: ignore + Q: T.Tensor([B, S, H, DK], dtype), + K: T.Tensor([B, S, H, DK], dtype), + V: T.Tensor([B, S, H, DV], dtype), + O: T.Tensor([NK, B, S, H, DV], dtype), + final_state: T.Tensor([B, H, DK, DV], accum_dtype) ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H From e487af73209f25844e8c2cc5db69bd73ad4e5394 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 4 Jun 2025 18:25:10 +0800 Subject: [PATCH 4/4] Run yapf and ruff --- .../linear_attention/example_linear_attn_bwd.py | 14 +++++++------- .../linear_attention/example_linear_attn_fwd.py | 10 +++------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index ddd514688..d03398627 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -33,13 +33,13 @@ def chunk_linear_attn_bwd_kernel( @T.prim_func def main( - Q: T.Tensor([B, S, H, DK], dtype), - K: T.Tensor([B, S, H, DK], dtype), - V: T.Tensor([B, S, H, DV], dtype), - dO: T.Tensor([B, S, H, DV], dtype), - dQ: T.Tensor([NV, B, S, H, DK], dtype), - dK: T.Tensor([NV, B, S, H, DK], dtype), - dV: T.Tensor([NK, B, S, H, DV], dtype), + Q: T.Tensor([B, S, H, DK], dtype), + K: T.Tensor([B, S, H, DK], dtype), + V: T.Tensor([B, S, H, DV], dtype), + dO: T.Tensor([B, S, H, DV], dtype), + dQ: T.Tensor([NV, B, S, H, DK], dtype), + dK: T.Tensor([NV, B, S, H, DK], dtype), + dV: T.Tensor([NK, B, S, H, DV], dtype), ): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index 61282d049..e9e96de13 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -32,13 +32,9 @@ def chunk_linear_attn_fwd_kernel( NT = tl.cdiv(S, chunk_size) @T.prim_func - def main( - Q: T.Tensor([B, S, H, DK], dtype), - K: T.Tensor([B, S, H, DK], dtype), - V: T.Tensor([B, S, H, DV], dtype), - O: T.Tensor([NK, B, S, H, DV], dtype), - final_state: T.Tensor([B, H, DK, DV], accum_dtype) - ): + def main(Q: T.Tensor([B, S, H, DK], dtype), K: T.Tensor([B, S, H, DK], dtype), + V: T.Tensor([B, S, H, DV], dtype), O: T.Tensor([NK, B, S, H, DV], dtype), + final_state: T.Tensor([B, H, DK, DV], accum_dtype)): with T.Kernel(NV, NK, B * H) as (i_v, i_k, i_bh): i_b = i_bh // H i_h = i_bh % H