Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions examples/flash_attention/example_gqa_fwd_bshd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch
import torch.nn.functional as F
import tilelang
from tilelang import Profiler
from tilelang.autotuner import *
import tilelang.language as T
import itertools
import argparse
from functools import partial


def get_configs():
block_M = [128]
block_N = [128]
num_stages = [2]
threads = [256]
_configs = list(itertools.product(block_M, block_N, num_stages, threads))

configs = [{
'block_M': c[0],
'block_N': c[1],
'num_stages': c[2],
'threads': c[3]
} for c in _configs]
return configs


def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1):
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
head_kv = heads // groups
q_shape = [batch, seq_len, heads, dim]
kv_shape = [batch, seq_len, head_kv, dim]
dtype = "float16"
accum_dtype = "float"

def kernel_func(block_M, block_N, num_stages, threads):

@T.macro
def MMA0(
K: T.Buffer(kv_shape, dtype),
Q_shared: T.Buffer([block_M, dim], dtype),
K_shared: T.Buffer([block_N, dim], dtype),
acc_s: T.Buffer([block_M, block_N], accum_dtype),
k: T.int32,
bx: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared)
if is_causal:
for i, j in T.Parallel(block_M, block_N):
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
-T.infinity(acc_s.dtype))
else:
T.clear(acc_s)
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

@T.macro
def MMA1(
V: T.Buffer(kv_shape, dtype),
V_shared: T.Buffer([block_M, dim], dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
acc_o: T.Buffer([block_M, dim], accum_dtype),
k: T.int32,
by: T.int32,
bz: T.int32,
):
T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared)
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)

@T.macro
def Softmax(
acc_s: T.Buffer([block_M, block_N], accum_dtype),
acc_s_cast: T.Buffer([block_M, block_N], dtype),
scores_max: T.Buffer([block_M], accum_dtype),
scores_max_prev: T.Buffer([block_M], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
scores_sum: T.Buffer([block_M], accum_dtype),
logsum: T.Buffer([block_M], accum_dtype),
):
T.copy(scores_max, scores_max_prev)
T.fill(scores_max, -T.infinity(accum_dtype))
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
# To do causal softmax, we need to set the scores_max to 0 if it is -inf
# This process is called Check_inf in FlashAttention3 code, and it only need to be done
# in the first ceil_div(kBlockM, kBlockN) steps.
# for i in T.Parallel(block_M):
# scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i])
for i in T.Parallel(block_M):
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
for i, j in T.Parallel(block_M, block_N):
# Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
# max * log_2(e)) This allows the compiler to use the ffma
# instruction instead of fadd and fmul separately.
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
T.reduce_sum(acc_s, scores_sum, dim=1)
for i in T.Parallel(block_M):
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
T.copy(acc_s, acc_s_cast)

@T.macro
def Rescale(
acc_o: T.Buffer([block_M, dim], accum_dtype),
scores_scale: T.Buffer([block_M], accum_dtype),
):
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] *= scores_scale[i]

@T.prim_func
def main(
Q: T.Buffer(q_shape, dtype),
K: T.Buffer(kv_shape, dtype),
V: T.Buffer(kv_shape, dtype),
Output: T.Buffer(q_shape, dtype),
):
with T.Kernel(
T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz):
Q_shared = T.alloc_shared([block_M, dim], dtype)
K_shared = T.alloc_shared([block_N, dim], dtype)
V_shared = T.alloc_shared([block_N, dim], dtype)
O_shared = T.alloc_shared([block_M, dim], dtype)
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
scores_max = T.alloc_fragment([block_M], accum_dtype)
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
scores_scale = T.alloc_fragment([block_M], accum_dtype)
scores_sum = T.alloc_fragment([block_M], accum_dtype)
logsum = T.alloc_fragment([block_M], accum_dtype)

T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
T.fill(acc_o, 0)
T.fill(logsum, 0)
T.fill(scores_max, -T.infinity(accum_dtype))

loop_range = (
T.min(T.ceildiv(seq_len, block_N), T.ceildiv(
(bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N))

for k in T.Pipelined(loop_range, num_stages=num_stages):
MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz)
Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale,
scores_sum, logsum)
Rescale(acc_o, scores_scale)
MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz)
for i, j in T.Parallel(block_M, dim):
acc_o[i, j] /= logsum[i]
T.copy(acc_o, O_shared)
T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])

return main

if tune:

@autotune(
configs=get_configs(),
keys=["block_M", "block_N", "num_stages", "threads"],
warmup=10,
rep=10)
@jit(
out_idx=[3],
supply_type=tilelang.TensorSupplyType.Integer,
ref_prog=None,
profiler="auto")
def kernel(block_M=None, block_N=None, num_stages=None, threads=None):
return kernel_func(block_M, block_N, num_stages, threads)

return kernel()
else:

def kernel(block_M, block_N, num_stages, threads):
return kernel_func(block_M, block_N, num_stages, threads)

return kernel


def ref_program(Q, K, V, is_causal, groups=1):
# Q: [B, T, HQ, D]
# K: [B, T, HK, D]
# V: [B, T, HV, D]
# HQ = HKV * groups
assert Q.size(2) == K.size(
2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}"
assert Q.size(2) == V.size(
2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}"

dim = Q.size(-1)
K = K.repeat_interleave(groups, dim=2)
V = V.repeat_interleave(groups, dim=2)
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
if is_causal:
seq_len = Q.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
mask = mask.unsqueeze(0).unsqueeze(0)
scores = scores.masked_fill(mask == 0, float('-inf'))
attention_weights = F.softmax(scores, dim=-1)
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
return output


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=8, help='batch size')
parser.add_argument('--heads', type=int, default=32, help='heads')
parser.add_argument('--seq_len', type=int, default=4096, help='sequence length')
parser.add_argument('--dim', type=int, default=128, help='dim')
parser.add_argument('--is_causal', action='store_true', help='causal')
parser.add_argument('--tune', action='store_true', help='tune configs')
parser.add_argument('--groups', type=int, default=8, help='groups')
args = parser.parse_args()
batch, heads, seq_len, dim, is_causal, groups = args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups
flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim
total_flops = 2 * flops_per_matmul
if is_causal:
total_flops *= 0.5

if (not args.tune):
program = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)(
block_M=128, block_N=128, num_stages=1, threads=128)
ref_program = partial(ref_program, is_causal=is_causal, groups=groups)
mod, params = tilelang.lower(program)
mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal)
mod.assert_allclose(ref_program, rtol=0.01, atol=0.01)
print("All checks pass.")
latency = mod.do_bench(ref_program, warmup=500)
print("Ref: {:.2f} ms".format(latency))
print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9))
latency = mod.do_bench(mod.func, warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
else:
best_latency, best_config, _ = flashattn(
batch, heads, seq_len, dim, is_causal, tune=args.tune)
print(f"Best latency: {best_latency}")
print(f"Best TFlops: {total_flops / best_latency * 1e-9}")
print(f"Best config: {best_config}")
73 changes: 10 additions & 63 deletions tilelang/engine/lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the MIT License.
"""The compiler for TL programs."""

import tilelang as tl
import os
import os.path as osp
from typing import Union, Optional, Callable
Expand All @@ -12,6 +11,10 @@
from tvm.target import Target
from tilelang.contrib import hipcc, nvcc
from tilelang.utils.target import determine_target
from tilelang.engine.phase import (
LowerAndLegalize,
OptimizeForTarget,
)


def is_cpu_device_backend(target: Target):
Expand Down Expand Up @@ -152,68 +155,12 @@ def lower(
_is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target))
_is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target))

mod = tir.transform.BindTarget(target)(mod)

mod = tl.transform.FrontendLegalize()(mod)
mod = tir.transform.Simplify()(mod)
mod = tl.transform.LayoutInference()(mod)
mod = tl.transform.LowerTileOp()(mod)
mod = tl.transform.LegalizeVectorizedLoop()(mod)
mod = tl.transform.LegalizeSafeMemoryAccess()(mod)
# Inject Simplify to remove the duplicated conditions
mod = tir.transform.Simplify()(mod)

# which may be introduced by the LegalizeSafeMemoryAccess
if target.arch == "sm_90":
mod = tl.transform.MultiVersionBuffer()(mod)
mod = tl.transform.WarpSpecialized()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)
mod = tir.transform.LowerOpaqueBlock()(mod)
# mod = tl.transform.WarpSpecializedPipeline()(mod)
mod = tl.transform.InjectFenceProxy()(mod)
else:
mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
mod = tl.transform.PipelinePlanning()(mod)
mod = tl.transform.InjectSoftwarePipeline()(mod)

mod = tir.transform.LowerOpaqueBlock()(mod)
mod = tir.transform.FlattenBuffer()(mod)
mod = tir.transform.NarrowDataType(32)(mod)
mod = tir.transform.Simplify()(mod)
mod = tl.transform.VectorizeLoop()(mod)
mod = tir.transform.StorageRewrite()(mod)
mod = tir.transform.UnrollLoop()(mod)
mod = tir.transform.RenormalizeSplitPattern()(mod)
mod = tir.transform.Simplify()(mod)
mod = tir.transform.RemoveNoOp()(mod)
mod = tir.transform.RewriteUnsafeSelect()(mod)
mod = tir.transform.HoistIfThenElse()(mod)

mod = tir.transform.VerifyMemory()(mod)
mod = tir.transform.AnnotateEntryFunc()(mod)
# TODO(lei): This is a hack to make sure the
# thread level allreduce pass can be applied
# in TL. As Tl only use one thread dimension
# the var binding information will be lost
# in the lowering process with Legalization
# and Simplify pass.
# We can find a way better to create var instead
# of putting the LowerThreadAllreduce before
# the Legalization.
mod = tl.transform.ThreadPartialSync("shared.dyn")(mod)
mod = tir.transform.InferFragment()(mod)
mod = tir.transform.LowerThreadAllreduce()(mod)
mod = tl.transform.LowerHopperIntrin()(mod)
mod = tl.transform.ThreadSync("shared")(mod)
mod = tl.transform.ThreadSync("shared.dyn")(mod)
mod = tir.transform.InjectPTXAsyncCopy()(mod)

mod = tl.transform.AnnotateDeviceRegions()(mod)
mod = tir.transform.SplitHostDevice()(mod)
mod = tir.transform.MergeSharedMemoryAllocations()(mod)

mod = tl.transform.MakePackedAPI()(mod)
mod = tir.transform.LowerDeviceKernelLaunch()(mod)
# Phase 1: Lower and legalize the IR
mod = LowerAndLegalize(mod, target)

# Phase 2: Optimize the IR for the target
mod = OptimizeForTarget(mod, target)

host_mod = tir.transform.Filter(_is_host_call)(mod)
host_mod = tir.transform.BindTarget(target_host)(host_mod)
host_mod = tir.transform.FP8StorageLegalize()(host_mod)
Expand Down
Loading
Loading