Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ jobs:
source tilelang_ci/bin/activate
python -m pip install .

- name: Run examples
run: |
source tilelang_ci/bin/activate
cd examples
python -m pytest **/test*.py

- name: Run tests
run: |
source tilelang_ci/bin/activate
Expand Down
38 changes: 16 additions & 22 deletions examples/warp_specialize/example_warp_specialize_flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import tilelang
import tilelang.language as T
from einops import rearrange, einsum
import argparse


def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split):
Expand Down Expand Up @@ -46,40 +45,36 @@ def flash_attn(
T.annotate_layout({
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
})

T.create_list_of_mbarrier(128, 128, 256, 128)

loop_range = T.ceildiv(seqlen_kv, block_N)
with T.ws(2):
T.dec_max_nreg(24)
T.copy(Q[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_shared)
T.copy(Q_pe[bx, by * VALID_BLOCK_H:(by + 1) * VALID_BLOCK_H, :], Q_pe_shared)
T.mbarrier_arrive(T.get_mbarrier(3))
T.barrier_arrive(barrier_id=3)
for k in T.serial(loop_range):
T.mbarrier_wait_parity(
T.FloorMod(k, 1) + 2, T.bitwise_xor(T.FloorDiv(k, 1) % 2, 1))
T.barrier_wait(barrier_id=(k % 1) + 2, parity=(k % 2) ^ 1)
T.copy(KV[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], KV_shared)
T.mbarrier_arrive(T.FloorMod(k, 1))
T.barrier_arrive(k % 1)
T.copy(K_pe[bx, k * block_N:(k + 1) * block_N, cur_kv_head, :], K_pe_shared)
T.mbarrier_arrive(T.FloorMod(k, 1) + 1)
T.barrier_arrive(k % 1 + 1)
with T.ws(0, 1):
T.inc_max_nreg(240)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))
T.mbarrier_wait_parity(T.get_mbarrier(3), 0)
T.barrier_wait(3, 0)
for k in T.serial(loop_range):
T.clear(acc_s)
T.mbarrier_wait_parity(T.get_mbarrier(T.FloorMod(k, 1)), T.FloorDiv(k, 1) % 2)
T.barrier_wait(barrier_id=k % 1, parity=(k // 1) % 2)
T.gemm(
Q_shared,
KV_shared,
acc_s,
transpose_B=True,
policy=T.GemmWarpPolicy.FullCol)
T.mbarrier_wait_parity(
T.get_mbarrier(T.FloorMod(k, 1) + 1),
T.FloorDiv(k, 1) % 2)
T.barrier_wait(barrier_id=k % 1 + 1, parity=(k // 1) % 2)
T.gemm(
Q_pe_shared,
K_pe_shared,
Expand All @@ -100,7 +95,7 @@ def flash_attn(
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] *= scores_scale[i]
T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullCol)
T.mbarrier_arrive(T.get_mbarrier(T.FloorMod(k, 1) + 2))
T.barrier_arrive(barrier_id=k % 1 + 2)
for i, j in T.Parallel(block_H, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
Expand Down Expand Up @@ -165,15 +160,13 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):


def main():
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=128, help='batch size')
parser.add_argument('--heads', type=int, default=128, help='q heads number')
parser.add_argument('--kv_heads', type=int, default=1, help='kv heads number')
parser.add_argument('--kv_ctx', type=int, default=8192, help='kv context length')
parser.add_argument('--dim', type=int, default=512, help='head dim')
parser.add_argument('--pe_dim', type=int, default=64, help='pe head dim')
args = parser.parse_args()
batch, heads, kv_heads, kv_ctx, dim, pe_dim = args.batch, args.heads, args.kv_heads, args.kv_ctx, args.dim, args.pe_dim
batch = 128
heads = 128
kv_heads = 1
kv_ctx = 8192
dim = 512
pe_dim = 64

qk_flops = 2 * batch * heads * kv_ctx * (dim + pe_dim)
pv_flops = 2 * batch * heads * kv_ctx * dim
total_flops = qk_flops + pv_flops
Expand All @@ -183,6 +176,7 @@ def main():

program = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split)
kernel = tilelang.compile(program, out_idx=[6])
print(kernel.get_kernel_source())

profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn)
profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01)
Expand Down