diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py new file mode 100644 index 000000000..845cee648 --- /dev/null +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -0,0 +1,241 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import torch +import torch.nn.functional as F +import tilelang +from tilelang import Profiler +from tilelang.autotuner import * +import tilelang.language as T +import itertools +import argparse +from functools import partial + + +def get_configs(): + block_M = [128] + block_N = [128] + num_stages = [2] + threads = [256] + _configs = list(itertools.product(block_M, block_N, num_stages, threads)) + + configs = [{ + 'block_M': c[0], + 'block_N': c[1], + 'num_stages': c[2], + 'threads': c[3] + } for c in _configs] + return configs + + +def flashattn(batch, heads, seq_len, dim, is_causal, tune=False, groups=1): + scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e) + head_kv = heads // groups + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = "float16" + accum_dtype = "float" + + def kernel_func(block_M, block_N, num_stages, threads): + + @T.macro + def MMA0( + K: T.Buffer(kv_shape, dtype), + Q_shared: T.Buffer([block_M, dim], dtype), + K_shared: T.Buffer([block_N, dim], dtype), + acc_s: T.Buffer([block_M, block_N], accum_dtype), + k: T.int32, + bx: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(K[bz, k * block_N:(k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, + -T.infinity(acc_s.dtype)) + else: + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def MMA1( + V: T.Buffer(kv_shape, dtype), + V_shared: T.Buffer([block_M, dim], dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + acc_o: T.Buffer([block_M, dim], accum_dtype), + k: T.int32, + by: T.int32, + bz: T.int32, + ): + T.copy(V[bz, k * block_N:(k + 1) * block_N, by // groups, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + @T.macro + def Softmax( + acc_s: T.Buffer([block_M, block_N], accum_dtype), + acc_s_cast: T.Buffer([block_M, block_N], dtype), + scores_max: T.Buffer([block_M], accum_dtype), + scores_max_prev: T.Buffer([block_M], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + scores_sum: T.Buffer([block_M], accum_dtype), + logsum: T.Buffer([block_M], accum_dtype), + ): + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + # To do causal softmax, we need to set the scores_max to 0 if it is -inf + # This process is called Check_inf in FlashAttention3 code, and it only need to be done + # in the first ceil_div(kBlockM, kBlockN) steps. + # for i in T.Parallel(block_M): + # scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + # Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + # max * log_2(e)) This allows the compiler to use the ffma + # instruction instead of fadd and fmul separately. + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + @T.macro + def Rescale( + acc_o: T.Buffer([block_M, dim], accum_dtype), + scores_scale: T.Buffer([block_M], accum_dtype), + ): + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + @T.prim_func + def main( + Q: T.Buffer(q_shape, dtype), + K: T.Buffer(kv_shape, dtype), + V: T.Buffer(kv_shape, dtype), + Output: T.Buffer(q_shape, dtype), + ): + with T.Kernel( + T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv( + (bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N)) + + for k in T.Pipelined(loop_range, num_stages=num_stages): + MMA0(K, Q_shared, K_shared, acc_s, k, bx, by, bz) + Softmax(acc_s, acc_s_cast, scores_max, scores_max_prev, scores_scale, + scores_sum, logsum) + Rescale(acc_o, scores_scale) + MMA1(V, V_shared, acc_s_cast, acc_o, k, by, bz) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, O_shared) + T.copy(O_shared, Output[bz, bx * block_M:(bx + 1) * block_M, by, :]) + + return main + + if tune: + + @autotune( + configs=get_configs(), + keys=["block_M", "block_N", "num_stages", "threads"], + warmup=10, + rep=10) + @jit( + out_idx=[3], + supply_type=tilelang.TensorSupplyType.Integer, + ref_prog=None, + profiler="auto") + def kernel(block_M=None, block_N=None, num_stages=None, threads=None): + return kernel_func(block_M, block_N, num_stages, threads) + + return kernel() + else: + + def kernel(block_M, block_N, num_stages, threads): + return kernel_func(block_M, block_N, num_stages, threads) + + return kernel + + +def ref_program(Q, K, V, is_causal, groups=1): + # Q: [B, T, HQ, D] + # K: [B, T, HK, D] + # V: [B, T, HV, D] + # HQ = HKV * groups + assert Q.size(2) == K.size( + 2) * groups, f"Q.size(2): {Q.size(2)}, K.size(2): {K.size(2)}, groups: {groups}" + assert Q.size(2) == V.size( + 2) * groups, f"Q.size(2): {Q.size(2)}, V.size(2): {V.size(2)}, groups: {groups}" + + dim = Q.size(-1) + K = K.repeat_interleave(groups, dim=2) + V = V.repeat_interleave(groups, dim=2) + scores = torch.einsum('bqhd,bkhd->bhqk', Q, K) + scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) + if is_causal: + seq_len = Q.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float('-inf')) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V) + return output + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=8, help='batch size') + parser.add_argument('--heads', type=int, default=32, help='heads') + parser.add_argument('--seq_len', type=int, default=4096, help='sequence length') + parser.add_argument('--dim', type=int, default=128, help='dim') + parser.add_argument('--is_causal', action='store_true', help='causal') + parser.add_argument('--tune', action='store_true', help='tune configs') + parser.add_argument('--groups', type=int, default=8, help='groups') + args = parser.parse_args() + batch, heads, seq_len, dim, is_causal, groups = args.batch, args.heads, args.seq_len, args.dim, args.is_causal, args.groups + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + if (not args.tune): + program = flashattn( + batch, heads, seq_len, dim, is_causal, tune=args.tune, groups=groups)( + block_M=128, block_N=128, num_stages=1, threads=128) + ref_program = partial(ref_program, is_causal=is_causal, groups=groups) + mod, params = tilelang.lower(program) + mod = Profiler(mod, params, [3], tilelang.TensorSupplyType.Normal) + mod.assert_allclose(ref_program, rtol=0.01, atol=0.01) + print("All checks pass.") + latency = mod.do_bench(ref_program, warmup=500) + print("Ref: {:.2f} ms".format(latency)) + print("Ref: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = mod.do_bench(mod.func, warmup=500) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + else: + best_latency, best_config, _ = flashattn( + batch, heads, seq_len, dim, is_causal, tune=args.tune) + print(f"Best latency: {best_latency}") + print(f"Best TFlops: {total_flops / best_latency * 1e-9}") + print(f"Best config: {best_config}") diff --git a/tilelang/engine/lower.py b/tilelang/engine/lower.py index 964262cfd..d9791f869 100644 --- a/tilelang/engine/lower.py +++ b/tilelang/engine/lower.py @@ -2,7 +2,6 @@ # Licensed under the MIT License. """The compiler for TL programs.""" -import tilelang as tl import os import os.path as osp from typing import Union, Optional, Callable @@ -12,6 +11,10 @@ from tvm.target import Target from tilelang.contrib import hipcc, nvcc from tilelang.utils.target import determine_target +from tilelang.engine.phase import ( + LowerAndLegalize, + OptimizeForTarget, +) def is_cpu_device_backend(target: Target): @@ -152,68 +155,12 @@ def lower( _is_host_call = get_host_call(is_device_c=is_cpu_device_backend(target)) _is_device_call = get_device_call(is_device_c=is_cpu_device_backend(target)) - mod = tir.transform.BindTarget(target)(mod) - - mod = tl.transform.FrontendLegalize()(mod) - mod = tir.transform.Simplify()(mod) - mod = tl.transform.LayoutInference()(mod) - mod = tl.transform.LowerTileOp()(mod) - mod = tl.transform.LegalizeVectorizedLoop()(mod) - mod = tl.transform.LegalizeSafeMemoryAccess()(mod) - # Inject Simplify to remove the duplicated conditions - mod = tir.transform.Simplify()(mod) - - # which may be introduced by the LegalizeSafeMemoryAccess - if target.arch == "sm_90": - mod = tl.transform.MultiVersionBuffer()(mod) - mod = tl.transform.WarpSpecialized()(mod) - mod = tl.transform.InjectSoftwarePipeline()(mod) - mod = tir.transform.LowerOpaqueBlock()(mod) - # mod = tl.transform.WarpSpecializedPipeline()(mod) - mod = tl.transform.InjectFenceProxy()(mod) - else: - mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tl.transform.PipelinePlanning()(mod) - mod = tl.transform.InjectSoftwarePipeline()(mod) - - mod = tir.transform.LowerOpaqueBlock()(mod) - mod = tir.transform.FlattenBuffer()(mod) - mod = tir.transform.NarrowDataType(32)(mod) - mod = tir.transform.Simplify()(mod) - mod = tl.transform.VectorizeLoop()(mod) - mod = tir.transform.StorageRewrite()(mod) - mod = tir.transform.UnrollLoop()(mod) - mod = tir.transform.RenormalizeSplitPattern()(mod) - mod = tir.transform.Simplify()(mod) - mod = tir.transform.RemoveNoOp()(mod) - mod = tir.transform.RewriteUnsafeSelect()(mod) - mod = tir.transform.HoistIfThenElse()(mod) - - mod = tir.transform.VerifyMemory()(mod) - mod = tir.transform.AnnotateEntryFunc()(mod) - # TODO(lei): This is a hack to make sure the - # thread level allreduce pass can be applied - # in TL. As Tl only use one thread dimension - # the var binding information will be lost - # in the lowering process with Legalization - # and Simplify pass. - # We can find a way better to create var instead - # of putting the LowerThreadAllreduce before - # the Legalization. - mod = tl.transform.ThreadPartialSync("shared.dyn")(mod) - mod = tir.transform.InferFragment()(mod) - mod = tir.transform.LowerThreadAllreduce()(mod) - mod = tl.transform.LowerHopperIntrin()(mod) - mod = tl.transform.ThreadSync("shared")(mod) - mod = tl.transform.ThreadSync("shared.dyn")(mod) - mod = tir.transform.InjectPTXAsyncCopy()(mod) - - mod = tl.transform.AnnotateDeviceRegions()(mod) - mod = tir.transform.SplitHostDevice()(mod) - mod = tir.transform.MergeSharedMemoryAllocations()(mod) - - mod = tl.transform.MakePackedAPI()(mod) - mod = tir.transform.LowerDeviceKernelLaunch()(mod) + # Phase 1: Lower and legalize the IR + mod = LowerAndLegalize(mod, target) + + # Phase 2: Optimize the IR for the target + mod = OptimizeForTarget(mod, target) + host_mod = tir.transform.Filter(_is_host_call)(mod) host_mod = tir.transform.BindTarget(target_host)(host_mod) host_mod = tir.transform.FP8StorageLegalize()(host_mod) diff --git a/tilelang/engine/phase.py b/tilelang/engine/phase.py new file mode 100644 index 000000000..2ac15215c --- /dev/null +++ b/tilelang/engine/phase.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from tvm import tir, IRModule +from tvm.target import Target +import tilelang as tl + + +def LowerAndLegalize(mod: IRModule, target: Target) -> IRModule: + # Bind the target device information to the module + mod = tir.transform.BindTarget(target)(mod) + + # Legalize the frontend IR to make it compatible with TVM + mod = tl.transform.FrontendLegalize()(mod) + # Simplify the IR expressions + mod = tir.transform.Simplify()(mod) + # Infer memory layouts for fragments and shared memory + mod = tl.transform.LayoutInference()(mod) + # Lower high-level tile operations to low-level operations + mod = tl.transform.LowerTileOp()(mod) + # Legalize vectorized loops to ensure they are valid + mod = tl.transform.LegalizeVectorizedLoop()(mod) + # Add safety checks for memory accesses + mod = tl.transform.LegalizeSafeMemoryAccess()(mod) + # Simplify again to clean up any duplicated conditions + # that may have been introduced by safety checks + mod = tir.transform.Simplify()(mod) + + return mod + + +def OptimizeForTarget(mod: IRModule, target: Target) -> IRModule: + # which may be introduced by the LegalizeSafeMemoryAccess + if target.arch == "sm_90": + mod = tl.transform.MultiVersionBuffer()(mod) + mod = tl.transform.WarpSpecialized()(mod) + mod = tl.transform.InjectSoftwarePipeline()(mod) + mod = tir.transform.LowerOpaqueBlock()(mod) + # mod = tl.transform.WarpSpecializedPipeline()(mod) + mod = tl.transform.InjectFenceProxy()(mod) + else: + mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) + mod = tl.transform.PipelinePlanning()(mod) + mod = tl.transform.InjectSoftwarePipeline()(mod) + + mod = tir.transform.LowerOpaqueBlock()(mod) + mod = tir.transform.FlattenBuffer()(mod) + mod = tir.transform.NarrowDataType(32)(mod) + mod = tir.transform.Simplify()(mod) + mod = tl.transform.VectorizeLoop()(mod) + mod = tir.transform.StorageRewrite()(mod) + mod = tir.transform.UnrollLoop()(mod) + mod = tir.transform.RenormalizeSplitPattern()(mod) + mod = tir.transform.Simplify()(mod) + mod = tir.transform.RemoveNoOp()(mod) + mod = tir.transform.RewriteUnsafeSelect()(mod) + mod = tir.transform.HoistIfThenElse()(mod) + + mod = tir.transform.VerifyMemory()(mod) + mod = tir.transform.AnnotateEntryFunc()(mod) + # TODO(lei): This is a hack to make sure the + # thread level allreduce pass can be applied + # in TL. As Tl only use one thread dimension + # the var binding information will be lost + # in the lowering process with Legalization + # and Simplify pass. + # We can find a way better to create var instead + # of putting the LowerThreadAllreduce before + # the Legalization. + mod = tl.transform.ThreadPartialSync("shared.dyn")(mod) + mod = tir.transform.InferFragment()(mod) + mod = tir.transform.LowerThreadAllreduce()(mod) + mod = tl.transform.LowerHopperIntrin()(mod) + mod = tl.transform.ThreadSync("shared")(mod) + mod = tl.transform.ThreadSync("shared.dyn")(mod) + mod = tir.transform.InjectPTXAsyncCopy()(mod) + + mod = tl.transform.AnnotateDeviceRegions()(mod) + mod = tir.transform.SplitHostDevice()(mod) + mod = tir.transform.MergeSharedMemoryAllocations()(mod) + + mod = tl.transform.MakePackedAPI()(mod) + mod = tir.transform.LowerDeviceKernelLaunch()(mod) + + return mod diff --git a/tilelang/jit/adapter/ctypes/utils.py b/tilelang/jit/adapter/ctypes/utils.py index 287d63c72..f2be9527a 100644 --- a/tilelang/jit/adapter/ctypes/utils.py +++ b/tilelang/jit/adapter/ctypes/utils.py @@ -5,12 +5,15 @@ from tilelang import tvm as tvm from tvm import IRModule, tir from tvm.target import Target -import tilelang.transform from tilelang.engine.lower import ( is_device_call, determine_target, canon_target_host, ) +from tilelang.engine.phase import ( + LowerAndLegalize, + OptimizeForTarget, +) def match_global_kernel(source: str) -> int: @@ -47,58 +50,8 @@ def get_annotated_device_mod( target_host = tvm.target.Target.canon_target(target_host) target = tvm.target.Target(target, target_host) - mod = tir.transform.BindTarget(target)(mod) - - mod = tilelang.transform.FrontendLegalize()(mod) - mod = tir.transform.Simplify()(mod) - mod = tilelang.transform.LayoutInference()(mod) - mod = tilelang.transform.LowerTileOp()(mod) - mod = tir.transform.Simplify()(mod) - - if target.arch == "sm_90": - mod = tilelang.transform.WarpSpecializedPipeline()(mod) - else: - mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tilelang.transform.PipelinePlanning()(mod) - mod = tilelang.transform.InjectSoftwarePipeline()(mod) - - mod = tir.transform.LowerOpaqueBlock()(mod) - mod = tir.transform.FlattenBuffer()(mod) - mod = tir.transform.NarrowDataType(32)(mod) - mod = tir.transform.Simplify()(mod) - - mod = tir.transform.VectorizeLoop()(mod) - mod = tir.transform.StorageRewrite()(mod) - mod = tir.transform.UnrollLoop()(mod) - mod = tir.transform.RenormalizeSplitPattern()(mod) - mod = tir.transform.Simplify()(mod) - mod = tir.transform.RemoveNoOp()(mod) - mod = tir.transform.RewriteUnsafeSelect()(mod) - mod = tir.transform.HoistIfThenElse()(mod) - - mod = tir.transform.VerifyMemory()(mod) - mod = tir.transform.AnnotateEntryFunc()(mod) - mod = tir.transform.ThreadSync("shared")(mod) - # TODO(lei): This is a hack to make sure the - # thread level allreduce pass can be applied - # in TL. As Tl only use one thread dimension - # the var binding information will be lost - # in the lowering process with Legalization - # and Simplify pass. - # We can find a way better to create var instead - # of putting the LowerThreadAllreduce before - # the Legalization. - mod = tir.transform.LowerThreadAllreduce()(mod) - mod = tir.transform.ThreadSync("shared.dyn")(mod) - mod = tilelang.transform.LowerHopperIntrin()(mod) - mod = tir.transform.InjectPTXAsyncCopy()(mod) - - mod = tir.transform.AnnotateDeviceRegions()(mod) - mod = tir.transform.SplitHostDevice()(mod) - mod = tir.transform.MergeSharedMemoryAllocations()(mod) - mod = tir.transform.MakePackedAPI()(mod) - mod = tir.transform.LowerDeviceKernelLaunch()(mod) - + mod = LowerAndLegalize(mod, target) + mod = OptimizeForTarget(mod, target) device_mod = tir.transform.Filter(is_device_call)(mod) return device_mod diff --git a/tilelang/jit/adapter/cython/utils.py b/tilelang/jit/adapter/cython/utils.py index 287d63c72..c03c231e3 100644 --- a/tilelang/jit/adapter/cython/utils.py +++ b/tilelang/jit/adapter/cython/utils.py @@ -5,12 +5,15 @@ from tilelang import tvm as tvm from tvm import IRModule, tir from tvm.target import Target -import tilelang.transform from tilelang.engine.lower import ( is_device_call, determine_target, canon_target_host, ) +from tilelang.engine.phase import ( + LowerAndLegalize, + OptimizeForTarget, +) def match_global_kernel(source: str) -> int: @@ -47,57 +50,8 @@ def get_annotated_device_mod( target_host = tvm.target.Target.canon_target(target_host) target = tvm.target.Target(target, target_host) - mod = tir.transform.BindTarget(target)(mod) - - mod = tilelang.transform.FrontendLegalize()(mod) - mod = tir.transform.Simplify()(mod) - mod = tilelang.transform.LayoutInference()(mod) - mod = tilelang.transform.LowerTileOp()(mod) - mod = tir.transform.Simplify()(mod) - - if target.arch == "sm_90": - mod = tilelang.transform.WarpSpecializedPipeline()(mod) - else: - mod = tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) - mod = tilelang.transform.PipelinePlanning()(mod) - mod = tilelang.transform.InjectSoftwarePipeline()(mod) - - mod = tir.transform.LowerOpaqueBlock()(mod) - mod = tir.transform.FlattenBuffer()(mod) - mod = tir.transform.NarrowDataType(32)(mod) - mod = tir.transform.Simplify()(mod) - - mod = tir.transform.VectorizeLoop()(mod) - mod = tir.transform.StorageRewrite()(mod) - mod = tir.transform.UnrollLoop()(mod) - mod = tir.transform.RenormalizeSplitPattern()(mod) - mod = tir.transform.Simplify()(mod) - mod = tir.transform.RemoveNoOp()(mod) - mod = tir.transform.RewriteUnsafeSelect()(mod) - mod = tir.transform.HoistIfThenElse()(mod) - - mod = tir.transform.VerifyMemory()(mod) - mod = tir.transform.AnnotateEntryFunc()(mod) - mod = tir.transform.ThreadSync("shared")(mod) - # TODO(lei): This is a hack to make sure the - # thread level allreduce pass can be applied - # in TL. As Tl only use one thread dimension - # the var binding information will be lost - # in the lowering process with Legalization - # and Simplify pass. - # We can find a way better to create var instead - # of putting the LowerThreadAllreduce before - # the Legalization. - mod = tir.transform.LowerThreadAllreduce()(mod) - mod = tir.transform.ThreadSync("shared.dyn")(mod) - mod = tilelang.transform.LowerHopperIntrin()(mod) - mod = tir.transform.InjectPTXAsyncCopy()(mod) - - mod = tir.transform.AnnotateDeviceRegions()(mod) - mod = tir.transform.SplitHostDevice()(mod) - mod = tir.transform.MergeSharedMemoryAllocations()(mod) - mod = tir.transform.MakePackedAPI()(mod) - mod = tir.transform.LowerDeviceKernelLaunch()(mod) + mod = LowerAndLegalize(mod, target) + mod = OptimizeForTarget(mod, target) device_mod = tir.transform.Filter(is_device_call)(mod)