Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/norm/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def ref_program(x):
latency = profiler.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} ms".format(latency))
78 changes: 78 additions & 0 deletions examples/norm/test_rms_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) Tile-AI Corporation.
# Licensed under the MIT License.

import torch
import tilelang
import tilelang.language as T


def rms_norm_splitk(M, N, blk_m, blk_k):
dtype = "float"

@T.prim_func
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
A_shared = T.alloc_shared((blk_m, blk_k), dtype)
A_local = T.alloc_fragment((blk_m, blk_k), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype)

num_k_step = T.ceildiv(N, blk_k)
T.clear(A_local)
for k in range(num_k_step):
T.copy(A[bx * blk_m, k * blk_k], A_shared)
for i, j in T.Parallel(blk_m, blk_k):
A_local[i, j] += A_shared[i, j] * A_shared[i, j]
T.reduce_sum(A_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12

for k in range(num_k_step):
# reverse, better cache hit rate
T.copy(A[bx * blk_m, (num_k_step - 1 - k) * blk_k], A_shared)
for i, j in T.Parallel(blk_m, blk_k):
A_shared[i, j] *= A_powsum[i]
T.copy(A_shared, B[bx * blk_m, (num_k_step - 1 - k) * blk_k])

return main


def rms_norm(M, N, blk_m):
dtype = "float"

@T.prim_func
def main(A: T.Tensor((M, N), dtype), B: T.Tensor((M, N), dtype)):
with T.Kernel(T.ceildiv(M, blk_m), threads=128) as bx:
A_shared = T.alloc_shared((blk_m, N), dtype)
A_pow_local = T.alloc_fragment((blk_m, N), dtype)
A_local = T.alloc_fragment((blk_m, N), dtype)
A_powsum = T.alloc_fragment((blk_m,), dtype)

T.copy(A[bx * blk_m:(bx + 1) * blk_m, :], A_shared)
T.copy(A_shared, A_local)
for i, j in T.Parallel(blk_m, N):
A_pow_local[i, j] = A_local[i, j] * A_local[i, j]
T.reduce_sum(A_pow_local, A_powsum, dim=1)
for i in T.Parallel(blk_m):
A_powsum[i] = T.rsqrt(A_powsum[i] / N) + 1e-12
for i, j in T.Parallel(blk_m, N):
A_local[i, j] *= A_powsum[i]
T.copy(A_local, B[bx * blk_m:(bx + 1) * blk_m, :])

return main


def ref_program(x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + 1e-12)


def test_rms_norm():
M, N, blk_m = 8192, 8192, 1
program = rms_norm(M, N, blk_m)
kernel = tilelang.compile(
program,
out_idx=-1,
target="cuda",
execution_backend="cython",
pass_configs={"tl.disable_tma_lower": True})
profiler = kernel.get_profiler()
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
Loading