diff --git a/python/tvm/dlight/__init__.py b/python/tvm/dlight/__init__.py index 421d4017d1bd..bd70acf00f90 100644 --- a/python/tvm/dlight/__init__.py +++ b/python/tvm/dlight/__init__.py @@ -16,12 +16,15 @@ # under the License. """DLight package provides efficient schedules out-of-box for deep learning workloads.""" from . import gpu -from .base import ( - ApplyDefaultSchedule, +from . import cpu +from .analysis import ( BlockInfo, IterInfo, - ScheduleRule, normalize_prim_func, +) +from .base import ( + ApplyDefaultSchedule, + ScheduleRule, try_inline, try_inline_contiguous_spatial, ) diff --git a/python/tvm/dlight/analysis/__init__.py b/python/tvm/dlight/analysis/__init__.py new file mode 100644 index 000000000000..bf68d0855015 --- /dev/null +++ b/python/tvm/dlight/analysis/__init__.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base infra""" +from .common_analysis import ( + BlockInfo, + IterInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + detect_dominant_read, + is_broadcast_epilogue, + normalize_prim_func, + get_root_block, +) +from .gemv import ( + is_gemv, + normalize, +) diff --git a/python/tvm/dlight/base/analysis.py b/python/tvm/dlight/analysis/common_analysis.py similarity index 100% rename from python/tvm/dlight/base/analysis.py rename to python/tvm/dlight/analysis/common_analysis.py diff --git a/python/tvm/dlight/analysis/gemv.py b/python/tvm/dlight/analysis/gemv.py new file mode 100644 index 000000000000..c502081ba320 --- /dev/null +++ b/python/tvm/dlight/analysis/gemv.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Analysis for GEMV.""" +from typing import List, Optional + +from tvm import arith, ir, tir + +from .common_analysis import ( + BlockInfo, + collect_block_iter_vars_used_in_access_region, + collect_vars_used_in_prim_expr, + detect_dominant_read, +) + + +def get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: + """Extracts the reduction expression from a TIR block. + + This function checks whether the given TIR block follows a reduction pattern + of the form `X[...] = X[...] + Y` and returns `Y` as the reduction expression. + + Parameters: + ---------- + block : tir.Block + The TIR block to analyze. + + Returns: + ------- + Optional[tir.PrimExpr] + The reduction expression (`Y`) if detected, otherwise None. + """ + + buffer_store = block.body + if not isinstance(buffer_store, tir.BufferStore): + return None + if not isinstance(buffer_store.value, tir.Add): + return None + if not ir.structural_equal( + buffer_store.value.a, + tir.BufferLoad(buffer_store.buffer, block.body.indices), + map_free_vars=True, + ): + return None + return buffer_store.value.b + + +def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: + """Check if the block is a GEMV. + + Parameters + ---------- + + sch : tir.Schedule + The schedule + + block_info : BlockInfo + The block info to be checked + + + Returns + ------- + ret : Optional[List[tir.Buffer]] + The vector buffers used in the GEMV if it is a GEMV, otherwise None. + """ + block = block_info.block_rv + block_stmt = sch.get(block) + conditions = [] + conditions.append(block_info.is_reduction()) + conditions.append(len(block_stmt.reads) >= 2) + conditions.append(len(block_stmt.writes) == 1) + conditions.append(get_reduction_expr(block_stmt) is not None) + conditions.append( + len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) + > 0 + ) + if not all(conditions): + return None + + iter_num = len(block_stmt.iter_vars) + ret = [ + read.buffer + for read in block_stmt.reads + if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num + and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 + ] + return ret if 0 < len(ret) < len(block_stmt.reads) else None + + +def normalize( + sch: tir.Schedule, + block_info: BlockInfo, +) -> Optional[bool]: + """Normalize the main block.""" + block_stmt: tir.Block = sch.get(block_info.block_rv) + access = arith.normalize_to_iter_sum( + detect_dominant_read(block_stmt), + input_iters={i.var: i.dom for i in block_stmt.iter_vars}, + ) + buffers_use_vars = [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.writes + ] + buffers_use_vars.extend( + [ + collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) + for buf in block_stmt.reads + ] + ) + if collect_vars_used_in_prim_expr(access.base) & set( + iter_var.var for iter_var in block_stmt.iter_vars + ): + return None + iter_to_info = {i.var: i for i in block_info.iters} + batch_loops, s_loops, r_loops, c_loops = [], [], [], [] + inner_axis = access.args[-1].source.source + is_inner_reduction = iter_to_info[inner_axis].kind == "R" + + for split_expr in access.args: + var = split_expr.source.source + info = iter_to_info.get(var) + loop = info.loop_rv + is_reduction = info.kind == "R" + if split_expr.lower_factor > 1: + if c_loops: + return None + loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) + # we only support the reduction dim being grouped atm + if not is_reduction: + return None + c_loops.append(c_loop) + if is_reduction: + r_loops.append(loop) + elif all([var in buf_vars for buf_vars in buffers_use_vars]): + batch_loops.append(loop) + else: + s_loops.append(loop) + + assert s_loops + assert r_loops + if not c_loops: + c_loops = [sch.add_unit_loop(block_info.block_rv)] + if not batch_loops: + batch_loops = [sch.add_unit_loop(block_info.block_rv)] + sch.reorder(*batch_loops, *s_loops, *r_loops, *c_loops) + sch.fuse(*batch_loops) + sch.fuse(*s_loops) + sch.fuse(*r_loops) + return is_inner_reduction diff --git a/python/tvm/dlight/base/__init__.py b/python/tvm/dlight/base/__init__.py index a19a292fa13e..9d90c4f8e171 100644 --- a/python/tvm/dlight/base/__init__.py +++ b/python/tvm/dlight/base/__init__.py @@ -15,15 +15,13 @@ # specific language governing permissions and limitations # under the License. """Base infra""" -from .analysis import ( - BlockInfo, - IterInfo, - collect_block_iter_vars_used_in_access_region, - collect_vars_used_in_prim_expr, - detect_dominant_read, - is_broadcast_epilogue, - normalize_prim_func, -) from .common_schedules import try_inline, try_inline_contiguous_spatial from .schedule_rule import ScheduleRule from .transform import ApplyDefaultSchedule +from .utils import ( + auto_vectorize, + get_bytes, + get_extent, + max_threads_per_block, + suggest_threads_per_block, +) diff --git a/python/tvm/dlight/base/common_schedules.py b/python/tvm/dlight/base/common_schedules.py index fe005cec5d70..c205b78390bc 100644 --- a/python/tvm/dlight/base/common_schedules.py +++ b/python/tvm/dlight/base/common_schedules.py @@ -19,7 +19,7 @@ from tvm import tir -from .analysis import BlockInfo +from ..analysis import BlockInfo def try_inline( diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/base/utils.py similarity index 100% rename from python/tvm/dlight/gpu/utils.py rename to python/tvm/dlight/base/utils.py diff --git a/python/tvm/dlight/cpu/__init__.py b/python/tvm/dlight/cpu/__init__.py new file mode 100644 index 000000000000..3282275862f3 --- /dev/null +++ b/python/tvm/dlight/cpu/__init__.py @@ -0,0 +1,20 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +CPU-generic schedule rules. +""" +from .gemv import GEMV diff --git a/python/tvm/dlight/cpu/base.py b/python/tvm/dlight/cpu/base.py new file mode 100644 index 000000000000..4d16f9726bff --- /dev/null +++ b/python/tvm/dlight/cpu/base.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Base schedule rule for CPU operators.""" + +from tvm.target import Target + +from ..base import ScheduleRule + + +class CPUScheduleRule(ScheduleRule): # pylint: disable=too-few-public-methods + """The Schedule Rule specific to CPU targets, will return None if the target is not CPU.""" + + def is_target_available(self, target: Target) -> bool: + """Check whether the target is available for gpu rule. + + Parameters + ---------- + target : Target + The compilation target to check. + + Returns + ------- + available : bool + Whether the target is available for this rule. + """ + return super().is_target_available(target) and "llvm" == target.kind.name diff --git a/python/tvm/dlight/cpu/gemv.py b/python/tvm/dlight/cpu/gemv.py new file mode 100644 index 000000000000..15b47de919a7 --- /dev/null +++ b/python/tvm/dlight/cpu/gemv.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A rule for GEMV and DecodeGEMV.""" +from typing import List, Optional, Union + +from tvm import tir +from tvm.target import Target + +from ..analysis import BlockInfo, normalize_prim_func +from ..analysis.gemv import is_gemv, normalize +from ..base import get_extent, try_inline_contiguous_spatial +from .base import CPUScheduleRule + + +class GEMV(CPUScheduleRule): + """A rule for GEMV and DecodeGEMV.""" + + def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-statements, no-else-return + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Union[None, tir.Schedule, List[tir.Schedule]]: + if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): + return None + sch = tir.Schedule(func) + block_infos = normalize_prim_func(sch) + block_infos = try_inline_contiguous_spatial(sch, block_infos) + if block_infos is None: + return None + if len(block_infos) == 1: + epilogue = None + elif len(block_infos) == 2: + epilogue = block_infos[1] + if not epilogue.is_injective(): + return None + else: + return None + + block_info = block_infos[0] + if len(block_info.iters) not in [2, 3]: + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + return None + block = block_info.block_rv + vector_input_buffers = is_gemv(sch, block_info) + if vector_input_buffers is None: + return None + + # Step 1. Normalize the block, merge spatial and reduction iters + is_inner_reduction = normalize(sch, block_info) + + # Step 2. Do the scheduling + if is_inner_reduction is None: + return None + elif is_inner_reduction: + return self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) + else: + # sch_outer reduction + return None + + def sch_inner_reduction( # pylint: disable=too-many-arguments, too-many-positional-arguments, invalid-name, unused-argument + self, + sch: tir.Schedule, + target: Target, + block: tir.schedule.BlockRV, + vector_input_buffers: List[tir.Buffer], + epilogue_info: Optional[BlockInfo], + ): + """Schedule the inner reduction block.""" + + def apply( # pylint: disable=unused-variable, too-many-locals + sch: tir.Schedule, + gemv, + vector_width: int = 8, + parallel_threads: int = 8, + unroll_factor: int = 256, + ): + batch, s, r, c = sch.get_loops(block) + len_batch, len_s, len_r, len_c = ( + get_extent(sch, batch), + get_extent(sch, s), + get_extent(sch, r), + get_extent(sch, c), + ) + len_S = len_batch * len_s + len_R = len_r * len_c + + if isinstance(len_S, int) and isinstance(len_R, int): + if len_S > len_R: + tile_s, tile_r = 128, 64 # Larger tiling for s-axis when len_S is larger + else: + tile_s, tile_r = 64, 128 # Larger tiling for r-axis when len_R is larger + else: + tile_s, tile_r = 64, 64 # Default tile sizes for unknown extents + + tile_c = min(vector_width, len_c) # Ensure c-axis tiling aligns with SIMD vector width + + # Apply loop tiling (improves cache locality) + s_outer, s_inner = sch.split(s, factors=[None, tile_s]) + r_outer, r_inner = sch.split(r, factors=[None, tile_r]) + c_outer, c_inner = sch.split(c, factors=[None, tile_c]) + + # Apply vectorization (SIMD optimization) + sch.vectorize(s_inner) # Vectorize computation along c-axis for AVX/NEON + + # Enable parallel execution + sch.parallel(s_outer) # Parallelize along the s-axis (major computation loop) + + # Apply loop unrolling for better CPU performance + sch.annotate(r_outer, "pragma_auto_unroll_max_step", unroll_factor) + sch.annotate(r_outer, "pragma_unroll_explicit", 1) + return sch + + return apply( + sch, + gemv=block, + ) diff --git a/python/tvm/dlight/gpu/fallback.py b/python/tvm/dlight/gpu/fallback.py index 7139c7ea4199..bcbfda791fb3 100644 --- a/python/tvm/dlight/gpu/fallback.py +++ b/python/tvm/dlight/gpu/fallback.py @@ -21,8 +21,9 @@ from tvm import tir from tvm.target import Target -from ..base import normalize_prim_func, try_inline -from . import utils +from .. import base +from ..analysis import normalize_prim_func +from ..base import try_inline from .base import GPUScheduleRule @@ -40,7 +41,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring ) -> tir.Schedule: if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None - max_threads_per_block = utils.max_threads_per_block(target) + max_threads_per_block = base.max_threads_per_block(target) sch = tir.Schedule(func) block_infos = normalize_prim_func(sch) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index cff234140e50..ebb19ad72c3a 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -18,140 +18,18 @@ from functools import reduce from typing import List, Optional, Union -from tvm import arith, ir, tir +from tvm import tir from tvm.target import Target -from ..base import ( +from ..analysis import ( BlockInfo, - collect_block_iter_vars_used_in_access_region, - collect_vars_used_in_prim_expr, - detect_dominant_read, is_broadcast_epilogue, + is_gemv, + normalize, normalize_prim_func, - try_inline_contiguous_spatial, ) +from ..base import auto_vectorize, get_bytes, get_extent, try_inline_contiguous_spatial from .base import GPUScheduleRule -from .utils import auto_vectorize, get_bytes, get_extent - - -def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: - # Detect and return `Y` in `X[...] = X[...] + Y` - buffer_store = block.body - if not isinstance(buffer_store, tir.BufferStore): - return None - if not isinstance(buffer_store.value, tir.Add): - return None - if not ir.structural_equal( - buffer_store.value.a, - tir.BufferLoad(buffer_store.buffer, block.body.indices), - map_free_vars=True, - ): - return None - return buffer_store.value.b - - -def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: - """Check if the block is a GEMV. - - Parameters - ---------- - - sch : tir.Schedule - The schedule - - block_info : BlockInfo - The block info to be checked - - - Returns - ------- - ret : Optional[List[tir.Buffer]] - The vector buffers used in the GEMV if it is a GEMV, otherwise None. - """ - block = block_info.block_rv - block_stmt = sch.get(block) - conditions = [] - conditions.append(block_info.is_reduction()) - conditions.append(len(block_stmt.reads) >= 2) - conditions.append(len(block_stmt.writes) == 1) - conditions.append(_get_reduction_expr(block_stmt) is not None) - conditions.append( - len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region)) - > 0 - ) - if not all(conditions): - return None - - iter_num = len(block_stmt.iter_vars) - ret = [ - read.buffer - for read in block_stmt.reads - if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num - and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0 - ] - return ret if 0 < len(ret) < len(block_stmt.reads) else None - - -def normalize( - sch: tir.Schedule, - block_info: BlockInfo, -) -> Optional[bool]: - """Normalize the main block.""" - block_stmt: tir.Block = sch.get(block_info.block_rv) - access = arith.normalize_to_iter_sum( - detect_dominant_read(block_stmt), - input_iters={i.var: i.dom for i in block_stmt.iter_vars}, - ) - buffers_use_vars = [ - collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) - for buf in block_stmt.writes - ] - buffers_use_vars.extend( - [ - collect_block_iter_vars_used_in_access_region(block_stmt, buf.region) - for buf in block_stmt.reads - ] - ) - if collect_vars_used_in_prim_expr(access.base) & set( - iter_var.var for iter_var in block_stmt.iter_vars - ): - return None - iter_to_info = {i.var: i for i in block_info.iters} - batch_loops, s_loops, r_loops, c_loops = [], [], [], [] - inner_axis = access.args[-1].source.source - is_inner_reduction = iter_to_info[inner_axis].kind == "R" - - for split_expr in access.args: - var = split_expr.source.source - info = iter_to_info.get(var) - loop = info.loop_rv - is_reduction = info.kind == "R" - if split_expr.lower_factor > 1: - if c_loops: - return None - loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) - # we only support the reduction dim being grouped atm - if not is_reduction: - return None - c_loops.append(c_loop) - if is_reduction: - r_loops.append(loop) - elif all([var in buf_vars for buf_vars in buffers_use_vars]): - batch_loops.append(loop) - else: - s_loops.append(loop) - - assert s_loops - assert r_loops - if not c_loops: - c_loops = [sch.add_unit_loop(block_info.block_rv)] - if not batch_loops: - batch_loops = [sch.add_unit_loop(block_info.block_rv)] - sch.reorder(*batch_loops, *s_loops, *r_loops, *c_loops) - sch.fuse(*batch_loops) - sch.fuse(*s_loops) - sch.fuse(*r_loops) - return is_inner_reduction class GEMV(GPUScheduleRule): diff --git a/python/tvm/dlight/gpu/general_reduction.py b/python/tvm/dlight/gpu/general_reduction.py index 404b73a6f0cc..a068e732b986 100644 --- a/python/tvm/dlight/gpu/general_reduction.py +++ b/python/tvm/dlight/gpu/general_reduction.py @@ -21,7 +21,8 @@ from tvm import arith, tir from tvm.target import Target -from ..base import normalize_prim_func, try_inline_contiguous_spatial +from ..analysis import normalize_prim_func +from ..base import try_inline_contiguous_spatial from .base import GPUScheduleRule diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py b/python/tvm/dlight/gpu/low_batch_gemv.py index b528086a1626..f5e3669ad0f3 100644 --- a/python/tvm/dlight/gpu/low_batch_gemv.py +++ b/python/tvm/dlight/gpu/low_batch_gemv.py @@ -21,16 +21,15 @@ from tvm import arith, ir, tir from tvm.target import Target -from ..base import ( +from ..analysis import ( BlockInfo, collect_block_iter_vars_used_in_access_region, collect_vars_used_in_prim_expr, is_broadcast_epilogue, normalize_prim_func, - try_inline_contiguous_spatial, ) +from ..base import auto_vectorize, get_bytes, get_extent, try_inline_contiguous_spatial from .base import GPUScheduleRule -from .utils import auto_vectorize, get_bytes, get_extent def _get_reduction_expr(block: tir.Block) -> Optional[tir.PrimExpr]: diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index d9d4b7ebd4d2..368552c88d43 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -22,13 +22,13 @@ from tvm import tir from tvm.ir import Range +from tvm.script import tir as T from tvm.target import Target from tvm.tir import IterVar, PrimExpr, Var from tvm.tir.analysis import undefined_vars from tvm.tir.schedule.schedule import BlockRV -from tvm.script import tir as T -from ..base import analysis, BlockInfo, IterInfo +from ..analysis import BlockInfo, IterInfo, get_root_block from .base import GPUScheduleRule @@ -358,7 +358,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) + root_block = get_root_block(sch) blocks = sch.get_child_blocks(root_block) reduction_blocks = get_reduction_blocks(sch, blocks) @@ -499,7 +499,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) + root_block = get_root_block(sch) blocks = sch.get_child_blocks(root_block) if "dlight.do_not_tensorize" in func.attrs.keys(): @@ -720,7 +720,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): return None sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) + root_block = get_root_block(sch) blocks = sch.get_child_blocks(root_block) if "dlight.do_not_tensorize" in func.attrs.keys(): @@ -971,7 +971,7 @@ def apply( # pylint: disable=too-many-locals,missing-docstring return None sch = tir.Schedule(func) config = self.get_configs(target) - root_block = analysis.get_root_block(sch) + root_block = get_root_block(sch) blocks = sch.get_child_blocks(root_block) reduction_blocks = get_reduction_blocks(sch, blocks) @@ -1130,7 +1130,6 @@ def sch_outer_reduction( reduction_block: tir.schedule.BlockRV, blocks: List[tir.schedule.BlockRV], ) -> Optional[tir.Schedule]: - """Get vectorization factor""" def get_max_factor(n, factors): diff --git a/python/tvm/dlight/gpu/reduction.py b/python/tvm/dlight/gpu/reduction.py index fc63e4836849..9851bb9800fa 100644 --- a/python/tvm/dlight/gpu/reduction.py +++ b/python/tvm/dlight/gpu/reduction.py @@ -21,14 +21,13 @@ from tvm import arith, ir, tir from tvm.target import Target -from ..base import ( +from ..analysis import ( BlockInfo, detect_dominant_read, is_broadcast_epilogue, normalize_prim_func, - try_inline_contiguous_spatial, ) -from . import utils +from ..base import suggest_threads_per_block, try_inline_contiguous_spatial from .base import GPUScheduleRule @@ -181,7 +180,7 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments ): # pylint: disable=invalid-name _, r, _ = sch.get_loops(block) - (len_tx,) = utils.suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking + (len_tx,) = suggest_threads_per_block( # pylint: disable=unbalanced-tuple-unpacking target, [sch.get(r)] ) diff --git a/python/tvm/dlight/gpu/rmsnorm.py b/python/tvm/dlight/gpu/rmsnorm.py index 4047721c9aa8..5dc6887c782c 100644 --- a/python/tvm/dlight/gpu/rmsnorm.py +++ b/python/tvm/dlight/gpu/rmsnorm.py @@ -19,9 +19,9 @@ import tvm from tvm import tir -from tvm.tir import Block, BufferStore -from tvm.tir.expr import Cast, BufferLoad, Call from tvm.target import Target +from tvm.tir import Block, BufferStore +from tvm.tir.expr import BufferLoad, Call, Cast from ..base import ScheduleRule diff --git a/python/tvm/dlight/gpu/transpose.py b/python/tvm/dlight/gpu/transpose.py index 3bef3d61e536..125af538cdb8 100644 --- a/python/tvm/dlight/gpu/transpose.py +++ b/python/tvm/dlight/gpu/transpose.py @@ -22,11 +22,8 @@ from tvm.tir import Schedule from tvm.tir.schedule import BlockRV -from ..base import ( - detect_dominant_read, - normalize_prim_func, - try_inline_contiguous_spatial, -) +from ..analysis import detect_dominant_read, normalize_prim_func +from ..base import try_inline_contiguous_spatial from .base import GPUScheduleRule diff --git a/python/tvm/relax/frontend/nn/llm/tree_attn.py b/python/tvm/relax/frontend/nn/llm/tree_attn.py index 33614633fc77..9aa27ca83d70 100644 --- a/python/tvm/relax/frontend/nn/llm/tree_attn.py +++ b/python/tvm/relax/frontend/nn/llm/tree_attn.py @@ -762,6 +762,8 @@ def tree_attn_paged_kv_cpu( for h_qo in T.serial(h_q): for b_idx in T.serial(batch_size): with T.block("attn"): + T.reads() + T.writes() O_local = T.alloc_buffer((d, ), "float32") Q_local = T.alloc_buffer((d, ), "float32") K_local = T.alloc_buffer((d, ), "float32") diff --git a/tests/python/dlight/test_cpu_gemv.py b/tests/python/dlight/test_cpu_gemv.py new file mode 100644 index 000000000000..c7eebf58aa9e --- /dev/null +++ b/tests/python/dlight/test_cpu_gemv.py @@ -0,0 +1,595 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import pytest + +import tvm.testing +from tvm import dlight as dl +from tvm.script import tir as T +from tvm.target import Target + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + @pytest.fixture + def transform(self): + def transform(mod): + with Target("llvm"): + return dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + + return transform + + +class TestGEMV(BaseBeforeAfter): + # fmt: off + + @T.prim_func + def before(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int32() + lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") + lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + var_T_divide_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + var_T_maximum_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + var_T_minimum_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + for i0, i1, i2, i3, k in T.grid(1, 32, 1, n, 128): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv1637[v_i0, v_i1, v_i2, v_k], lv1638[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1637[v_i0, v_i1, v_i2, v_k] * lv1638[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float16(0.088397790055248615) + for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float16(-65504)) + for ax0, ax1, ax2, ax3 in T.grid(1, 32, 1, n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1614[v_ax0, 0, v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1614[v_ax0, 0, v_ax2, v_ax3]) + for i0, i1, i2, i3 in T.grid(1, 32, 1, n): + with T.block("compute"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + T.writes(var_compute_intermediate[v_i0, v_i1, v_i2, v_i3]) + var_compute_intermediate[v_i0, v_i1, v_i2, v_i3] = T.Cast("float32", var_T_minimum_intermediate[v_i0, v_i1, v_i2, v_i3]) + + @T.prim_func + def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p_lv1614: T.handle, p_output0: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int32() + lv1638 = T.match_buffer(p_lv1638, (1, 32, n, 128), "float16") + lv1614 = T.match_buffer(p_lv1614, (1, 1, 1, n), "float16") + var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((1, 32, 1, n), "float16") + for ax0_fused in range(32): + for ax1_fused_0 in T.parallel((n + 63) // 64): + for ax1_fused_1 in T.vectorized(64): + for ax2_fused_0 in T.serial(2, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_1, u_0, u_1 in T.grid(64, 1, 1): + with T.block("NT_matmul"): + v0 = T.axis.spatial(32, ax0_fused) + v1 = T.axis.spatial(n, ax1_fused_0 * 64 + ax1_fused_1) + v2 = T.axis.reduce(128, ax2_fused_0 * 64 + ax2_fused_1) + T.where(ax1_fused_0 * 64 + ax1_fused_1 < n) + T.reads(lv1637[0, v0, 0, v2], lv1638[0, v0, v1, v2]) + T.writes(var_NT_matmul_intermediate[0, v0, 0, v1]) + with T.init(): + var_NT_matmul_intermediate[0, v0, 0, v1] = T.float16(0.0) + var_NT_matmul_intermediate[0, v0, 0, v1] = var_NT_matmul_intermediate[0, v0, 0, v1] + lv1637[0, v0, 0, v2] * lv1638[0, v0, v1, v2] + for ax0, ax1 in T.grid(32, n): + with T.block("compute"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(var_NT_matmul_intermediate[0, v0, 0, v1], lv1614[0, 0, 0, v1]) + T.writes(var_compute_intermediate[0, v0, 0, v1]) + var_compute_intermediate[0, v0, 0, v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate[0, v0, 0, v1] * T.float16(0.088397790055248615), T.float16(-65504.0)), lv1614[0, 0, 0, v1])) + + # fmt: on + + +def test_decode_gemv_256_threads(): + # fmt: off + @T.prim_func(private=True) + def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate = T.alloc_buffer((22016, 4096), "float16") + for i, j in T.grid(22016, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv571[v_i, v_j // 8], lv572[v_i, v_j // 32]) + T.writes(p_output0_intermediate[v_i, v_j]) + p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i, v_j // 32] + for i0, i1, i2, k in T.grid(1, 1, 22016, 4096): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k] + + @T.prim_func(private=True) + def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for u_fused in range(1): + for ax0_fused_0 in T.parallel(172): + for ax0_fused_1 in T.vectorized(128): + for ax1_0_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(64, 1, 8): + with T.block("NT_matmul"): + v0 = T.axis.spatial(22016, ax0_fused_0 * 128 + ax0_fused_1) + v1 = T.axis.reduce(4096, ax1_0_fused_0 * 512 + ax1_0_fused_1 * 8 + ax1_1_0 * 8 + ax1_1_1) + T.reads(lv1654[0, 0, v1], lv571[v0, v1 // 8], lv572[v0, v1 // 32]) + T.writes(var_NT_matmul_intermediate[0, 0, v0]) + with T.init(): + var_NT_matmul_intermediate[0, 0, v0] = T.float16(0.0) + var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + lv1654[0, 0, v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v0, v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv572[v0, v1 // 32]) + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_decode_gemv1(): + # fmt: off + + @T.prim_func(private=True) + def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate = T.alloc_buffer((22016, 4096), "float16") + for i, j in T.grid(22016, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv571[v_i, v_j // 8], lv572[v_i, v_j // 32]) + T.writes(p_output0_intermediate[v_i, v_j]) + p_output0_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v_i, v_j // 32] + for i0, i1, i2, k in T.grid(1, 1, 22016, 4096): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1654[v_i0, v_i1, v_k], p_output0_intermediate[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k] + + @T.prim_func(private=True) + def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + for u_fused in range(1): + for ax0_fused_0 in T.parallel(172): + for ax0_fused_1 in T.vectorized(128): + for ax1_0_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(64, 1, 8): + with T.block("NT_matmul"): + v0 = T.axis.spatial(22016, ax0_fused_0 * 128 + ax0_fused_1) + v1 = T.axis.reduce(4096, ax1_0_fused_0 * 512 + ax1_0_fused_1 * 8 + ax1_1_0 * 8 + ax1_1_1) + T.reads(lv1654[0, 0, v1], lv571[v0, v1 // 8], lv572[v0, v1 // 32]) + T.writes(var_NT_matmul_intermediate[0, 0, v0]) + with T.init(): + var_NT_matmul_intermediate[0, 0, v0] = T.float16(0.0) + var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + lv1654[0, 0, v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v0, v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv572[v0, v1 // 32]) + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_decode_gemv2(): + # fmt: off + + @T.prim_func(private=True) + def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((32000, 4096), "float16") + var_NT_matmul_intermediate = T.alloc_buffer((1, 1, 32000), "float16") + for i, j in T.grid(32000, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv771[v_i, v_j // 8], lv772[v_i, v_j // 32]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v_i, v_j // 8], T.Cast("uint32", v_j % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v_i, v_j // 32] + for i0, i1, i2, k in T.grid(1, 1, 32000, 4096): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv3216[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv3216[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] + for i0, i1, i2 in T.grid(1, 1, 32000): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func(private=True) + def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((1, 1, 32000), "float16") + for u_fused in range(1): + for ax0_fused_0 in T.parallel(250): + for ax0_fused_1 in T.vectorized(128): + for ax1_0_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(64, 1, 8): + with T.block("NT_matmul"): + v0 = T.axis.spatial(32000, ax0_fused_0 * 128 + ax0_fused_1) + v1 = T.axis.reduce(4096, ax1_0_fused_0 * 512 + ax1_0_fused_1 * 8 + ax1_1_0 * 8 + ax1_1_1) + T.reads(lv3216[0, 0, v1], lv771[v0, v1 // 8], lv772[v0, v1 // 32]) + T.writes(var_NT_matmul_intermediate[0, 0, v0]) + with T.init(): + var_NT_matmul_intermediate[0, 0, v0] = T.float16(0.0) + var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + lv3216[0, 0, v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v0, v1 // 8], T.Cast("uint32", v1 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv772[v0, v1 // 32]) + for ax0 in range(32000): + with T.block("compute"): + v0 = T.axis.spatial(32000, ax0) + T.reads(var_NT_matmul_intermediate[0, 0, v0]) + T.writes(p_output0_intermediate[0, 0, v0]) + p_output0_intermediate[0, 0, v0] = T.Cast("float32", var_NT_matmul_intermediate[0, 0, v0]) + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_decode_gemv3(): + # fmt: off + + @T.prim_func(private=True) + def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv575[v_i, v_j // T.int64(8)], lv576[v_i, v_j // T.int64(32)]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv570[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + @T.prim_func(private=True) + def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for u_fused in range(1): + for ax0_fused_0 in T.parallel(T.int64(64)): + for ax0_fused_1 in T.vectorized(T.int64(64)): + for ax1_0_fused_0 in T.serial(T.int64(11), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_0_fused_1, ax1_1_0, ax1_1_1 in T.grid(T.int64(128), T.int64(1), T.int64(8)): + with T.block("NT_matmul"): + v0 = T.axis.spatial(T.int64(4096), ax0_fused_0 * T.int64(64) + ax0_fused_1) + v1 = T.axis.reduce(T.int64(11008), (ax1_0_fused_0 * T.int64(128) + ax1_0_fused_1) * T.int64(8) + ax1_1_0 * T.int64(8) + ax1_1_1) + T.where(ax1_0_fused_0 * T.int64(128) + ax1_0_fused_1 < T.int64(1376)) + T.reads(lv574[T.int64(0), T.int64(0), v1], lv575[v0, v1 // T.int64(8)], lv576[v0, v1 // T.int64(32)]) + T.writes(var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0]) + with T.init(): + var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = T.float16(0.0) + var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] = var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] + lv574[T.int64(0), T.int64(0), v1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v0, v1 // T.int64(8)], T.Cast("uint32", v1 % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv576[v0, v1 // T.int64(32)]) + for ax0 in range(T.int64(4096)): + with T.block("T_add"): + v0 = T.axis.spatial(T.int64(4096), ax0) + T.reads(lv570[T.int64(0), T.int64(0), v0], var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0]) + T.writes(p_output0_intermediate[T.int64(0), T.int64(0), v0]) + p_output0_intermediate[T.int64(0), T.int64(0), v0] = lv570[T.int64(0), T.int64(0), v0] + var_NT_matmul_intermediate[T.int64(0), T.int64(0), v0] + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_autogptq_decode_gemv(): + # fmt: off + @T.prim_func(private=True) + def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer((T.int64(32), T.int64(512)), "uint32"), lv11: T.Buffer((T.int64(32), T.int64(4096)), "float16"), lv12: T.Buffer((T.int64(4096),), "uint32"), lv8: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1613: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv9[v_i // T.int64(8), v_j], lv10[lv12[v_i], v_j // T.int64(8)], lv12[v_i], lv11[lv12[v_i], v_j]) + T.writes(decode_intermediate[v_i, v_j]) + decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv9[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) - (T.Cast("float16", T.bitwise_and(T.shift_right(lv10[lv12[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) + T.float16(1))) * lv11[lv12[v_i], v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv8[v_i0, v_i1, v_k], decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv8[v_i0, v_i1, v_k] * decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv1613[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1613[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + # fmt: on + + # The GeMV rule does not yet support the inner dim being grouped. + # So the rule is expected to skip transforming this function. + mod = tvm.IRModule({"main": func}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], func) + + +def test_outer_reduction_adreno(): + # fmt: off + @T.prim_func(private=True) + def before( + lv575: T.Buffer((1376, 4096), "uint32"), + lv576: T.Buffer((344, 4096), "float16"), + lv574: T.Buffer((1, 1, 11008), "float16"), + lv570: T.Buffer((1, 1, 4096), "float16"), + p_output0_intermediate: T.Buffer((1, 1, 4096), "float16"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") + var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") + for i, j in T.grid(11008, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15)))- T.float16(7)) * lv576[v_i // 32, v_j] + for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(1, 1, 4096): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + @T.prim_func(private=True) + def expected(lv575: T.Buffer((1376, 4096), "uint32"), lv576: T.Buffer((344, 4096), "float16"), lv574: T.Buffer((1, 1, 11008), "float16"), lv570: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 4096), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((11008, 4096), "float16") + var_matmul_intermediate = T.alloc_buffer((1, 1, 4096), "float16") + for i, j in T.grid(11008, 4096): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv575[v_i // 8, v_j], lv576[v_i // 32, v_j]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i // 8, v_j], T.Cast("uint32", v_i % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv576[v_i // 32, v_j] + for i0, i1, i2, k in T.grid(1, 1, 4096, 11008): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(1, 1, 4096): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv570[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_outer_reduction_adreno_dynamic(): + # fmt: off + @T.prim_func(private=True) + def before(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + v = T.int64() + lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") + lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") + p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), v), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16") + for i, j in T.grid(T.int64(4096), v): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv612[v_i // T.int64(8), v_j], lv613[v_i // T.int64(32), v_j]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv612[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv613[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), v, T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1607[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1607[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), v): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func(private=True) + def expected(p_lv612: T.handle, p_lv613: T.handle, lv1607: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + v = T.int64() + lv612 = T.match_buffer(p_lv612, (T.int64(512), v), "uint32") + lv613 = T.match_buffer(p_lv613, (T.int64(128), v), "float16") + p_output0_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(1), v)) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), v), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), v), "float16") + for i, j in T.grid(T.int64(4096), v): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv612[v_i // T.int64(8), v_j], lv613[v_i // T.int64(32), v_j]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv612[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7.0)) * lv613[v_i // T.int64(32), v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), v, T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv1607[v_i0, v_i1, v_k], p_output0_intermediate_1[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0.0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv1607[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_k, v_i2] + for i0, i1, i2 in T.grid(T.int64(1), T.int64(1), v): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(var_matmul_intermediate[v_i0, v_i1, v_i2]) + T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) + p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_matmul_intermediate[v_i0, v_i1, v_i2]) + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_blockized_gemv(): + # fmt: off + @T.prim_func(private=True) + def before(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): + # with T.block("root"): + for expert_id in T.thread_binding(2, thread="blockIdx.y"): + with T.block("gemv_o"): + v_expert_id_o = T.axis.spatial(2, expert_id) + vi_o = T.axis.spatial(1, 0) + vj_o = T.axis.reduce(1, 0) + T.reads(x[0, 0:4096], w[indptr[v_expert_id_o], 0:16384, 0:4096], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, 0:16384]) + for i, j in T.grid(16384, 4096): + with T.block("gemv"): + vi_i, vj_i = T.axis.remap("SR", [i, j]) + T.reads(x[0, vj_i], w[indptr[v_expert_id_o], vi_i, vj_i], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, vi_i]) + with T.init(): + o[v_expert_id_o, vi_i] = T.float16(0) + o[v_expert_id_o, vi_i] = o[v_expert_id_o, vi_i] + x[0, vj_i] * w[indptr[v_expert_id_o], vi_i, vj_i] + + @T.prim_func(private=True) + def expected(x: T.Buffer((1, 4096), "float16"), w: T.Buffer((8, 16384, 4096), "float16"), indptr: T.Buffer((2,), "int32"), o: T.Buffer((2, 16384), "float16")): + T.func_attr({"tir.is_scheduled": 1}) + # with T.block("root"): + for expert_id in T.thread_binding(2, thread="blockIdx.y"): + with T.block("gemv_o"): + v_expert_id_o = T.axis.spatial(2, expert_id) + vi_o = T.axis.spatial(1, 0) + vj_o = T.axis.reduce(1, 0) + T.reads(x[0, 0:4096], w[indptr[v_expert_id_o], 0:16384, 0:4096], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, 0:16384]) + for u_fused in range(1): + for ax0_fused_0 in T.parallel(128): + for ax0_fused_1 in T.vectorized(128): + for ax1_fused_0 in T.serial(64, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax1_fused_1, u_0, u_1 in T.grid(64, 1, 1): + with T.block("gemv"): + v0 = T.axis.spatial(16384, ax0_fused_0 * 128 + ax0_fused_1) + v1 = T.axis.reduce(4096, ax1_fused_0 * 64 + ax1_fused_1) + T.reads(x[0, v1], w[indptr[v_expert_id_o], v0, v1], indptr[v_expert_id_o]) + T.writes(o[v_expert_id_o, v0]) + with T.init(): + o[v_expert_id_o, v0] = T.float16(0.0) + o[v_expert_id_o, v0] = o[v_expert_id_o, v0] + x[0, v1] * w[indptr[v_expert_id_o], v0, v1] + # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_func_to_skip(): + @T.prim_func + def before(var_A: T.handle, var_exclusive_scan_thrust: T.handle, seq_len: T.int64): + data_buf = T.match_buffer(var_A, (seq_len * T.int64(8),), "int32", align=8) + output_buf = T.match_buffer( + var_exclusive_scan_thrust, (seq_len * T.int64(8),), "int32", align=8 + ) + with T.block("exclusive_scan_thrust"): + T.reads() + T.writes() + T.call_packed( + "tvm.contrib.thrust.sum_scan", + T.tvm_stack_make_array( + data_buf.data, T.tvm_stack_make_shape(seq_len * T.int64(8)), 0, 1, 0, T.int64(0) + ), + T.tvm_stack_make_array( + output_buf.data, + T.tvm_stack_make_shape(seq_len * T.int64(8)), + 0, + 1, + 0, + T.int64(0), + ), + T.bool(False), + ) + + # This function should be skipped. + mod = tvm.IRModule({"main": before}) + with Target("llvm"): + mod = dl.ApplyDefaultSchedule(dl.cpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], before) + + +if __name__ == "__main__": + tvm.testing.main()