diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index 735d5ba6c0a1..3adb1bf5c204 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -455,12 +455,13 @@ inline Var Handle(runtime::DataType dtype = runtime::DataType::Void(), return is_size_var ? tvm::tir::SizeVar("", type_annotation) : tvm::tir::Var("", type_annotation); } -#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ - inline PrimExpr FuncName(Optional expr = NullOpt, bool is_size_var = false) { \ - DataType dtype = DType; \ - return expr.defined() \ - ? tvm::cast(dtype, expr.value()) \ - : (is_size_var ? tvm::tir::SizeVar("", dtype) : tvm::tir::Var("", dtype)); \ +#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName, DType) \ + inline PrimExpr FuncName(Optional expr = NullOpt, bool is_size_var = false, \ + int64_t min_value = 0) { \ + DataType dtype = DType; \ + return expr.defined() ? tvm::cast(dtype, expr.value()) \ + : (is_size_var ? tvm::tir::SizeVar("", dtype, min_value) \ + : tvm::tir::Var("", dtype)); \ } #define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \ diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 6c2c6dd5fc86..67e291acf161 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -142,6 +142,12 @@ class Var : public PrimExpr { */ class SizeVarNode : public VarNode { public: + int64_t min_value; + void VisitAttrs(tvm::AttrVisitor* v) { + VarNode::VisitAttrs(v); + v->Visit("min_value", &min_value); + } + static constexpr const char* _type_key = "tir.SizeVar"; TVM_DECLARE_FINAL_OBJECT_INFO(SizeVarNode, VarNode); }; @@ -157,14 +163,15 @@ class SizeVar : public Var { * \param span The location of this object in the source code. */ TVM_DLL explicit SizeVar(String name_hint = "s", DataType t = DataType::Int(32), - Span span = Span()); + int64_t min_value = 0, Span span = Span()); /*! * \brief Constructor which provides a more detailed type annotation. * \param name_hint variable name. * \param type_annotation The type annotation. * \param span The location of this object in the source code. */ - TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, Span span = Span()); + TVM_DLL explicit SizeVar(String name_hint, Type type_annotation, int64_t min_value = 0, + Span span = Span()); /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index b2bad2ec0646..1e3638b58a70 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -195,7 +195,7 @@ def canonical_simplify(self, expr): """ return self._canonical_simplify(expr) - def int_set(self, expr, dom_map): + def int_set(self, expr, dom_map=None): """Compute a symbolic IntSet that covers expr for all values in dom_map. Parameters @@ -203,8 +203,9 @@ def int_set(self, expr, dom_map): expr : PrimExpr The expression. - dom_map : Dict[Var, tvm.arith.IntSet] - The domain for variables to be relaxed. + dom_map : Optional[Dict[Var, tvm.arith.IntSet]] + The domain for variables to be relaxed. If None, use the domain map defined by bound + variables. Returns ------- diff --git a/python/tvm/dlight/gpu/__init__.py b/python/tvm/dlight/gpu/__init__.py index a09d6b8704d8..b70e5a371c8c 100644 --- a/python/tvm/dlight/gpu/__init__.py +++ b/python/tvm/dlight/gpu/__init__.py @@ -18,11 +18,14 @@ GPU-generic schedule rules. For CUDA/ROCm/Vulkan/Metal-specific rules, use `tvm.dlight.cuda/rocm/vulkan/metal` instead """ -from .gemv import GEMV from .fallback import Fallback -from .matmul import Matmul, MatmulWMMATensorization, MatmulMMATensorization +from .gemv import GEMV +from .general_reduction import GeneralReduction +from .matmul import ( + Matmul, + MatmulTensorizationMMA, + MatmulTensorizationWMMA, + MatmulTensorizationLegacy, +) from .reduction import Reduction from .transpose import Transpose -from .general_reduction import GeneralReduction -from .element_wise import ElementWise -from .rmsnorm import RMSNorm diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index e3fdad7d16a0..f3eaa58ddd22 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -17,1342 +17,25 @@ # pylint: disable=missing-docstring, invalid-name """A GEMM schedule rule for GPU operators.""" from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple -from functools import reduce +from typing import Optional from tvm import tir from tvm.target import Target -from tvm.tir import PrimExpr -from tvm.tir import IndexMap +from tvm.tir.stmt import ForKind + +from ..base import analysis from .base import GPUScheduleRule -from ..base import ScheduleRule, analysis -from ..base.roller.rasterization import NoRasterization -from ..base.analysis import ( - IterKind, - IterTrait, - detect_iter_traits, - get_reduction_blocks, - get_index_map, +from . import utils +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_producers, get_in_out_dtypes, - normalize_to_matmul, + get_index_map, + get_reduction_blocks, ) - - -def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): - result = [] - for producer in sch.get_producers(block): - result.append(producer) - result.extend(_collect_producers(sch, producer)) - return result - - -def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): - result = [] - for consumer in sch.get_consumers(block): - result.append(consumer) - result.extend(_collect_consumers(sch, consumer)) - return result - - -def auto_inline_producers( - sch: tir.Schedule, - block: tir.schedule.BlockRV, -): - while True: - inlined_cnt = 0 - producers = _collect_producers(sch, block) - for producer in producers: - try: - sch.compute_inline(producer) - inlined_cnt += 1 - except: # pylint: disable=bare-except - continue - if inlined_cnt == 0: - return - - -def auto_inline_consumers( - sch: tir.Schedule, - block: tir.schedule.BlockRV, -): - while True: - inlined_cnt = 0 - consumers = _collect_consumers(sch, block) - for consumer in consumers: - try: - sch.compute_inline(consumer) - inlined_cnt += 1 - except: # pylint: disable=bare-except - continue - for consumer in consumers: - try: - sch.reverse_compute_inline(consumer) - inlined_cnt += 1 - except: # pylint: disable=bare-except - continue - if inlined_cnt == 0: - return - - -def auto_inline_consumer_chain( - sch: tir.Schedule, - block: tir.schedule.BlockRV, -): - auto_inline_consumers(sch, block) - remaining_consumers = sch.get_consumers(block) - - if len(remaining_consumers) != 0: - # Some blocks have failed to be inlined to the producer cache-write stage. - # This could be due to another producer block that has not been scheduled. - for c in remaining_consumers: - for p in sch.get_producers(c): - if sch.get(p) != sch.get(block): - auto_inline_producers(sch, p) - sch.compute_inline(p) - - # Try inlining into the cache-write stage again, this time it should succeed. - auto_inline_consumers(sch, block) - - msg = "There are some consumers of the cache-write stage that are not properly inlined." - assert len(sch.get_consumers(block)) == 0, msg - - -def check_sm_version(arch: str) -> int: - sm_version = arch.replace("sm_", "") - return int(sm_version) if sm_version.isdigit() else -1 - - -class MatmulAdvancedTensorizationMMA(ScheduleRule): - """ - The advanced schedule rule for float16 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Optional[tir.Schedule]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - LDMATRIX_f16_A_INTRIN, - LDMATRIX_f16_B_INTRIN, - LDMATRIX_f16_B_TRANS_INTRIN, - MMA_f16f16f16_INTRIN, - MMA_f16f16f16_TRANS_B_INTRIN, - MMA_fill_16x16_f16_INTRIN, - MMA_store_16x16_f16_shared_INTRIN, - shared_16x16_to_mma_32x8_layout, - ldmatrix_32x8_to_shared_16x16_layout, - ldmatrix_trans_32x8_to_shared_16x16_layout, - ) - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - def make_iter_fusion_index_map( - traits: List[IterTrait], - kind_order: List[IterKind], - keep_last_dims: int = 0, - ) -> tir.IndexMap: - fused_iters: Dict[IterKind, PrimExpr] = {} - keep_iters: List[tir.Var] = [] - input_iters: List[tir.Var] = [] - for i, trait in enumerate(traits): - v_i = tir.Var(f"i{i}", trait.extent.dtype) - input_iters.append(v_i) - if trait.kind == IterKind.kIter_T: - continue - if trait.kind not in kind_order: - raise ValueError(f"Unknown iter kind {trait.kind}") - if i + keep_last_dims < len(traits): - if trait.kind in fused_iters: - fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i - else: - fused_iters[trait.kind] = v_i - else: - keep_iters.append(v_i) - final_indices: List[tir.PrimExpr] = [ - fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order - ] - final_indices.extend(keep_iters) - - return tir.IndexMap(input_iters, final_indices, None) - - def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]: - """Get index maps for the block - - Parameters - ---------- - block : tir.Block - The block to be analyzed - - Returns - ------- - index_maps : Optional[Tuple[tir.IndexMap]] - The index maps for the block, or None if the block is not a gemm-liked kernel - """ - traits = detect_iter_traits(block) - if traits is None: - return None - A_traits, B_traits, C_traits, block_traits = traits - - A_index_map = make_iter_fusion_index_map( - A_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_K], keep_last_dims=2 - ) - B_index_map = make_iter_fusion_index_map( - B_traits, [IterKind.kIter_S, IterKind.kIter_J, IterKind.kIter_K], keep_last_dims=2 - ) - C_index_map = make_iter_fusion_index_map( - C_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J], keep_last_dims=3 - ) - matmul_index_map = make_iter_fusion_index_map( - block_traits, - [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K], - keep_last_dims=3, - ) - - return ( - matmul_index_map, - A_index_map, - B_index_map, - C_index_map, - ) - - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - if func.attrs is not None and "transform_kind" in func.attrs.keys(): - if func.attrs["transform_kind"] >= 2: - - def ldmatrix_permutation_16x16_32x8_16x16(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 8 - local_id = kernel_j % 8 - return ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id) - - def ldmatrix_trans_permutation_16x16_32x8_16x16(kernel_i, kernel_j): - thread_id = kernel_i * 2 + kernel_j // 8 - local_id = kernel_j % 8 - return ldmatrix_trans_32x8_to_shared_16x16_layout(thread_id, local_id) - - index_map = IndexMap.from_func(ldmatrix_permutation_16x16_32x8_16x16) - inversed_index_map = index_map.inverse([16, 16]) - - def A_permutation_inverse(i, j, kernel_i, kernel_j): - return (i, j, *inversed_index_map.map_indices([kernel_i, kernel_j])) - - sch.transform_layout( - sch.get_block("A_reindex_reindex"), ("read", 0), A_permutation_inverse - ) - index_map = IndexMap.from_func(ldmatrix_trans_permutation_16x16_32x8_16x16) - inversed_index_map = index_map.inverse([16, 16]) - - def B_permutation_inverse(i, j, kernel_i, kernel_j): - return (i, j, *inversed_index_map.map_indices([kernel_i, kernel_j])) - - sch.transform_layout( - sch.get_block("B_reindex_reindex"), ("read", 0), B_permutation_inverse - ) - - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 - - warp_size = 32 - vector_size = 4 - - i_factors, j_factors, k_factors = ( - [None, 1, 4, 2], - [1, None, 4, 2], - [None, 2], - ) - - num_ty = i_factors[2] * j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) - sch.transform_block_layout(main_block, matmul_index_map) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - x_pad_factor, - y_pad_factor, - k_pad_factor, - micro_size_x, - micro_size_y, - micro_size_k, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k, i_inner, j_inner, k_inner = sch.get_loops(block) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, "shared") - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) - - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) - sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) - sch.annotate(block_read, "double_buffer_scope", 0) - return block_read - - a_g2s = fetch_to_shared(block_outer, 0, 4) - b_g2s = fetch_to_shared(block_outer, 1, 4) - - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared") - sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, thread_idy) - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - - def index_map(*args): - other_args = args[:-2] - inner_i, inner_j = args[-2:] - return ( - *other_args, - *shared_16x16_to_mma_32x8_layout(inner_i, inner_j), - ) - - sch.transform_layout(A_mat, ("write", 0), index_map) - sch.transform_layout(B_mat, ("write", 0), index_map) - sch.transform_layout(store, ("read", 0), index_map) - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - intrin_group = { - "load_a": LDMATRIX_f16_A_INTRIN, - "load_b": LDMATRIX_f16_B_TRANS_INTRIN, - "init": MMA_fill_16x16_f16_INTRIN, - "compute": MMA_f16f16f16_TRANS_B_INTRIN, - "store": MMA_store_16x16_f16_shared_INTRIN, - } - - try: - i, j, ii, jj = sch.get_loops(A_mat)[-4:] - sch.tensorize(ii, intrin_group["load_a"]) - - i, j, ii, jj = sch.get_loops(B_mat)[-4:] - sch.tensorize(ii, intrin_group["load_b"]) - except: # pylint: disable=bare-except - return None - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - tensorize_init_store_compute() - - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-4:]) - _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - - return sch - - -class MatmulMMATensorization(GPUScheduleRule): - """ - The schedule rule for float16 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config, - ) -> Optional[tir.Schedule]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group, - shared_16x16_to_mma_32x8_layout, - ) - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - intrin_info = config.intrin_info - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - chunk = config.rstep[0] - - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], - ) - - num_ty = i_factors[2] * j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - - # plan rasteration - if ( - not isinstance(config.rasterization_plan, NoRasterization) - and sch.get(batch).extent.value == 1 - ): - device_func, invoke_func = config.rasterization_plan.get_code() - factor = config.rasterization_plan.panel_width_ - - # TODO(lei): this is a trick for rasterization implementation - # wait for https://github.com/apache/tvm/pull/16113 to be merged - # require a solution for general block rasterization - factor = 8 # should be divisible by block_idy - if sch.get(block_idy).extent.value % factor == 0: - block_k, block_idy = sch.split(block_idy, factors=[None, factor]) - sch.bind(block_k, "blockIdx.z") - else: - sch.bind(batch, "blockIdx.z") - - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim, vec_len, dtype="float16"): - block_read = sch.cache_read(block, idx, "shared.dyn") - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vec_len]) - - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - return block_read - - a_g2s = fetch_to_shared( - block_outer, - 0, - 2, - vec_len=list(config.vectorize.values())[0], - dtype=intrin_info.in_dtype, - ) - b_g2s = fetch_to_shared( - block_outer, - 1, - 2, - vec_len=list(config.vectorize.values())[1], - dtype=intrin_info.in_dtype, - ) - - sch.annotate(a_g2s, ann_key="permuted_layout", ann_val="g2s_A") - sch.annotate(b_g2s, ann_key="permuted_layout", ann_val="g2s_B") - - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "warp") - B_mat = sch.cache_read(block_outer, 1, "warp") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") - sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) - - store = sch.cache_write(block_outer, 0, "warp") - sch.reverse_compute_at(store, thread_idy) - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - intrin_group = get_mma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_a=False, - trans_b=intrin_info.trans_b, - ) - - def index_map_A(b, i, j): - return ( - b, - i // 16, - j // 16, - *shared_16x16_to_mma_32x8_layout(i % 16, j % 16), - ) - - def index_map_B(b, i, j): - return ( - b, - i // 16, - j // 16, - *shared_16x16_to_mma_32x8_layout(i % 16, j % 16), - ) - - def index_map_C(b, i, j): - return ( - b, - i // 16, - j // 16, - *shared_16x16_to_mma_32x8_layout(i % 16, j % 16), - ) - - sch.transform_layout(A_mat, ("write", 0), index_map_A) - sch.transform_layout(B_mat, ("write", 0), index_map_A) - sch.transform_layout(store, ("read", 0), index_map_C) - - try: - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - ba = sch.blockize(i1) - sch.annotate(ba, ann_key="permuted_layout", ann_val="s2l_A") - sch.tensorize(ba, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, micro_size_x]) - j0, j1 = sch.split(j, factors=[None, micro_size_y]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - bb = sch.blockize(i1) - sch.annotate(bb, ann_key="permuted_layout", ann_val="s2l_B") - sch.tensorize(bb, intrin_group["load_b"]) - except: # pylint: disable=bare-except - return None - - # Try to tensorize the init - tensorize_success: bool = False - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - try: - tensorize_init_store_compute() - tensorize_success = True - except: # pylint: disable=bare-except - return None - - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) - _, f1, f2 = sch.split( - fused, factors=[None, warp_size, max(list(config.vectorize.values()))] - ) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - - if stage > 1: - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - - return sch if tensorize_success else None - - -class MatmulWMMATensorization(GPUScheduleRule): - """ - The schedule rule for float16 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Optional[tir.Schedule]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group, - ) - - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): - return None - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 - - warp_size = 32 - vector_size = 4 - - i_factors, j_factors, k_factors = ( - [None, 1, 4, 2], - [1, None, 4, 2], - [None, 4], - ) - - num_ty = i_factors[2] * j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) - sch.transform_block_layout(main_block, matmul_index_map) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) - sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) - sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) - sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, "shared.dyn") - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) - - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) - sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) - sch.annotate(block_read, "double_buffer_scope", 0) - return block_read - - a_g2s = fetch_to_shared(block_outer, 0, 2) - b_g2s = fetch_to_shared(block_outer, 1, 2) - - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") - B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") - sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) - - store = sch.cache_write(block_outer, 0, "wmma.accumulator") - sch.reverse_compute_at(store, thread_idy) - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - intrin_group = get_wmma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype="float16", - out_dtype="float32", - trans_b=True, - ) - - try: - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_b"]) - except: # pylint: disable=bare-except - return None - - # Try to tensorize the init, store and compute block with f16 or f32 intrinsics - tensorize_success: bool = False - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - try: - tensorize_init_store_compute() - tensorize_success = True - except: # pylint: disable=bare-except - intrin_group = get_wmma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype="float16", - out_dtype="float16", - trans_b=True, - ) - - if not tensorize_success: - try: - tensorize_init_store_compute() - tensorize_success = True - except: # pylint: disable=bare-except - return None - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) - _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - - return sch if tensorize_success else None - - def apply_config( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - config, - ) -> Optional[tir.Schedule]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group, - ) - - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - intrin_info = config.intrin_info - warp_row_tiles = config.warp[0] - warp_col_tiles = config.warp[1] - block_row_warps = config.block[0] // warp_row_tiles - block_col_warps = config.block[1] // warp_col_tiles - stage = config.pipeline_stage - use_async = config.use_async - chunk = config.rstep[0] - - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 - - warp_size = 32 - - i_factors, j_factors, k_factors = ( - [None, 1, block_row_warps, warp_row_tiles // micro_size_x], - [1, None, block_col_warps, warp_col_tiles // micro_size_y], - [None, chunk // micro_size_k], - ) - - num_ty = i_factors[2] * j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K]/B[S, K, J] - if not (func.attrs is not None and "dlight.tensorcore_prenormlized" in func.attrs.keys()): - sch = normalize_to_matmul(sch, main_block, ["a", "a", "a"]) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - # plan rasteration - if ( - not isinstance(config.rasterization_plan, NoRasterization) - and sch.get(batch).extent.value == 1 - ): - device_func, invoke_func = config.rasterization_plan.get_code() - factor = config.rasterization_plan.panel_width_ - - # TODO(lei): this is a trick for rasterization implementation - # wait for https://github.com/apache/tvm/pull/16113 to be merged - # require a solution for general block rasterization - factor = 8 # should be divisible by block_idy - if sch.get(block_idy).extent.value % factor == 0: - block_k, block_idy = sch.split(block_idy, factors=[None, factor]) - sch.bind(block_k, "blockIdx.z") - else: - sch.bind(batch, "blockIdx.z") - - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim, vec_len, dtype="float16"): - block_read = sch.cache_read(block, idx, "shared.dyn") - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vec_len]) - - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - offset: int = 0 - if dtype == "float16": - offset = 8 - elif dtype == "int8": - offset = 16 - # todo(lei): the pad value should be varied according to the data type - sch.storage_align(block_read, 0, axis=-2, factor=16, offset=offset) - return block_read - - a_g2s = fetch_to_shared( - block_outer, - 0, - 2, - vec_len=list(config.vectorize.values())[0], - dtype=intrin_info.in_dtype, - ) - b_g2s = fetch_to_shared( - block_outer, - 1, - 2, - vec_len=list(config.vectorize.values())[1], - dtype=intrin_info.in_dtype, - ) - - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") - B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") - sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) - - store = sch.cache_write(block_outer, 0, "wmma.accumulator") - sch.reverse_compute_at(store, thread_idy) - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - intrin_group = get_wmma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype=intrin_info.in_dtype, - out_dtype=intrin_info.out_dtype, - trans_b=intrin_info.trans_b, - ) - - try: - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_b"]) - except: # pylint: disable=bare-except - return None - - # Try to tensorize the init, store and compute block with f16 or f32 intrinsics - tensorize_success: bool = False - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - try: - tensorize_init_store_compute() - tensorize_success = True - except: # pylint: disable=bare-except - return None - - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) - _, f1, f2 = sch.split( - fused, factors=[None, warp_size, max(list(config.vectorize.values()))] - ) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - - if stage > 1: - sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, stage - 1]) - sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) - if use_async: - sch.annotate(k0, "software_pipeline_async_stages", [0]) - - return sch if tensorize_success else None - - -class MatmulInt8Tensorization(GPUScheduleRule): - """ - The schedule rule for int8 tensor core matmul computation. - func with attr 'dlight.do_not_tensorize' will not be tensorized. - """ - - def apply( # pylint: disable=too-many-locals,missing-docstring - self, - func: tir.PrimFunc, - target: Target, - _: bool, - ) -> Optional[tir.Schedule]: - from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group, - ) - - if not isinstance(func, tir.PrimFunc) or not self.is_target_available(target): - return None - sch = tir.Schedule(func) - root_block = analysis.get_root_block(sch) - blocks = sch.get_child_blocks(root_block) - - if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): - return None - - reduction_blocks = get_reduction_blocks(sch, blocks) - if reduction_blocks is None: - return None - - main_block = reduction_blocks[0] - block_stmt = sch.get(main_block) - index_maps = get_index_map(block_stmt) - if index_maps is None: - return None - matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps - - # Start Schedule - # Step 0. Get schedule config. - # NOTE: we can analyze the config by the hardware spec in the future - - # tensor core intrinsic size - micro_size_x = 16 - micro_size_y = 16 - micro_size_k = 16 - - warp_size = 32 - vector_size = 4 - - i_factors, j_factors, k_factors = ( - [None, 1, 4, 2], - [1, None, 4, 2], - [None, 1], - ) - - num_ty = i_factors[2] * j_factors[2] - x_pad_factor = i_factors[2] * i_factors[3] - y_pad_factor = j_factors[2] * j_factors[3] - k_pad_factor = k_factors[1] - - # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] - block = sch.reindex(main_block, ("read", 0)) - sch.transform_layout(block, ("write", 0), a_index_map) - block = sch.reindex(main_block, ("read", 1)) - sch.transform_layout(block, ("write", 0), b_index_map) - block = sch.reindex(main_block, ("write", 0)) - sch.transform_layout(block, ("read", 0), c_index_map) - sch.transform_block_layout(main_block, matmul_index_map) - - # Step 2. Padding for dynamic shape kernels - sch.pad_einsum( - main_block, - [ - 1, - micro_size_x * x_pad_factor, - micro_size_y * y_pad_factor, - micro_size_k * k_pad_factor, - ], - ) - - # Step 3. Schedule matmul to use tensor core - block = main_block - - batch, i, j, k = sch.get_loops(block) - - # inner loops for tensor core computation - i, i_inner = sch.split(i, factors=[None, micro_size_x]) - j, j_inner = sch.split(j, factors=[None, micro_size_y]) - k, k_inner = sch.split(k, factors=[None, micro_size_k]) - - sch.reorder(i, j, k, i_inner, j_inner, k_inner) - - block_inner = block - block_outer = sch.blockize(i_inner) - - i0, i1, i2, i3 = sch.split(i, factors=i_factors) - j0, j1, j2, j3 = sch.split(j, factors=j_factors) - k0, k1 = sch.split(k, k_factors) - sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) - sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) - sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) - sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) - - sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) - - block_idx = sch.fuse(i0, j0) - block_idy = sch.fuse(i1, j1) - thread_idy = sch.fuse(j2, i2) - sch.bind(batch, "blockIdx.z") - sch.bind(block_idx, "blockIdx.x") - sch.bind(block_idy, "blockIdx.y") - sch.bind(thread_idy, "threadIdx.y") - - def fetch_to_shared(block, idx, ndim): - block_read = sch.cache_read(block, idx, "shared.dyn") - sch.compute_at(block_read, k0) - fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) - - _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) - - sch.bind(f_2, "threadIdx.x") - sch.bind(f_1, "threadIdx.y") - sch.vectorize(f_3) - - sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16) - sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) - sch.annotate(block_read, "double_buffer_scope", 0) - return block_read - - a_g2s = fetch_to_shared(block_outer, 0, 2) - b_g2s = fetch_to_shared(block_outer, 1, 2) - - auto_inline_producers(sch, a_g2s) - auto_inline_producers(sch, b_g2s) - - # create read cache to load matrix from shared memory to wmma fragments - A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") - B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") - sch.compute_at(A_mat, k1) - sch.compute_at(B_mat, k1) - - # create write cache to store matrix from wmma fragments to shared memory and global memory - accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") - sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) - - store = sch.cache_write(block_outer, 0, "wmma.accumulator") - sch.reverse_compute_at(store, thread_idy) - sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) - - # split the store loop to match hardware intrinsic pattern - i, j = sch.get_loops(store)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - - block_init_c = sch.decompose_reduction(block_outer, k0) - block_init_c_inner = sch.get_child_blocks(block_init_c)[0] - - # Tensorization by hardware intrinsics - intrin_group = get_wmma_intrin_group( - load_scope="shared.dyn", - store_scope="shared.dyn", - in_dtype="int8", - out_dtype="int32", - trans_b=True, - ) - - try: - i, j = sch.get_loops(A_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_a"]) - - i, j = sch.get_loops(B_mat)[-2:] - i0, i1 = sch.split(i, factors=[None, 16]) - j0, j1 = sch.split(j, factors=[None, 16]) - sch.reorder(i0, j0, i1, j1) - sch.unroll(i0) - sch.unroll(j0) - sch.tensorize(i1, intrin_group["load_b"]) - except: # pylint: disable=bare-except - return None - - def tensorize_init_store_compute(): - sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) - sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) - sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) - - try: - tensorize_init_store_compute() - except: # pylint: disable=bare-except - return None - - auto_inline_consumer_chain(sch, accumulator_shared_to_global) - - fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) - _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) - sch.bind(f1, "threadIdx.x") - sch.vectorize(f2) - - return sch - +from .matmul_mma import MatmulTensorizationMMA +from .matmul_wmma import MatmulInt8Tensorization, MatmulTensorizationWMMA, MatmulTensorizationLegacy +from functools import reduce class Matmul(GPUScheduleRule): """The schedule rule for matmul-like computation""" @@ -1434,8 +117,8 @@ def apply( # pylint: disable=too-many-locals,missing-docstring # If any value of I, J, K is fixed and less than this threshold, # tensorization rule will not be applied. minimal_tensorize_threshold = 64 - - if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: + block_stmt = sch.get(main_block) + if target.kind.name == "cuda" and utils.get_sm_version(target) >= 70: apply_tensorization: bool = True # the batch dimension is not taken into consideration. for item_var in block_stmt.iter_vars[1:]: @@ -1448,8 +131,12 @@ def apply( # pylint: disable=too-many-locals,missing-docstring in_dtype, out_dtype = get_in_out_dtypes(block_stmt) if in_dtype == "int8" and out_dtype == "int32": tensorize_sch = MatmulInt8Tensorization().apply(func, target, _) + elif utils.get_sm_version(target) >= 80: + # For A100(sm_80) or more advanced gpu, use MMA tensorization. + tensorize_sch = MatmulTensorizationMMA().apply(func, target, _) else: - tensorize_sch = MatmulWMMATensorization().apply(func, target, _) + # For other GPUs, use WMMA tensorization. + tensorize_sch = MatmulTensorizationWMMA().apply(func, target, _) if tensorize_sch is not None: return tensorize_sch @@ -1528,8 +215,31 @@ def _cooperative_fetch(index, vec_len): auto_inline_producers(sch, main_block) auto_inline_consumer_chain(sch, l2g) - sch.decompose_reduction(main_block, ko) + + # Step 4. Check if there are unbound blocks. Execute fallback scheduling to them. + def is_scheduled(block: tir.schedule.BlockRV) -> bool: + loops = sch.get_loops(block) + loop_kinds = {sch.get(loop).kind for loop in loops} + return loop_kinds != {ForKind.SERIAL} + + blocks = sch.get_child_blocks(root_block) + max_threads_per_block = utils.max_threads_per_block(target) + for block in blocks: + if is_scheduled(block): + continue + # no axis of the block is bound to thread or block + s_loops = sch.get_loops(block) + bx, tx = sch.split( + sch.fuse(*s_loops), + factors=[ + None, + 256, + ], + ) + sch.bind(bx, "blockIdx.x") + sch.bind(tx, "threadIdx.x") + return sch def apply_config( # pylint: disable=too-many-locals,missing-docstring diff --git a/python/tvm/dlight/gpu/matmul_analysis.py b/python/tvm/dlight/gpu/matmul_analysis.py new file mode 100644 index 000000000000..14cce6e7c663 --- /dev/null +++ b/python/tvm/dlight/gpu/matmul_analysis.py @@ -0,0 +1,374 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from dataclasses import dataclass +from enum import Enum +from typing import Dict, List, Optional, Set, Tuple + +from tvm import tir +from tvm.ir import Range +from tvm.tir import IterVar, PrimExpr, Var +from tvm.tir.analysis import undefined_vars +from tvm.tir.schedule.schedule import BlockRV + + +def _is_one(x: PrimExpr) -> bool: + return isinstance(x, tir.IntImm) and x.value == 1 + + +def _collect_producers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for producer in sch.get_producers(block): + result.append(producer) + result.extend(_collect_producers(sch, producer)) + return result + + +def _collect_consumers(sch: tir.Schedule, block: tir.schedule.BlockRV): + result = [] + for consumer in sch.get_consumers(block): + result.append(consumer) + result.extend(_collect_consumers(sch, consumer)) + return result + + +def auto_inline_producers( + sch: tir.Schedule, + block: tir.schedule.BlockRV, + skip_blocks: Optional[List[tir.schedule.BlockRV]] = None, +): + skip_blocks = skip_blocks or [] + while True: + inlined_cnt = 0 + producers = _collect_producers(sch, block) + for producer in producers: + if any(sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks): + continue + try: + sch.compute_inline(producer) + inlined_cnt += 1 + except: # pylint: disable=bare-except + continue + if inlined_cnt == 0: + return + + +def auto_inline_consumers( + sch: tir.Schedule, + block: tir.schedule.BlockRV, +): + while True: + inlined_cnt = 0 + consumers = _collect_consumers(sch, block) + for consumer in consumers: + try: + sch.compute_inline(consumer) + inlined_cnt += 1 + except: # pylint: disable=bare-except + continue + for consumer in consumers: + try: + sch.reverse_compute_inline(consumer) + inlined_cnt += 1 + except: # pylint: disable=bare-except + continue + if inlined_cnt == 0: + return + + +def auto_inline_consumer_chain( + sch: tir.Schedule, + block: tir.schedule.BlockRV, +): + auto_inline_consumers(sch, block) + remaining_consumers = sch.get_consumers(block) + + if len(remaining_consumers) != 0: + # Some blocks have failed to be inlined to the producer cache-write stage. + # This could be due to another producer block that has not been scheduled. + for c in remaining_consumers: + for p in sch.get_producers(c): + if sch.get(p) != sch.get(block): + sch.compute_inline(p) + + # Try inlining into the cache-write stage again, this time it should succeed. + auto_inline_consumers(sch, block) + + +class IterKind(Enum): + """Iter kinds for GEMM-liked programs. + We can simplify the computation to C[S, I, J] += A[S, I, K] * B[S, J, K], + where `I, J, K` are fundamental axes for gemm and `S` represents all + other spatial axes (e.g. batches) + kIter_S: spatial axes + kIter_I: I axes + kIter_J: J axes + kIter_K: K axes + kIter_T: trivial axes (i.e. with extent 1) + """ + + kIter_S = 0 + kIter_I = 1 + kIter_J = 2 + kIter_K = 3 + kIter_T = 4 + + +@dataclass +class IterTrait: + kind: IterKind + extent: PrimExpr + + +def make_iter_fusion_index_map( + traits: List[IterTrait], + kind_order: List[IterKind], +) -> tir.IndexMap: + fused_iters: Dict[IterKind, PrimExpr] = {} + input_iters: List[tir.Var] = [] + for i, trait in enumerate(traits): + v_i = tir.Var(f"i{i}", trait.extent.dtype) + input_iters.append(v_i) + if trait.kind == IterKind.kIter_T: + continue + if trait.kind not in kind_order: + raise ValueError(f"Unknown iter kind {trait.kind}") + if trait.kind in fused_iters: + fused_iters[trait.kind] = fused_iters[trait.kind] * trait.extent + v_i + else: + fused_iters[trait.kind] = v_i + + final_indices: List[tir.PrimExpr] = [ + fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order + ] + + return tir.IndexMap(input_iters, final_indices, None) + + +def detect_iter_traits(block: tir.Block) -> Optional[Tuple[List[IterTrait]]]: + """Detect iter traits based on the pattern C[S, I, J] += A[S, I, K] * B[S, J, K] + + Parameters + ---------- + block : tir.Block + The block to be analyzed + + Returns + ------- + traits : Optional[Tuple[List[IterTrait]]] + The detected iter traits for axes in A, B and C. None if the block + does not match the pattern. + + """ + + if len(block.reads) != 2 or len(block.writes) != 1: + return None + + def get_access_axes(region: List[Range]) -> Set[Var]: + axes: Set[Var] = set() + for r in region: + if not r.extent: + raise ValueError("Expect elemwise block access") + axes = axes.union(set(undefined_vars(r.min))) + return axes + + try: + A_axes = get_access_axes(block.reads[0].region) + B_axes = get_access_axes(block.reads[1].region) + C_axes = get_access_axes(block.writes[0].region) + except ValueError: + return None + + traits: Dict[Var, IterTrait] = {} + for iter_var in block.iter_vars: + var = iter_var.var + kind: IterKind + if _is_one(iter_var.dom.extent): + kind = IterKind.kIter_T + elif iter_var.iter_type == iter_var.DataPar: + if var in A_axes and var in B_axes and var in C_axes: + kind = IterKind.kIter_S + elif var in A_axes and var in C_axes: + kind = IterKind.kIter_I + elif var in B_axes and var in C_axes: + kind = IterKind.kIter_J + else: + return None + elif iter_var.iter_type == tir.IterVar.CommReduce: + if var in A_axes and var in B_axes and var not in C_axes: + kind = IterKind.kIter_K + else: + return None + else: + return None + traits[var] = IterTrait(kind, iter_var.dom.extent) + + # A Gemm-kernel requires have I, J and K axes + gemm_traits = {IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K} + if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits: + return None + + A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes] + B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes] + C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes] + block_traits = [traits[i.var] for i in block.iter_vars] + return A_traits, B_traits, C_traits, block_traits + + +def get_index_map(block: tir.Block) -> Optional[Tuple[tir.IndexMap, ...]]: + """Get index maps for the block + + Parameters + ---------- + block : tir.Block + The block to be analyzed + + Returns + ------- + index_maps : Optional[Tuple[tir.IndexMap]] + The index maps for the block, or None if the block is not a gemm-liked kernel + """ + traits = detect_iter_traits(block) + if traits is None: + return None + A_traits, B_traits, C_traits, block_traits = traits + + A_index_map = make_iter_fusion_index_map( + A_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_K] + ) + B_index_map = make_iter_fusion_index_map( + B_traits, [IterKind.kIter_S, IterKind.kIter_J, IterKind.kIter_K] + ) + C_index_map = make_iter_fusion_index_map( + C_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J] + ) + matmul_index_map = make_iter_fusion_index_map( + block_traits, [IterKind.kIter_S, IterKind.kIter_I, IterKind.kIter_J, IterKind.kIter_K] + ) + + return matmul_index_map, A_index_map, B_index_map, C_index_map + + +def get_reduction_blocks(sch, blocks) -> Optional[List[BlockRV]]: + # Get the main computation block + def is_reduction(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.CommReduce, IterVar.DataPar} + + def is_spatial(block: BlockRV) -> bool: + block_stmt = sch.get(block) + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + return iter_types == {IterVar.DataPar} + + # NOTE: We assume there is only one reduction block in the function + # all blocks are required to be spatial or reduction + if not all([is_reduction(block) or is_spatial(block) for block in blocks]): + return None + + # There is only one reduction block + reduction_blocks = [block for block in blocks if is_reduction(block)] + if len(reduction_blocks) != 1: + return None + + return reduction_blocks + + +def get_in_out_dtypes(block: tir.Block) -> Tuple[str]: + """ + Detect In/Out data types for the given block based on the analysis if read/write buffers. + """ + assert len(block.reads) > 0 and len(block.writes) > 0 + in_dtype = block.reads[0].buffer.dtype + out_dtype = block.writes[0].buffer.dtype + return (in_dtype, out_dtype) + + +def get_dequantize_block(sch, blocks) -> Optional[BlockRV]: + # check at least two input and one output + # at lease one input has uint dtype, and the output dtype is float + def is_dequantize(block: BlockRV) -> bool: + block_stmt = sch.get(block) + if len(block_stmt.reads) < 2: + return False + has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) + if not has_uint_input: + return False + if len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype): + return False + return True + + dequantize_blocks = [block for block in blocks if is_dequantize(block)] + return dequantize_blocks[0] if len(dequantize_blocks) == 1 else None + + +def is_identity_or_transpose_block(block_stmt: tir.Block) -> bool: + iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars} + if iter_types != {IterVar.DataPar}: + return False, False + if not isinstance(block_stmt.body, tir.BufferStore): + return False, False + if not isinstance(block_stmt.body.value, tir.BufferLoad): + return False, False + + def get_access_vars(region: List[Range]) -> List[Var]: + axes: List[Var] = [] + for r in region: + if not _is_one(r.extent): + return None + axes.extend(undefined_vars(r.min)) + # remove trivial axis + trivial_vars = set( + iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent) + ) + axes = [axis for axis in axes if axis not in trivial_vars] + # remove duplicate axis + axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] + return axes + + lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] + rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] + is_identity = list(lhs_access_vars) == list(rhs_access_vars) + is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set( + rhs_access_vars + ) + return is_identity, is_transpose + + +def is_identity_block(block_stmt: tir.Block) -> bool: + return is_identity_or_transpose_block(block_stmt)[0] + + +def is_transpose_block(block_stmt: tir.Block) -> bool: + return is_identity_or_transpose_block(block_stmt)[1] + + +def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV]): + result_blocks = [] + for block in blocks: + if not is_transpose_block(sch.get(block)): + result_blocks.append(block) + continue + try: + sch.compute_inline(block) + except: + try: + sch.reverse_compute_inline(block) + except: + result_blocks.append(block) + return result_blocks diff --git a/python/tvm/dlight/gpu/matmul_mma.py b/python/tvm/dlight/gpu/matmul_mma.py new file mode 100644 index 000000000000..b980d28fcaf9 --- /dev/null +++ b/python/tvm/dlight/gpu/matmul_mma.py @@ -0,0 +1,323 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +from typing import Literal, Optional + +from tvm import tir +from tvm.target import Target + +from ..base import ScheduleRule, analysis +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_producers, + get_dequantize_block, + get_index_map, + get_reduction_blocks, + inline_transpose_block, + is_identity_block, + is_transpose_block, +) + + +class MatmulTensorizationMMA(ScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + # We first inline all transpose blocks for later analysis of transposed A and B + blocks = inline_transpose_block(sch, blocks) + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + dequantize_block = get_dequantize_block(sch, blocks) + + main_block = reduction_blocks[0] + main_block_stmt = sch.get(main_block) + + # Supported data types: + # fp16, fp16, fp16: fp16 precision + # fp16, fp16, fp32: fp16 mixed precision + dtype_a = main_block_stmt.reads[0].buffer.dtype + dtype_b = main_block_stmt.reads[1].buffer.dtype + dtype_c = main_block_stmt.writes[0].buffer.dtype + if dtype_a != dtype_b: + return None + + # Get index maps + index_maps = get_index_map(main_block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # Tensorization by hardware intrinsics + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_mma_intrin_group, + shared_16x16_to_ldmatrix_32x8_layout, + ) + + # tile size + block_m, block_n, block_k = 128, 128, 32 + + # tensor core intrinsic size + micro_size_m, micro_size_n, micro_size_k = 16, 16, 16 + + # thread size + # thread_x == warp_size + thread_z, thread_y, thread_x = 2, 2, 32 + + vector_size = 8 + unroll_depth = 4 + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + is_transpose_a = is_transpose_block(sch.get(block)) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + is_transpose_b = is_identity_block(sch.get(block)) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + batch, i, j, k = sch.get_loops(main_block) + + swizzle_factor_for_l2_m = [1, None] + swizzle_factor_for_l2_n = [1, None] + # swizzle_factor_for_l2_m = [4, None] + # swizzle_factor_for_l2_n = [4, None] + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + swizzle_factor_for_l2_m[0] * block_m, + swizzle_factor_for_l2_n[0] * block_n, + block_k, + ], + ) + + # Step 3. Reorder loops for tiling + + # Step 3.1 inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_m]) + j, j_inner = sch.split(j, factors=[None, micro_size_n]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = main_block + block_outer = sch.blockize(i_inner) + + # Step 3.2 outer loops for tiling + # split factors for i, j, and k + micro_block_cnt_in_warp_m = block_m // thread_z // micro_size_m + micro_block_cnt_in_warp_n = block_n // thread_y // micro_size_n + micro_block_cnt_in_warp_k = block_k // micro_size_k + + i_factors = swizzle_factor_for_l2_m + [thread_z, micro_block_cnt_in_warp_m] + j_factors = swizzle_factor_for_l2_n + [thread_y, micro_block_cnt_in_warp_n] + k_factors = [None, micro_block_cnt_in_warp_k] + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, factors=k_factors) + + sch.reorder(i0, j0, i1, j1, k0, i2, j2, k1, i3, j3) + + block_axis = sch.fuse(batch, i0, j0, i1, j1) + sch.bind(block_axis, "blockIdx.x") + + sch.bind(i2, "threadIdx.z") + sch.bind(j2, "threadIdx.y") + + # Step 4. Read/write to shared mem and register + def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], is_transpose): + # 1) Read to shared memory + block_read_smem = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn") + sch.compute_at(block_read_smem, k0) + auto_inline_producers( + sch, block_read_smem, [dequantize_block] if dequantize_block else [] + ) + + # For transposed read, we directly load transposed tensor from global + # Then use ldmatrix.trans to handle transpose later + if (tensor_name == "A" and is_transpose) or (tensor_name == "B" and not is_transpose): + # specifical handle transpose read (for NN matmul or TT matmul) + v0, v1 = sch.get_loops(block_read_smem)[-2:] + sch.reorder(v1, v0) + sch.transform_layout(block_read_smem, ("write", 0), lambda b, i, j: (b, j, i)) + + # bind loops + fused = sch.fuse(*sch.get_loops(block_read_smem)[-2:]) + f0, f1, f2, f3, f4 = sch.split(fused, [None, thread_z, thread_y, thread_x, vector_size]) + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + + # swizzling + sch.annotate(block_read_smem, ann_key="permuted_layout", ann_val=1) + + # 2) Read to register + block_read_reg = sch.cache_read(block_outer, read_buffer_idx, "warp") + sch.compute_at(block_read_reg, k1) + + # bind_loops + micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n + micro_size_1, micro_size_2 = ( + (micro_size_spatial, micro_size_k) + if not is_transpose + else (micro_size_k, micro_size_spatial) + ) + v00, v01 = sch.split(sch.get_loops(block_read_reg)[-2], [None, micro_size_1]) + v10, v11 = sch.split(sch.get_loops(block_read_reg)[-1], [None, micro_size_2]) + sch.reorder(v00, v10, v01, v11) + + # reorder read axis to match the layout of ldmatrix + sch.transform_layout( + block_read_reg, + ("write", 0), + lambda v0, v1, v2: ( + v0, + v1 // micro_size_1, + v2 // micro_size_2, + *shared_16x16_to_ldmatrix_32x8_layout(v1 % micro_size_1, v2 % micro_size_2), + ), + ) + + # swizzling + mma_read_block = sch.blockize(sch.get_loops(block_read_reg)[-2]) + sch.annotate(mma_read_block, ann_key="permuted_layout", ann_val=1) + + return block_read_smem, block_read_reg + + block_read_a, block_read_reg_a = fetch_input(block_outer, 0, "A", is_transpose_a) + block_read_b, block_read_reg_b = fetch_input(block_outer, 1, "B", is_transpose_b) + + # Write to register, and then smem + def store_output(block_outer, write_buffer_idx): + # 1) Write to shared memory + block_write_smem = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") + sch.reverse_compute_at(block_write_smem, block_axis) + auto_inline_consumer_chain(sch, block_write_smem) + + # bind loops + fused = sch.fuse(*sch.get_loops(block_write_smem)[-2:]) + f0, f1, f2, f3, f4 = sch.split(fused, [None, thread_z, thread_y, thread_x, vector_size]) + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + + # swizzling + sch.annotate(block_write_smem, ann_key="permuted_layout", ann_val=1) + + # 2) Write to register + block_write_reg = sch.cache_write(block_outer, write_buffer_idx, "warp") + + # bind loops + v0, v1, v2 = sch.get_loops(block_write_reg)[-3:] + v11, v12, v13 = sch.split(v1, factors=[thread_z, None, micro_size_m]) + v21, v22, v23 = sch.split(v2, factors=[thread_y, None, micro_size_n]) + sch.reorder(v11, v21, v12, v22, v13, v23) + sch.bind(v11, "threadIdx.z") + sch.bind(v21, "threadIdx.y") + + # reorder write axis to match the layout of ldmatrix + sch.transform_layout( + block_write_reg, + ("read", 0), + lambda v0, v1, v2: ( + v0, + v1 // micro_size_m, + v2 // micro_size_n, + *shared_16x16_to_ldmatrix_32x8_layout(v1 % micro_size_m, v2 % micro_size_n), + ), + ) + + # swizzling + mma_read_block = sch.blockize(sch.get_loops(block_write_reg)[-2]) + sch.annotate(mma_read_block, ann_key="permuted_layout", ann_val=1) + + return block_write_smem, block_write_reg + + block_write_smem, block_write_reg = store_output(block_outer, 0) + + # Step 5. Schedule tensor core computation + block_init = sch.decompose_reduction(block_outer, k0) + block_init_inner = sch.get_child_blocks(block_init)[0] + + # unroll k + # Profiling result shows unrolling k0 is not helpful on A100 + # sch.unroll(k0) + # k00, k01 = sch.split(k0, factors=[None, 8]) + # sch.unroll(k01) + + intrin_group = get_mma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype=str(dtype_a), + out_dtype=str(dtype_c), + trans_a=is_transpose_a, + trans_b=is_transpose_b, + ) + + sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(block_read_reg_a)[-2], intrin_group["load_a"]) + sch.tensorize(sch.get_loops(block_read_reg_b)[-2], intrin_group["load_b"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + sch.tensorize(sch.get_loops(block_write_reg)[-2], intrin_group["store"]) + + # Step 6. Async pipeline + sch.annotate(k0, ann_key="software_pipeline_stage", ann_val=[0, 0, 3]) + sch.annotate(k0, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) + sch.annotate(k0, ann_key="software_pipeline_async_stages", ann_val=[0]) + + # Step 7. Handle dequantize block + # Now we just add a dummy kernel to compute dequantize + if dequantize_block is not None: + auto_inline_producers(sch, dequantize_block) + loops = sch.get_loops(dequantize_block) + loop = sch.fuse(*loops) + v0, v1, v2, v3 = sch.split(loop, [None, 128, 2, 4]) + sch.bind(v0, "blockIdx.x") + sch.bind(v1, "threadIdx.x") + sch.unroll(v2) + sch.vectorize(v3) + return sch diff --git a/python/tvm/dlight/gpu/matmul_wmma.py b/python/tvm/dlight/gpu/matmul_wmma.py new file mode 100644 index 000000000000..62776a6e0663 --- /dev/null +++ b/python/tvm/dlight/gpu/matmul_wmma.py @@ -0,0 +1,685 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring, invalid-name +"""A GEMM schedule rule for GPU operators.""" +import math +from typing import Literal, Optional + +from tvm import DataType, tir +from tvm.target import Target +from tvm.tir.stmt import ForKind + +from ..base import ScheduleRule, analysis +from .matmul_analysis import ( + auto_inline_consumer_chain, + auto_inline_consumers, + auto_inline_producers, + get_index_map, + get_reduction_blocks, +) + + +class MatmulTensorizationWMMA(ScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + block_m = 128 + block_n = 128 + block_k = 32 + + # tensor core intrinsic size + micro_size_m = 16 + micro_size_n = 16 + micro_size_k = 16 + + thread_z = 2 + thread_y = 2 + warp_size = 32 + thread_cnt = thread_y * thread_z * warp_size + + vector_size = 8 + unroll_depth = 256 + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + + # # Step 2.1 Swizzle for l2, for better performance on inputs exceeding l2 size + # # Get input shape + batch, i, j, k = sch.get_loops(main_block) + # input_b, input_m, input_n, input_k = [sch.get(loop).extent for loop in [batch, i, j, k]] + + # # Get input/output dtype + dtype_a, dtype_b = [DataType(region.buffer.dtype) for region in sch.get(main_block).reads] + dtype_c = DataType(sch.get(main_block).writes[0].buffer.dtype) + # dtype_a_bytes, dtype_b_bytes = [math.ceil(d.bits / 8) for d in [dtype_a, dtype_b]] + + # # Get l2 size + # l2_size = target.l2_cache_size_bytes + + # # Analyse swizzle factor + # def get_swizzle_factor(l2_size, input_k, dtype_bytes, input_spatial, block_size): + # if l2_size != 0 and isinstance(input_k, (int, tir.IntImm)): + # # div by 3: suppose the two inputs and the output uses the same amount of l2 + # swizzle_factor = l2_size / 3 / int(input_k) / dtype_bytes / block_size + # # optimization: try find the best swizzle factor (aka the least additional padding) + # if isinstance(input_spatial, (int, tir.IntImm)): + # block_cnt = math.ceil(int(input_spatial) / block_size) + # swizzle_factor = math.ceil(block_cnt / math.ceil(block_cnt / swizzle_factor)) + # else: + # swizzle_factor = math.floor(swizzle_factor) + # return [None, swizzle_factor] + # else: + # return [4, None] + + # swizzle_factor_m = get_swizzle_factor(l2_size, input_k, dtype_a_bytes, input_m, block_m) + # swizzle_factor_n = get_swizzle_factor(l2_size, input_k, dtype_b_bytes, input_n, block_n) + + swizzle_factor_m = [4, None] + swizzle_factor_n = [4, None] + + # Step 2.2 Add padding + sch.pad_einsum( + main_block, + [ + 1, + (swizzle_factor_m[0] or swizzle_factor_m[1]) * block_m, + (swizzle_factor_n[0] or swizzle_factor_n[1]) * block_n, + block_k, + ], + ) + + # Step 3. Reorder loops for tiling + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_m]) + j, j_inner = sch.split(j, factors=[None, micro_size_n]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = main_block + block_outer = sch.blockize(i_inner) + + # split factors for i, j, and k + in_wrap_block_cnt_m = block_m // thread_z // micro_size_m + in_wrap_block_cnt_n = block_n // thread_y // micro_size_n + in_wrap_block_cnt_k = block_k // micro_size_k + + i_factors = swizzle_factor_m + [thread_z, in_wrap_block_cnt_m] + j_factors = swizzle_factor_n + [thread_y, in_wrap_block_cnt_n] + k_factors = [None, in_wrap_block_cnt_k] + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, factors=k_factors) + + sch.reorder(i0, j0, i1, j1, k0, i2, j2, k1, i3, j3) + block_axis = sch.fuse(batch, i0, j0, i1, j1) + + sch.bind(block_axis, "blockIdx.x") + sch.bind(i2, "threadIdx.z") + sch.bind(j2, "threadIdx.y") + + # Step 4. Read to/write from shared mem, and from/to wmma fragments + def fetch_input(block_outer, read_buffer_idx, tensor_name: Literal["A", "B"], wmma_name): + block_read = sch.cache_read(block_outer, read_buffer_idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-2:]) + + f0, f1, f2, f3, f4 = sch.split( + fused, [None, thread_z, thread_y, warp_size, vector_size] + ) + + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) + + auto_inline_producers(sch, block_read) + + wmma_read = sch.cache_read(block_outer, read_buffer_idx, wmma_name) + sch.compute_at(wmma_read, k1) + + micro_size_spatial = micro_size_m if tensor_name == "A" else micro_size_n + v0, v1 = sch.get_loops(wmma_read)[-2:] + sch.split(v0, factors=[None, micro_size_spatial]) + + return wmma_read + + wmma_read_a = fetch_input( + block_outer, 0, [block_m, block_k, micro_size_m, micro_size_k], "wmma.matrix_a" + ) + wmma_read_b = fetch_input( + block_outer, 1, [block_n, block_k, micro_size_n, micro_size_k], "wmma.matrix_b" + ) + + def store_output(block_outer, write_buffer_idx, wmma_name): + block_write = sch.cache_write(block_outer, write_buffer_idx, "shared.dyn") + sch.reverse_compute_at(block_write, block_axis) + + fused = sch.fuse(*sch.get_loops(block_write)[-2:]) + + f0, f1, f2, f3, f4 = sch.split( + fused, [None, thread_z, thread_y, warp_size, vector_size] + ) + + sch.bind(f1, "threadIdx.z") + sch.bind(f2, "threadIdx.y") + sch.bind(f3, "threadIdx.x") + sch.vectorize(f4) + # sch.storage_align(block_write, 0, axis=-2, factor=128, offset=16) + + auto_inline_consumer_chain(sch, block_write) + + wmma_store = sch.cache_write(block_outer, write_buffer_idx, wmma_name) + v0, v1 = sch.get_loops(wmma_store)[-2:] + v00, v01, v02 = sch.split(v0, factors=[thread_z, None, micro_size_m]) + v10, v11, v12 = sch.split(v1, factors=[thread_y, None, micro_size_n]) + sch.reorder(v00, v10, v01, v11, v02, v12) + sch.bind(v00, "threadIdx.z") + sch.bind(v10, "threadIdx.y") + return wmma_store + + wmma_store = store_output(block_outer, 0, "wmma.accumulator") + + block_init = sch.decompose_reduction(block_outer, k0) + block_init_inner = sch.get_child_blocks(block_init)[0] + + # unroll k + sch.unroll(k0) + + # Step 5. Schedule tensor core computation + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group, + ) + + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype=str(dtype_a), + out_dtype=str(dtype_c), + trans_b=True, + ) + + sch.tensorize(sch.get_loops(block_init_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(wmma_read_a)[-2], intrin_group["load_a"]) + sch.tensorize(sch.get_loops(wmma_read_b)[-2], intrin_group["load_b"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + sch.tensorize(sch.get_loops(wmma_store)[-2], intrin_group["store"]) + + return sch + + +class MatmulInt8Tensorization(ScheduleRule): + """ + The schedule rule for int8 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group, + ) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + micro_size_x = 16 + micro_size_y = 16 + micro_size_k = 16 + + warp_size = 32 + vector_size = 4 + + i_factors, j_factors, k_factors = ( + [None, 1, 4, 2], + [1, None, 4, 2], + [None, 1], + ) + + num_ty = i_factors[2] * j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) + sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) + sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) + sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + + sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16) + sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) + sch.annotate(block_read, "double_buffer_scope", 0) + return block_read + + a_g2s = fetch_to_shared(block_outer, 0, 2) + b_g2s = fetch_to_shared(block_outer, 1, 2) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") + B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") + sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) + + store = sch.cache_write(block_outer, 0, "wmma.accumulator") + sch.reverse_compute_at(store, thread_idy) + sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype="int8", + out_dtype="int32", + trans_b=True, + ) + + try: + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_b"]) + except: # pylint: disable=bare-except + return None + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + try: + tensorize_init_store_compute() + except: # pylint: disable=bare-except + return None + + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) + _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + return sch + + +class MatmulTensorizationLegacy(ScheduleRule): + """ + The schedule rule for float16 tensor core matmul computation. + func with attr 'dlight.do_not_tensorize' will not be tensorized. + """ + + def apply( # pylint: disable=too-many-locals,missing-docstring + self, + func: tir.PrimFunc, + target: Target, + _: bool, + ) -> Optional[tir.Schedule]: + from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel + get_wmma_intrin_group, + ) + + sch = tir.Schedule(func) + root_block = analysis.get_root_block(sch) + blocks = sch.get_child_blocks(root_block) + + if func.attrs is not None and "dlight.do_not_tensorize" in func.attrs.keys(): + return None + + reduction_blocks = get_reduction_blocks(sch, blocks) + if reduction_blocks is None: + return None + + main_block = reduction_blocks[0] + block_stmt = sch.get(main_block) + index_maps = get_index_map(block_stmt) + if index_maps is None: + return None + matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps + + # Start Schedule + # Step 0. Get schedule config. + # NOTE: we can analyze the config by the hardware spec in the future + + # tensor core intrinsic size + micro_size_x = 16 + micro_size_y = 16 + micro_size_k = 16 + + warp_size = 32 + vector_size = 4 + + i_factors, j_factors, k_factors = ( + [None, 1, 4, 2], + [1, None, 4, 2], + [None, 4], + ) + + num_ty = i_factors[2] * j_factors[2] + x_pad_factor = i_factors[2] * i_factors[3] + y_pad_factor = j_factors[2] * j_factors[3] + k_pad_factor = k_factors[1] + + # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, J, K] + block = sch.reindex(main_block, ("read", 0)) + sch.transform_layout(block, ("write", 0), a_index_map) + block = sch.reindex(main_block, ("read", 1)) + sch.transform_layout(block, ("write", 0), b_index_map) + block = sch.reindex(main_block, ("write", 0)) + sch.transform_layout(block, ("read", 0), c_index_map) + sch.transform_block_layout(main_block, matmul_index_map) + + # Step 2. Padding for dynamic shape kernels + sch.pad_einsum( + main_block, + [ + 1, + micro_size_x * x_pad_factor, + micro_size_y * y_pad_factor, + micro_size_k * k_pad_factor, + ], + ) + + # Step 3. Schedule matmul to use tensor core + block = main_block + + batch, i, j, k = sch.get_loops(block) + + # inner loops for tensor core computation + i, i_inner = sch.split(i, factors=[None, micro_size_x]) + j, j_inner = sch.split(j, factors=[None, micro_size_y]) + k, k_inner = sch.split(k, factors=[None, micro_size_k]) + + sch.reorder(i, j, k, i_inner, j_inner, k_inner) + + block_inner = block + block_outer = sch.blockize(i_inner) + + i0, i1, i2, i3 = sch.split(i, factors=i_factors) + j0, j1, j2, j3 = sch.split(j, factors=j_factors) + k0, k1 = sch.split(k, k_factors) + sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) + sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) + sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) + sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) + + sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) + + block_idx = sch.fuse(i0, j0) + block_idy = sch.fuse(i1, j1) + thread_idy = sch.fuse(j2, i2) + sch.bind(batch, "blockIdx.z") + sch.bind(block_idx, "blockIdx.x") + sch.bind(block_idy, "blockIdx.y") + sch.bind(thread_idy, "threadIdx.y") + + def fetch_to_shared(block, idx, ndim): + block_read = sch.cache_read(block, idx, "shared.dyn") + sch.compute_at(block_read, k0) + fused = sch.fuse(*sch.get_loops(block_read)[-ndim:]) + + _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, warp_size, vector_size]) + + sch.bind(f_2, "threadIdx.x") + sch.bind(f_1, "threadIdx.y") + sch.vectorize(f_3) + + sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8) + sch.annotate(block_read, "tir.manifest_shared_memory_local_stage", 1) + sch.annotate(block_read, "double_buffer_scope", 0) + return block_read + + a_g2s = fetch_to_shared(block_outer, 0, 2) + b_g2s = fetch_to_shared(block_outer, 1, 2) + + auto_inline_producers(sch, a_g2s) + auto_inline_producers(sch, b_g2s) + + # create read cache to load matrix from shared memory to wmma fragments + A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a") + B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b") + sch.compute_at(A_mat, k1) + sch.compute_at(B_mat, k1) + + # create write cache to store matrix from wmma fragments to shared memory and global memory + accumulator_shared_to_global = sch.cache_write(block_outer, 0, "shared.dyn") + sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4) + + store = sch.cache_write(block_outer, 0, "wmma.accumulator") + sch.reverse_compute_at(store, thread_idy) + sch.reverse_compute_at(accumulator_shared_to_global, thread_idy) + + # split the store loop to match hardware intrinsic pattern + i, j = sch.get_loops(store)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + + block_init_c = sch.decompose_reduction(block_outer, k0) + block_init_c_inner = sch.get_child_blocks(block_init_c)[0] + + # Tensorization by hardware intrinsics + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype="float16", + out_dtype="float32", + trans_b=True, + ) + + try: + i, j = sch.get_loops(A_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_a"]) + + i, j = sch.get_loops(B_mat)[-2:] + i0, i1 = sch.split(i, factors=[None, 16]) + j0, j1 = sch.split(j, factors=[None, 16]) + sch.reorder(i0, j0, i1, j1) + sch.unroll(i0) + sch.unroll(j0) + sch.tensorize(i1, intrin_group["load_b"]) + except: # pylint: disable=bare-except + return None + + # Try to tensorize the init, store and compute block with f16 or f32 intrinsics + tensorize_success: bool = False + + def tensorize_init_store_compute(): + sch.tensorize(sch.get_loops(block_init_c_inner)[-2], intrin_group["init"]) + sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"]) + sch.tensorize(sch.get_loops(block_inner)[-3], intrin_group["compute"]) + + try: + tensorize_init_store_compute() + tensorize_success = True + except: # pylint: disable=bare-except + intrin_group = get_wmma_intrin_group( + load_scope="shared.dyn", + store_scope="shared.dyn", + in_dtype="float16", + out_dtype="float16", + trans_b=True, + ) + + if not tensorize_success: + try: + tensorize_init_store_compute() + tensorize_success = True + except: # pylint: disable=bare-except + return None + auto_inline_consumer_chain(sch, accumulator_shared_to_global) + + fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:]) + _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size]) + sch.bind(f1, "threadIdx.x") + sch.vectorize(f2) + + return sch if tensorize_success else None diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py index 4f2df5cfa0c9..d03d876595d3 100644 --- a/python/tvm/dlight/gpu/utils.py +++ b/python/tvm/dlight/gpu/utils.py @@ -89,3 +89,11 @@ def suggest_threads_per_block( results[dynamic[0]] *= threads return results + + +def get_sm_version(target: Target) -> int: + if target.kind.name != "cuda": + return -1 + arch = target.arch + sm_version = arch.replace("sm_", "") + return int(sm_version) if sm_version.isdigit() else -1 diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 74b0bd2ba4e1..7d56500fd2fc 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1349,10 +1349,18 @@ def func( ] = None, *, is_size_var: bool = False, + min_value: Optional[int] = None, ) -> PrimExpr: if isinstance(expr, str): expr = float(expr) - return getattr(_ffi_api, name)(expr, is_size_var) + + if min_value is not None: + assert is_size_var, "min_value is only valid for SizeVar" + else: + # set the default min_value for SizeVar + min_value = 0 + + return getattr(_ffi_api, name)(expr, is_size_var, min_value) return func diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index 5547ef82d7a8..29e913d3b565 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -507,7 +507,7 @@ def size_var(name="size", dtype="int32", span=None): var : SizeVar The result symbolic shape variable. """ - return tvm.tir.SizeVar(name, dtype, span) + return tvm.tir.SizeVar(name, dtype, span=span) def thread_axis(dom=None, tag="", name="", span=None): diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py index fad9fca083a1..7a1f6b170b9d 100644 --- a/python/tvm/tir/expr.py +++ b/python/tvm/tir/expr.py @@ -381,13 +381,18 @@ class SizeVar(Var): dtype : Union[str, ir.Type] The data type + min_value : Optional[int] + The minimum value of the SizeVar. Used to assist subsequent analysis. + span : Optional[Span] The location of this expression in the source code. """ # pylint: disable=super-init-not-called - def __init__(self, name: str, dtype: Union[str, ir.Type], span: Optional[Span] = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.SizeVar, name, dtype, span) # type: ignore + def __init__(self, name, dtype, min_value=0, span=None): + self.__init_handle_by_constructor__( + _ffi_api.SizeVar, name, dtype, min_value, span # type: ignore + ) @tvm._ffi.register_object("tir.IterVar") diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index ac7f7cf7d4bd..37194264fdfd 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -611,6 +611,15 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: row, col = T.meta_var(index_map_rev(tx, local_id)) C[row, col] = C_warp[tx, local_id] + with T.block("root"): + T.reads(C_warp[0:WARP_SIZE, 0:local_size]) + T.writes(C[0:M_DIM, 0:N_DIM]) + + for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): + for local_id in T.serial(local_size): + row, col = T.meta_var(index_map_rev(tx, local_id)) + C[row, col] = C_warp[tx, local_id] + return mma_store_desc, mma_store_impl @@ -732,7 +741,7 @@ def get_mma_intrin_group( load_scope = "_dyn" if load_scope == "shared.dyn" else "" load_a_intrin = f"mma_ldmatrix_{in_dtype}_a{trans_a}{load_scope}" load_b_intrin = f"mma_ldmatrix_{in_dtype}_b{trans_b}{load_scope}" - indexmap_a = shared_16x16_to_mma_32x8_layout if in_dtype == "float16" else shared_32x16_to_mma_32x16_layout + # e.g. mma_f16f16f32_trans_a_trans_b trans_a_str = trans_a + "_a" if trans_a != "" else "" trans_b_str = trans_b + "_b" if trans_b != "" else "" diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index 3e5b8834ebca..d555f9671c93 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -219,7 +219,9 @@ bool Analyzer::CanProve(const PrimExpr& expr, ProofStrength strength) { lower_bound = 0; } if (pos_diff) { + // VLOG(0) << pos_diff; IntSet iset = this->int_set(this->Simplify(pos_diff.value())); + // VLOG(0) << iset.min() << ", " << iset.max(); if (iset.HasLowerBound()) { ConstIntBound relaxed_lower_bound = this->const_int_bound(this->Simplify(iset.min())); if (relaxed_lower_bound->min_value >= lower_bound) return true; @@ -290,8 +292,14 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu return PackedFunc( [self](TVMArgs args, TVMRetValue* ret) { *ret = self->canonical_simplify(args[0]); }); } else if (name == "int_set") { - return PackedFunc( - [self](TVMArgs args, TVMRetValue* ret) { *ret = self->int_set(args[0], args[1]); }); + return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { + auto dom_map = args[1].operator Optional>(); + if (dom_map) { + *ret = self->int_set(args[0], dom_map.value()); + } else { + *ret = self->int_set(args[0]); + } + }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue* ret) { if (args[1].IsObjectRef()) { diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index d6554fc37103..3ee3fab7384f 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -347,11 +347,12 @@ ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread, DataType dtype = DataType(min.dtype().code(), bits, 1); n->vars = {Var("v", dtype)}; n->doms = {Range::FromMinExtent(min, extent)}; - n->f_make_for_loop = [annotations, thread, dtype](Array vars, Array doms, - Stmt body) -> For { + n->f_make_for_loop = [annotations, thread, dtype, min, extent](Array vars, Array doms, + Stmt body) -> For { ICHECK_EQ(vars.size(), 1); ICHECK_EQ(doms.size(), 1); - IterVar iter_var(Range(nullptr), Var("iter", dtype), IterVarType::kThreadIndex, thread); + IterVar iter_var(Range::FromMinExtent(min, extent), Var("iter", dtype), + IterVarType::kThreadIndex, thread); return For(vars[0], doms[0]->min, doms[0]->extent, ForKind::kThreadBinding, body, iter_var, annotations.value_or(Map())); }; diff --git a/src/target/tag.cc b/src/target/tag.cc index 9caeec3b9205..d8f341351d21 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -155,8 +155,6 @@ TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2050", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/tesla-c2070", "sm_20", 49152, 32768); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a100", "sm_80", 49152, 65536) .with_config("l2_cache_size_bytes", Integer(41943040)); -TVM_REGISTER_CUDA_TAG("nvidia/nvidia-h100", "sm_90", 49152, 65536) - .with_config("l2_cache_size_bytes", Integer(52428800)); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a40", "sm_86", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a30", "sm_80", 49152, 65536); TVM_REGISTER_CUDA_TAG("nvidia/nvidia-a10", "sm_86", 49152, 65536); diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 2bd1e0608374..8e52c5412aef 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -521,17 +521,22 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) { Var new_loop_var = Downcast(VisitExpr(op->loop_var)); PrimExpr min = VisitExpr(op->min); PrimExpr extent = VisitExpr(op->extent); + Optional thread_binding = NullOpt; + if (op->thread_binding.defined()) { + thread_binding = VisitIterVar(op->thread_binding.value()); + } is_enabled_ = is_enabled; Stmt new_body = VisitStmt(op->body); if (!new_loop_var.same_as(op->loop_var) || !min.same_as(op->min) || !extent.same_as(op->extent) || - !new_body.same_as(op->body)) { + !new_body.same_as(op->body) || !thread_binding.same_as(op->thread_binding)) { For new_for = GetRef(op); auto* n = new_for.CopyOnWrite(); n->loop_var = new_loop_var; n->min = cast(new_loop_var.dtype(), min); n->extent = cast(new_loop_var.dtype(), extent); + n->thread_binding = std::move(thread_binding); n->body = new_body; return std::move(new_for); } else { diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 41500051fa89..59b0189a7004 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -121,27 +121,30 @@ TVM_REGISTER_GLOBAL("tir.Var").set_body_typed([](String name_hint, runtime::TVMA TVM_REGISTER_NODE_TYPE(VarNode); // SizeVar -SizeVar::SizeVar(String name_hint, DataType dtype, Span span) { +SizeVar::SizeVar(String name_hint, DataType dtype, int64_t min_value, Span span) { auto n = make_object(); n->name_hint = std::move(name_hint); n->type_annotation = GetTypeFromRuntimeDataType(dtype); n->dtype = std::move(dtype); + n->min_value = min_value; n->span = std::move(span); data_ = std::move(n); } -SizeVar::SizeVar(String name_hint, Type type_annotation, Span span) { +SizeVar::SizeVar(String name_hint, Type type_annotation, int64_t min_value, Span span) { auto n = make_object(); n->name_hint = std::move(name_hint); n->dtype = GetRuntimeDataType(type_annotation); n->type_annotation = std::move(type_annotation); + n->min_value = min_value; n->span = std::move(span); data_ = std::move(n); } -TVM_REGISTER_GLOBAL("tir.SizeVar").set_body_typed([](String s, DataType t, Span span) { - return SizeVar(s, t, span); -}); +TVM_REGISTER_GLOBAL("tir.SizeVar") + .set_body_typed([](String s, DataType t, int64_t min_value, Span span) { + return SizeVar(s, t, min_value, span); + }); TVM_REGISTER_NODE_TYPE(SizeVarNode); diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index fc388b004843..17e941a59bfa 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -82,6 +82,8 @@ class NotInSameScopeError : public ScheduleError { arith::Analyzer* analyzer) { for (const StmtSRefNode* p = loop_sref.get();; p = p->parent) { if (const ForNode* loop = p->StmtAs()) { + VLOG(0) << "Bind " << loop->loop_var << " to min=" << loop->min + << ", max=" << loop->extent; analyzer->Bind(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); } else if (p != scope_root_sref.get()) { throw NotInSameScopeError(self->mod, block_sref, loop_sref); @@ -258,6 +260,7 @@ class ScopeReconstructor : private StmtMutator { iter_values.reserve(n_iters); PrimExpr predicate = const_true(); for (int i = 0; i < n_iters; ++i) { + VLOG(0) << "visiting: " << i << " " << iter_doms[i].dom << " " << iter_doms[i].bound; Range iter_dom = iter_doms[i].dom.CoverRange(block_->iter_vars[i]->dom); if (preserve_unit_loops || !is_one(iter_dom->extent)) { int bits = std::max(iter_dom->min.dtype().bits(), iter_dom->extent.dtype().bits()); @@ -265,22 +268,32 @@ class ScopeReconstructor : private StmtMutator { loop_vars.push_back(var); loop_extents.push_back(analyzer->Simplify(iter_dom->extent)); iter_values.push_back(iter_dom->min + var); - analyzer->Bind(var, Range::FromMinExtent(IntImm(var.dtype(), 0), iter_dom->extent)); + auto range = Range::FromMinExtent(IntImm(var.dtype(), 0), iter_dom->extent); + VLOG(0) << "Bind: " << var << " to " << range; + analyzer->Bind(var, range); } else { iter_values.push_back(iter_dom->min); } const arith::IntSet& pred_bound = iter_doms[i].bound; if (!pred_bound.IsNothing()) { // NOTE: Apply strong analyzer proofs to get rid of symbolic bound + VLOG(0) << iter_values[i]; if (pred_bound.HasLowerBound()) { PrimExpr lower_bound = iter_values[i] >= pred_bound.min(); - if (!analyzer->CanProve(lower_bound, arith::ProofStrength::kSymbolicBound)) { + VLOG(0) << lower_bound; + auto res = analyzer->CanProve(lower_bound, arith::ProofStrength::kSymbolicBound); + VLOG(0) << res; + if (!res) { predicate = predicate && lower_bound; } } if (pred_bound.HasUpperBound()) { PrimExpr upper_bound = iter_values[i] < pred_bound.max() + 1; - if (!analyzer->CanProve(upper_bound, arith::ProofStrength::kSymbolicBound)) { + VLOG(0) << upper_bound; + VLOG(0) << analyzer->int_set(analyzer->Simplify(iter_values[i])); + auto res = analyzer->CanProve(upper_bound, arith::ProofStrength::kSymbolicBound); + VLOG(0) << res; + if (!res) { predicate = predicate && upper_bound; } } @@ -738,6 +751,10 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s /*required_regions=*/std::move(required_regions), /*analyzer=*/analyzer); // Step 6. Create the new scope according to the iteration domain + VLOG(0) << block->iter_vars; + for (const BlockVarDomainInfo& info : iter_doms) { + VLOG(0) << info.dom << " " << info.bound; + } reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms), /*analyzer=*/analyzer, /*preserve_unit_loops=*/preserve_unit_loops); Block new_scope_root = Downcast(reconstructor(scope_root)); diff --git a/src/tir/schedule/primitive/for_kind.cc b/src/tir/schedule/primitive/for_kind.cc index 9690cd78c82f..9c6be83a6023 100644 --- a/src/tir/schedule/primitive/for_kind.cc +++ b/src/tir/schedule/primitive/for_kind.cc @@ -171,8 +171,7 @@ void ParallelizeComputation(const ScheduleState& self, const StmtSRef& loop_sref ObjectPtr new_loop = make_object(*loop); new_loop->kind = for_kind; if (thread_axis.defined()) { - const String& thread_tag = thread_axis.value(); - new_loop->thread_binding = IterVar(/*dom=*/Range(nullptr), // + new_loop->thread_binding = IterVar(/*dom=*/Range::FromMinExtent(loop->min, loop->extent), // /*var=*/Var(thread_axis.value(), loop->loop_var.dtype()), // /*iter_type=*/kThreadIndex, // /*thread_tag=*/thread_axis.value()); diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index a6b97bf17906..8f347c54e74d 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -399,6 +399,7 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array // Currently, loops not starting with 0 are not supported arith::Analyzer analyzer; CheckLoopStartsWithZero(self, loop_sref, &analyzer); + // should add AddShapeVarBounds // Find the most common dtype DataType dtype; @@ -433,6 +434,8 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const Array &opaque_block_reuse)(std::move(new_stmt)); // Step 3. Update predicate to guard the loop PrimExpr predicate = substitute_value < loop->extent; + // VLOG(0) << predicate; + // VLOG(0) << analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound); if (!analyzer.CanProve(predicate, arith::ProofStrength::kSymbolicBound)) { new_stmt = BlockPredicateAppender(/*predicate=*/predicate)(std::move(new_stmt)); } diff --git a/src/tir/transforms/lift_thread_binding.cc b/src/tir/transforms/lift_thread_binding.cc index 9d7d455dbaed..6d12604ce1f6 100644 --- a/src/tir/transforms/lift_thread_binding.cc +++ b/src/tir/transforms/lift_thread_binding.cc @@ -130,7 +130,7 @@ class ThreadBindingLifter : public StmtExprMutator { for (const auto& [iter_var, annotation] : it->second) { body = For(iter_var->var, iter_var->dom->min, iter_var->dom->extent, ForKind::kThreadBinding, std::move(body), - IterVar(Range(nullptr), Var(iter_var->thread_tag, iter_var->var->dtype), + IterVar(iter_var->dom, Var(iter_var->thread_tag, iter_var->var->dtype), kThreadIndex, iter_var->thread_tag), annotation); } diff --git a/src/tir/transforms/memhammer_coalesce.cc b/src/tir/transforms/memhammer_coalesce.cc index 5ca20f57aa78..98a502eeb0b3 100644 --- a/src/tir/transforms/memhammer_coalesce.cc +++ b/src/tir/transforms/memhammer_coalesce.cc @@ -128,7 +128,8 @@ Stmt SplitBindVectorize(const Stmt& stmt, const ConstraintSet& constraints) { body = For(new_loop_vars.back(), 0, vector_len, ForKind::kVectorized, std::move(body)); for (int i = n - 2; i >= 1; i--) { body = For(new_loop_vars[i], 0, factors[i], ForKind::kThreadBinding, std::move(body), - IterVar(Range(nullptr), Var(thread_axis[i - 1]), kThreadIndex, thread_axis[i - 1])); + IterVar(Range::FromMinExtent(0, factors[i]), Var(thread_axis[i - 1]), kThreadIndex, + thread_axis[i - 1])); } return For(new_loop_vars[0], 0, factors[0], ForKind::kSerial, std::move(body)); } diff --git a/tests/python/dlight/test_gpu_matmul.py b/tests/python/dlight/test_gpu_matmul.py index 82f481da469d..130215a82ee8 100644 --- a/tests/python/dlight/test_gpu_matmul.py +++ b/tests/python/dlight/test_gpu_matmul.py @@ -674,5 +674,110 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)), # fmt: on +class TestLastPadding(BaseBeforeAfter): + # fmt: off + + @T.prim_func + def before( + A: T.Buffer((T.int64(8), T.int64(1023), T.int64(32000)), "float32"), + B: T.Buffer((T.int64(4096), T.int64(32000)), "float32"), + out: T.Buffer((T.int64(8), T.int64(1024), T.int64(4096)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + C = T.alloc_buffer((T.int64(8), T.int64(1023), T.int64(4096)), "float32") + for i0, i1, i2, k in T.grid(T.int64(8), T.int64(1023), T.int64(4096), T.int64(32000)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) + T.writes(C[v_i0, v_i1, v_i2]) + with T.init(): + C[v_i0, v_i1, v_i2] = T.float16(0) + C[v_i0, v_i1, v_i2] = C[v_i0, v_i1, v_i2] + A[v_i0, v_i1, v_k] * B[v_i2, v_k] + for i, j, k in T.grid(T.int64(8), T.int64(1024), T.int64(4096)): + with T.block("out"): + v_i, v_j, v_k = T.axis.remap("SSS", [i, j, k]) + T.reads(C[v_i, v_j, v_k]) + T.writes(out[v_i, v_j, v_k]) + out[v_i, v_j, v_k] = T.if_then_else(v_j != T.int64(1023), C[v_i, v_j, v_k], T.float16(0)) + + @T.prim_func + def expected(A: T.Buffer((T.int64(8), T.int64(1023), T.int64(32000)), "float32"), B: T.Buffer((T.int64(4096), T.int64(32000)), "float32"), out: T.Buffer((T.int64(8), T.int64(1024), T.int64(4096)), "float32")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + C_reindex_pad = T.alloc_buffer((T.int64(1), T.int64(8192), T.int64(4096))) + C_reindex_pad_local = T.alloc_buffer((T.int64(1), T.int64(8192), T.int64(4096)), scope="local") + A_reindex_pad_shared = T.alloc_buffer((T.int64(1), T.int64(8192), T.int64(32000)), scope="shared") + B_reindex_shared = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(32000)), scope="shared") + for ax0_ax2_0_fused in T.thread_binding(T.int64(64), thread="blockIdx.y"): + for ax1_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"): + for ax2_1 in T.thread_binding(T.int64(1), thread="vthread.y"): + for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.x"): + for ax2_2 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax1_2 in T.thread_binding(T.int64(8), thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_3_init, ax1_3_0_init in T.grid(T.int64(4), T.int64(2)): + for ax1_3_1_init in T.vectorized(T.int64(2)): + with T.block("NT_matmul_init"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(8192), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0_init * T.int64(2) + ax1_3_1_init) + v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3_init) + T.reads() + T.writes(C_reindex_pad_local[T.int64(0), v1, v2]) + C_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0) + for ax3_0 in range(T.int64(2000)): + for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in range(T.int64(2)): + for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): + with T.block("A_reindex_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(8192), ax1_0 * T.int64(32) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) + v2 = T.axis.spatial(T.int64(32000), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(32) + ax0_ax1_ax2_fused_1 * T.int64(4) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) + T.reads(A[v1 // T.int64(1023), v1 % T.int64(1023), v2]) + T.writes(A_reindex_pad_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + A_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < T.int64(8184), A[v1 // T.int64(1023), v1 % T.int64(1023), v2], T.float32(0)) + for ax0_ax1_ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0_ax1_ax2_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): + for ax0_ax1_ax2_fused_2 in range(T.int64(4)): + for ax0_ax1_ax2_fused_3 in T.vectorized(T.int64(2)): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) // T.int64(16)) + v2 = T.axis.spatial(T.int64(32000), ax3_0 * T.int64(16) + (ax0_ax1_ax2_fused_0 * T.int64(64) + ax0_ax1_ax2_fused_1 * T.int64(8) + ax0_ax1_ax2_fused_2 * T.int64(2) + ax0_ax1_ax2_fused_3) % T.int64(16)) + T.reads(B[v1, v2]) + T.writes(B_reindex_shared[v0, v1, v2]) + T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) + B_reindex_shared[v0, v1, v2] = B[v1, v2] + for ax3_1, ax2_3, ax1_3_0 in T.grid(T.int64(16), T.int64(4), T.int64(2)): + for ax1_3_1 in T.vectorized(T.int64(2)): + with T.block("NT_matmul_update"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(8192), ax1_0 * T.int64(32) + ax1_1 * T.int64(32) + ax1_2 * T.int64(4) + ax1_3_0 * T.int64(2) + ax1_3_1) + v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(4) + ax2_3) + v3 = T.axis.reduce(T.int64(32000), ax3_0 * T.int64(16) + ax3_1) + T.reads(C_reindex_pad_local[T.int64(0), v1, v2], A_reindex_pad_shared[T.int64(0), v1, v3], B_reindex_shared[T.int64(0), v2, v3]) + T.writes(C_reindex_pad_local[T.int64(0), v1, v2]) + C_reindex_pad_local[T.int64(0), v1, v2] = C_reindex_pad_local[T.int64(0), v1, v2] + A_reindex_pad_shared[T.int64(0), v1, v3] * B_reindex_shared[T.int64(0), v2, v3] + for ax0, ax1, ax2_0 in T.grid(T.int64(1), T.int64(4), T.int64(2)): + for ax2_1_1 in T.vectorized(T.int64(2)): + with T.block("C_reindex_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(8192), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1) + v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1) + T.reads(C_reindex_pad_local[v0, v1, v2]) + T.writes(C_reindex_pad[v0, v1, v2]) + C_reindex_pad[v0, v1, v2] = C_reindex_pad_local[v0, v1, v2] + for i_j_k_fused_0 in T.thread_binding(T.int64(32768), thread="blockIdx.x"): + for i_j_k_fused_1 in T.thread_binding(T.int64(1024), thread="threadIdx.x"): + with T.block("out"): + v_i = T.axis.spatial(T.int64(8), (i_j_k_fused_0 * T.int64(1024) + i_j_k_fused_1) // T.int64(4194304)) + v_j = T.axis.spatial(T.int64(1024), (i_j_k_fused_0 * T.int64(1024) + i_j_k_fused_1) % T.int64(4194304) // T.int64(4096)) + v_k = T.axis.spatial(T.int64(4096), (i_j_k_fused_0 * T.int64(1024) + i_j_k_fused_1) % T.int64(4096)) + T.reads(C_reindex_pad[T.int64(0), v_i * T.int64(1023) + v_j, v_k]) + T.writes(out[v_i, v_j, v_k]) + out[v_i, v_j, v_k] = T.if_then_else(v_j != T.int64(1023), C_reindex_pad[T.int64(0), v_i * T.int64(1023) + v_j, v_k], T.float32(0)) + # fmt: on + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py b/tests/python/dlight/test_gpu_matmul_tensorize.py deleted file mode 100644 index 72ffb307194a..000000000000 --- a/tests/python/dlight/test_gpu_matmul_tensorize.py +++ /dev/null @@ -1,702 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-docstring -import pytest - -import tvm.testing -from tvm import dlight as dl -from tvm.script import ir as I -from tvm.script import tir as T -from tvm.target import Target - - -class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): - @pytest.fixture - def transform(self): - def transform(mod): - with Target("nvidia/geforce-rtx-2080-ti"): - return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod) - - return transform - - -class TestMatmulTensorize(BaseBeforeAfter): - # fmt: off - - @T.prim_func - def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): - T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) - # with T.block("root"): - for i, j, k in T.grid(256, 256, 256): - with T.block("compute"): - v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) - T.reads(X[v_i, v_k], W[v_j, v_k]) - T.writes(compute[v_i, v_j]) - with T.init(): - compute[v_i, v_j] = T.float16(0) - compute[v_i, v_j] = compute[v_i, v_j] + X[v_i, v_k] * W[v_j, v_k] - - @T.prim_func - def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") - W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") - X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.matrix_a") - W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.matrix_b") - compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", scope="shared.dyn") - compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256, 256), "float16", scope="wmma.accumulator") - for ax0 in T.thread_binding(1, thread="blockIdx.z"): - for ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.x"): - for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"): - for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): - for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): - with T.block("compute_o_init"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) - v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) - T.reads() - T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("compute_init_o"): - v1_i_init_o = T.axis.spatial(1, 0) - v2_i_init_o = T.axis.spatial(1, 0) - T.reads() - T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) - T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) - for ax3_0_0 in range(4, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): - for ax0_ax1_fused_0 in range(4): - for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("X_reindex_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) - v2 = T.axis.spatial(256, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) - T.reads(X[v1, v2]) - T.writes(X_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) - X_reindex_shared_dyn[v0, v1, v2] = X[v1, v2] - for ax0_ax1_fused_0 in range(4): - for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("W_reindex_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) - v2 = T.axis.spatial(256, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) - T.reads(W[v1, v2]) - T.writes(W_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) - W_reindex_shared_dyn[v0, v1, v2] = W[v1, v2] - for ax3_0_1 in range(4, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): - for ax0_0 in T.unroll(2): - for ax1_0 in T.unroll(1): - with T.block("X_reindex_shared.dyn_wmma.matrix_a_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) - v2_o = T.axis.spatial(16, ax3_0_0 * 4 + ax3_0_1 + ax1_0) - T.reads(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) - C = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") - for ax0_0 in T.unroll(2): - for ax1_0 in T.unroll(1): - with T.block("W_reindex_shared.dyn_wmma.matrix_b_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) - v2_o = T.axis.spatial(16, ax3_0_0 * 4 + ax3_0_1 + ax1_0) - T.reads(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) - C = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") - for ax1_0_3, ax2_0_3 in T.grid(2, 2): - with T.block("compute_o_update"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) - v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) - v3_o = T.axis.reduce(16, ax3_0_0 * 4 + ax3_0_1) - T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("compute_o"): - v1_i_o = T.axis.spatial(1, 0) - v2_i_o = T.axis.spatial(1, 0) - v3_i_o = T.axis.reduce(1, 0) - T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) - B = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) - C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) - T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) - for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("compute_reindex_shared.dyn_wmma.accumulator_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) - v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) - T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) - C = T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) - T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") - for ax0_ax1_fused_0 in range(8): - for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(4): - with T.block("compute_reindex_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) - v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) - T.reads(compute_reindex_shared_dyn[v0, v1, v2]) - T.writes(compute[v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) - compute[v1, v2] = compute_reindex_shared_dyn[v0, v1, v2] - - # fmt: on - - -class TestMatmulTensorizeTooSmall(BaseBeforeAfter): - # fmt: off - - @T.prim_func - def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.handle): - T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) - m = T.int32() - X = T.match_buffer(var_X, (m, 256), "float16") - compute = T.match_buffer(var_compute, (m, 15)) - # with T.block("root"): - for i, j, k in T.grid(m, 15, 256): - with T.block("compute"): - v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) - T.reads(X[v_i, v_k], W[v_j, v_k]) - T.writes(compute[v_i, v_j]) - with T.init(): - compute[v_i, v_j] = T.float32(0) - compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("float32", X[v_i, v_k]) * T.Cast("float32", W[v_j, v_k]) - - @T.prim_func - def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.handle): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - m = T.int32() - X = T.match_buffer(var_X, (m, 256), "float16") - compute = T.match_buffer(var_compute, (m, 15)) - # with T.block("root"): - compute_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 64), scope="local") - X_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 256), "float16", scope="shared") - W_reindex_pad_shared = T.alloc_buffer((1, 64, 256), "float16", scope="shared") - for ax0_ax2_0_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax1_0 in T.thread_binding((m + 31) // 32, thread="blockIdx.x"): - for ax2_1 in T.thread_binding(1, thread="vthread.y"): - for ax1_1 in T.thread_binding(1, thread="vthread.x"): - for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): - for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax2_3_init, ax1_3_0_init in T.grid(4, 2): - for ax1_3_1_init in T.vectorized(2): - with T.block("compute_init"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init) - v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) - T.reads() - T.writes(compute_reindex_pad_local[0, v1, v2]) - compute_reindex_pad_local[0, v1, v2] = T.float32(0) - for ax3_0 in range(16): - for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): - for ax0_ax1_ax2_fused_2 in range(2): - for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("X_reindex_pad_shared"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) - v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) - T.reads(X[v1, v2]) - T.writes(X_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) - X_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, X[v1, v2], T.float16(0)) - for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): - for ax0_ax1_ax2_fused_2 in range(4): - for ax0_ax1_ax2_fused_3 in T.vectorized(2): - with T.block("W_reindex_pad_shared"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(64, (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) - v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) - T.reads(W[v1, v2]) - T.writes(W_reindex_pad_shared[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) - W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 15, W[v1, v2], T.float16(0)) - for ax3_1, ax2_3, ax1_3_0 in T.grid(16, 4, 2): - for ax1_3_1 in T.vectorized(2): - with T.block("compute_update"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1) - v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3) - v3 = T.axis.reduce(256, ax3_0 * 16 + ax3_1) - T.reads(compute_reindex_pad_local[0, v1, v2], X_reindex_pad_shared[0, v1, v3], W_reindex_pad_shared[0, v2, v3]) - T.writes(compute_reindex_pad_local[0, v1, v2]) - compute_reindex_pad_local[0, v1, v2] = compute_reindex_pad_local[0, v1, v2] + T.Cast("float32", X_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", W_reindex_pad_shared[0, v2, v3]) - for ax0, ax1, ax2_0 in T.grid(1, 4, 2): - for ax2_1_1 in T.vectorized(2): - with T.block("compute_reindex_pad_local"): - v0 = T.axis.spatial(1, ax0) - v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) - v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) - T.reads(compute_reindex_pad_local[v0, v1, v2]) - T.writes(compute[v1, v2]) - if v1 < m and v2 < 15: - compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2] - # fmt: on - - -class TestMatmulTensorizeEpilogue(BaseBeforeAfter): - # fmt: off - - @T.prim_func - def before(lv686: T.Buffer((T.int32(4096), T.int32(256)), "uint32"), lv687: T.Buffer((T.int32(4096), T.int32(64)), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle): - T.func_attr({"tir.noalias": T.bool(True)}) - n = T.int32() - lv42 = T.match_buffer(p_lv42, (T.int32(1), n, T.int32(2048)), "float16") - lv3 = T.match_buffer(p_lv3, (T.int32(1), n, T.int32(4096)), "float16") - p_output0_intermediate = T.match_buffer(p_output0, (T.int32(1), n, T.int32(4096)), "float16") - # with T.block("root"): - p_output0_intermediate_1 = T.alloc_buffer((T.int32(4096), T.int32(2048)), "float16") - var_NT_matmul_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16") - var_T_divide_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16") - for i, j in T.grid(T.int32(4096), T.int32(2048)): - with T.block("decode"): - v_i, v_j = T.axis.remap("SS", [i, j]) - T.reads(lv686[v_i, v_j // T.int32(8)], lv687[v_i, v_j // T.int32(32)]) - T.writes(p_output0_intermediate_1[v_i, v_j]) - p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v_i, v_j // T.int32(8)], T.Cast("uint32", v_j % T.int32(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v_i, v_j // T.int32(32)] - for i0, i1, i2, k in T.grid(T.int32(1), n, T.int32(4096), T.int32(2048)): - with T.block("NT_matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(lv42[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) - T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - with T.init(): - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) - var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv42[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] - for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)): - with T.block("T_divide"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(lv3[v_ax0, v_ax1, v_ax2]) - T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2]) - var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * T.float16(0.5) - for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)): - with T.block("T_add"): - v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) - T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) - T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) - p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] - - @T.prim_func - def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle): - T.func_attr({"global_symbol": "fused_fused_decode3_fused_NT_matmul6_divide1_add1", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - n = T.int32() - lv42 = T.match_buffer(p_lv42, (1, n, 2048), "float16") - lv3 = T.match_buffer(p_lv3, (1, n, 4096), "float16") - p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") - # with T.block("root"): - lv42_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="shared.dyn") - p_output0_intermediate_1_reindex_shared_dyn = T.alloc_buffer((1, 4096, 2048), "float16", scope="shared.dyn") - lv42_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="wmma.matrix_a") - p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 2048), "float16", scope="wmma.matrix_b") - var_NT_matmul_intermediate_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="shared.dyn") - var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="wmma.accumulator") - for ax0 in T.thread_binding(1, thread="blockIdx.z"): - for ax1_0_0_ax2_0_0_fused in T.thread_binding((n + 127) // 128, thread="blockIdx.x"): - for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"): - for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): - for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): - with T.block("NT_matmul_o_init"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) - v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) - T.reads() - T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("NT_matmul_init_o"): - v1_i_init_o = T.axis.spatial(1, 0) - v2_i_init_o = T.axis.spatial(1, 0) - T.reads() - T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) - T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) - for ax3_0_0 in range(32, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): - for ax0_ax1_fused_0 in range(4): - for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("lv42_reindex_pad_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) - v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) - T.reads(lv42[v0, v1, v2]) - T.writes(lv42_reindex_pad_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) - lv42_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < n, lv42[v0, v1, v2], T.float16(0)) - for ax0_ax1_fused_0 in range(4): - for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("p_output0_intermediate_1_reindex_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) - v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) - T.reads(lv686[v1, v2 // 8], lv687[v1, v2 // 32]) - T.writes(p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) - p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v1, v2 // 32] - for ax3_0_1 in range(4, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): - for ax0_0 in T.unroll(2): - for ax1_0 in T.unroll(1): - with T.block("lv42_reindex_pad_shared.dyn_wmma.matrix_a_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) - v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0) - T.reads(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) - C = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") - for ax0_0 in T.unroll(2): - for ax1_0 in T.unroll(1): - with T.block("p_output0_intermediate_1_reindex_shared.dyn_wmma.matrix_b_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) - v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0) - T.reads(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) - C = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") - for ax1_0_3, ax2_0_3 in T.grid(2, 2): - with T.block("NT_matmul_o_update"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) - v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) - v3_o = T.axis.reduce(128, ax3_0_0 * 4 + ax3_0_1) - T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("NT_matmul_o"): - v1_i_o = T.axis.spatial(1, 0) - v2_i_o = T.axis.spatial(1, 0) - v3_i_o = T.axis.reduce(1, 0) - T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) - B = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) - C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) - T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) - for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn_wmma.accumulator_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) - v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) - T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) - C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) - T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") - for ax0_ax1_fused_0 in range(8): - for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(4): - with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) - v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) - T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]) - T.writes(p_output0_intermediate[0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) - if v1 < n: - p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2] - # fmt: on - - -class TestMatmulInt8Tensorize(BaseBeforeAfter): - # fmt: off - @T.prim_func - def before(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): - T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) - # with T.block("root"): - for i, j, r in T.grid(256, 256, 256): - with T.block("compute"): - v_i, v_j, v_k = T.axis.remap("SSR", [i, j, r]) - T.reads(X[v_i, v_k], W[v_j, v_k]) - T.writes(compute[v_i, v_j]) - with T.init(): - compute[v_i, v_j] = 0 - compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("int32", X[v_i, v_k]) * T.Cast("int32", W[v_j, v_k]) - - @T.prim_func - def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): - T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - # with T.block("root"): - X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") - W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") - X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), "int8", scope="wmma.matrix_a") - W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), "int8", scope="wmma.matrix_b") - compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int32", scope="shared.dyn") - compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256, 256), "int32", scope="wmma.accumulator") - for ax0 in T.thread_binding(1, thread="blockIdx.z"): - for ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.x"): - for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"): - for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): - for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): - with T.block("compute_o_init"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) - v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) - T.reads() - T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("compute_init_o"): - v1_i_init_o = T.axis.spatial(1, 0) - v2_i_init_o = T.axis.spatial(1, 0) - T.reads() - T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) - T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) - for ax3_0_0 in T.serial(16, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): - for ax0_ax1_fused_0 in range(1): - for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("X_reindex_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) - v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) - T.reads(X[v1, v2]) - T.writes(X_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) - X_reindex_shared_dyn[v0, v1, v2] = X[v1, v2] - for ax0_ax1_fused_0 in range(1): - for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("W_reindex_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) - v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) - T.reads(W[v1, v2]) - T.writes(W_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) - W_reindex_shared_dyn[v0, v1, v2] = W[v1, v2] - for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): - for ax0_0 in T.unroll(2): - for ax1_0 in T.unroll(1): - with T.block("X_reindex_shared.dyn_wmma.matrix_a_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) - v2_o = T.axis.spatial(16, ax3_0_0 + ax1_0) - T.reads(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) - C = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") - for ax0_0 in T.unroll(2): - for ax1_0 in T.unroll(1): - with T.block("W_reindex_shared.dyn_wmma.matrix_b_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) - v2_o = T.axis.spatial(16, ax3_0_0 + ax1_0) - T.reads(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) - C = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") - for ax1_0_3, ax2_0_3 in T.grid(2, 2): - with T.block("compute_o_update"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) - v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) - v3_o = T.axis.reduce(16, ax3_0_0 + ax3_0_1) - T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("compute_o"): - v1_i_o = T.axis.spatial(1, 0) - v2_i_o = T.axis.spatial(1, 0) - v3_i_o = T.axis.reduce(1, 0) - T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) - B = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) - C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) - T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) - for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("compute_reindex_shared.dyn_wmma.accumulator_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) - v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) - T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) - C = T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) - T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") - for ax0_ax1_fused_0 in range(8): - for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(4): - with T.block("compute_reindex_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) - v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) - T.reads(compute_reindex_shared_dyn[v0, v1, v2]) - T.writes(compute[v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) - compute[v1, v2] = compute_reindex_shared_dyn[v0, v1, v2] - # fmt: on - - -class TestMatmulInt8Tensorize3d2dDyn(BaseBeforeAfter): - # fmt: off - @T.prim_func - def before(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.handle): - T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) - m = T.int32() - A = T.match_buffer(var_A, (1, m, 22016), "int8") - matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") - # with T.block("root"): - for i0, i1, i2, k in T.grid(1, m, 4096, 22016): - with T.block("matmul"): - v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) - T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) - T.writes(matmul_1[v_i0, v_i1, v_i2]) - with T.init(): - matmul_1[v_i0, v_i1, v_i2] = 0 - matmul_1[v_i0, v_i1, v_i2] = matmul_1[v_i0, v_i1, v_i2] + T.Cast("int32", A[v_i0, v_i1, v_k]) * T.Cast("int32", B[v_i2, v_k]) - - @T.prim_func - def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.handle): - T.func_attr({"op_pattern": 4, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) - m = T.int32() - A = T.match_buffer(var_A, (1, m, 22016), "int8") - matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") - # with T.block("root"): - A_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 22016), "int8", scope="shared.dyn") - B_reindex_shared_dyn = T.alloc_buffer((1, 4096, 22016), "int8", scope="shared.dyn") - A_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (m + 127) // 128 * 128, 22016), "int8", scope="wmma.matrix_a") - B_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 22016), "int8", scope="wmma.matrix_b") - matmul_1_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 4096), "int32", scope="shared.dyn") - matmul_1_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1, (m + 127) // 128 * 128, 4096), "int32", scope="wmma.accumulator") - for ax0 in T.thread_binding(1, thread="blockIdx.z"): - for ax1_0_0_ax2_0_0_fused in T.thread_binding((m + 127) // 128, thread="blockIdx.x"): - for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"): - for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): - for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): - with T.block("matmul_o_init"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) - v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) - T.reads() - T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("matmul_init_o"): - v1_i_init_o = T.axis.spatial(1, 0) - v2_i_init_o = T.axis.spatial(1, 0) - T.reads() - T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - C = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) - T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) - for ax3_0_0 in T.serial(1376, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): - for ax0_ax1_fused_0 in range(1): - for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("A_reindex_pad_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) - v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) - T.reads(A[v0, v1, v2]) - T.writes(A_reindex_pad_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) - A_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < m, A[v0, v1, v2], T.int8(0)) - for ax0_ax1_fused_0 in range(1): - for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): - for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_3 in T.vectorized(4): - with T.block("B_reindex_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) - v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) - T.reads(B[v1, v2]) - T.writes(B_reindex_shared_dyn[v0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) - B_reindex_shared_dyn[v0, v1, v2] = B[v1, v2] - for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): - for ax0_0 in T.unroll(2): - for ax1_0 in T.unroll(1): - with T.block("A_reindex_pad_shared.dyn_wmma.matrix_a_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) - v2_o = T.axis.spatial(1376, ax3_0_0 + ax1_0) - T.reads(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A_1 = T.match_buffer(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) - C = T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "row_major") - for ax0_0 in T.unroll(2): - for ax1_0 in T.unroll(1): - with T.block("B_reindex_shared.dyn_wmma.matrix_b_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) - v2_o = T.axis.spatial(1376, ax3_0_0 + ax1_0) - T.reads(B_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A_1 = T.match_buffer(B_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) - C = T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) - T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "col_major") - for ax1_0_3, ax2_0_3 in T.grid(2, 2): - with T.block("matmul_o_update"): - v0_o = T.axis.spatial(1, ax0) - v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) - v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) - v3_o = T.axis.reduce(1376, ax3_0_0 + ax3_0_1) - T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - with T.block("matmul_o"): - v1_i_o = T.axis.spatial(1, 0) - v2_i_o = T.axis.spatial(1, 0) - v3_i_o = T.axis.reduce(1, 0) - T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) - T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A_1 = T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) - B_1 = T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) - C = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) - T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A_1.data, A_1.elem_offset // A_1.strides[0] // 16 * (A_1.strides[0] // 16) + A_1.elem_offset % A_1.strides[0] // 16, B_1.data, B_1.elem_offset // B_1.strides[0] // 16 * (B_1.strides[0] // 16) + B_1.elem_offset % B_1.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) - for ax0_0, ax1_0 in T.grid(2, 2): - with T.block("matmul_1_reindex_pad_shared.dyn_wmma.accumulator_o"): - v0_o = T.axis.spatial(1, 0) - v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) - v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) - T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - T.writes(matmul_1_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) - A_1 = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) - C = T.match_buffer(matmul_1_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) - T.tvm_store_matrix_sync(A_1.data, 16, 16, 16, A_1.elem_offset // A_1.strides[0] // 16 * (A_1.strides[0] // 16) + A_1.elem_offset % A_1.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") - for ax0_ax1_fused_0 in range(8): - for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_fused_2 in T.vectorized(4): - with T.block("matmul_1_reindex_pad_shared.dyn"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) - v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) - T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2]) - T.writes(matmul_1[0, v1, v2]) - T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) - if v1 < m: - matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2] - # fmt: on - - -if __name__ == "__main__": - tvm.testing.main() diff --git a/tests/python/dlight/test_gpu_matmul_tensorize_mma.py b/tests/python/dlight/test_gpu_matmul_tensorize_mma.py new file mode 100644 index 000000000000..7d7cab25b528 --- /dev/null +++ b/tests/python/dlight/test_gpu_matmul_tensorize_mma.py @@ -0,0 +1,779 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import pytest + +import tvm.testing +from tvm import dlight as dl +from tvm.script import tir as T +from tvm.target import Target + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + @pytest.fixture + def transform(self): + def transform(mod): + with Target("nvidia/nvidia-a100"): + return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod) + + return transform + + +class TestNTMatmulMixedPrecision(BaseBeforeAfter): + # fmt: off + @T.prim_func + def before(p_A: T.handle, p_B: T.handle, p_O: T.handle): + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(128), T.int64(128)), "float16") + B = T.match_buffer(p_B, (T.int64(128), T.int64(128)), "float16") + O = T.match_buffer(p_O, (b, T.int64(128), T.int64(128)), "float16") + var_matmul_intermediate = T.alloc_buffer((b, T.int64(128), T.int64(128))) + for i0, i1, i2, k in T.grid(b, T.int64(128), T.int64(128), T.int64(128)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", A[v_i0, v_i1, v_k]) * T.Cast("float32", B[v_i2, v_k]) + for i0, i1, i2 in T.grid(b, T.int64(128), T.int64(128)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + O[v_i0, v_i1, v_i2] = T.Cast("float16", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def expected(p_A: T.handle, B: T.Buffer((T.int64(128), T.int64(128)), "float16"), p_O: T.handle): + T.func_attr({"tir.is_scheduled": 1}) + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(128), T.int64(128)), "float16") + O = T.match_buffer(p_O, (b, T.int64(128), T.int64(128)), "float16") + # with T.block("root"): + A_reindex_shared_dyn = T.alloc_buffer((T.int64(1), b * T.int64(128), T.int64(128)), "float16", scope="shared.dyn") + A_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), b * T.int64(8), T.int64(8), T.int64(32), T.int64(8)), "float16", scope="warp") + B_reindex_shared_dyn = T.alloc_buffer((T.int64(1), T.int64(128), T.int64(128)), "float16", scope="shared.dyn") + B_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(32), T.int64(8)), "float16", scope="warp") + var_matmul_intermediate_reindex_shared_dyn = T.alloc_buffer((T.int64(1), b * T.int64(128), T.int64(128)), scope="shared.dyn") + var_matmul_intermediate_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), b * T.int64(8), T.int64(8), T.int64(32), T.int64(8)), scope="warp") + for ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused in T.thread_binding(b, thread="blockIdx.x"): + for ax1_0_2_init in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0_2_init in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_0_3_init, ax2_0_3_init in T.grid(T.int64(4), T.int64(4)): + with T.block("matmul_o_init"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(b * T.int64(8), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(8) + ax1_0_2_init * T.int64(4) + ax1_0_3_init) + v2_o = T.axis.spatial(T.int64(8), ax2_0_2_init * T.int64(4) + ax2_0_3_init) + T.reads() + T.writes(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("matmul_init_o"): + v1_i_init_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_init_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads() + T.writes(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + C_warp = T.match_buffer(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=1) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.mma_fill("float32", 8, C_warp.data, C_warp.elem_offset) + for ax3_0_0 in T.serial(T.int64(4), annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 3]}): + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("A_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(b * T.int64(128), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(32)) + v2 = T.axis.spatial(T.int64(128), ax3_0_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(32)) + T.reads(A[v1 // T.int64(128), v1 % T.int64(128), v2]) + T.writes(A_reindex_shared_dyn[v0, v1, v2]) + T.block_attr({"permuted_layout": 1}) + A_reindex_shared_dyn[v0, v1, v2] = A[v1 // T.int64(128), v1 % T.int64(128), v2] + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("B_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(128), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(32)) + v2 = T.axis.spatial(T.int64(128), ax3_0_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(32)) + T.reads(B[v1, v2]) + T.writes(B_reindex_shared_dyn[v0, v1, v2]) + T.block_attr({"permuted_layout": 1}) + B_reindex_shared_dyn[v0, v1, v2] = B[v1, v2] + for ax1_0_2 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax3_0_1 in range(T.int64(2)): + for ax0_0, ax1_0 in T.grid(T.int64(4), T.int64(1)): + with T.block("A_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(T.int64(8) * b, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(8) + ax1_0_2 * T.int64(4) + ax0_0) + v2_o = T.axis.spatial(T.int64(8), ax3_0_0 * T.int64(2) + ax3_0_1 + ax1_0) + T.reads(A_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(A_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.block_attr({"permuted_layout": 1}) + with T.block("A_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(A_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(A_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + warp = T.match_buffer(A_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + shared = T.match_buffer(A_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + T.int64(8) * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * T.int64(16), 1), shared.strides[0] * (tx % T.int64(16)) + T.int64(8) * (tx // T.int64(16))) + for ax0_0, ax1_0 in T.grid(T.int64(4), T.int64(1)): + with T.block("B_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(T.int64(8), ax2_0_2 * T.int64(4) + ax0_0) + v2_o = T.axis.spatial(T.int64(8), ax3_0_0 * T.int64(2) + ax3_0_1 + ax1_0) + T.reads(B_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(B_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.block_attr({"permuted_layout": 1}) + with T.block("B_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(B_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(B_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + warp = T.match_buffer(B_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + shared = T.match_buffer(B_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + T.int64(8) * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * T.int64(16), 1), shared.strides[0] * T.int64(8) * (tx // T.int64(16)) + shared.strides[0] * (tx % T.int64(8)) + T.int64(8) * (tx % T.int64(16) // T.int64(8))) + for ax1_0_3, ax2_0_3 in T.grid(T.int64(4), T.int64(4)): + with T.block("matmul_o_update"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(b * T.int64(8), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(8) + ax1_0_2 * T.int64(4) + ax1_0_3) + v2_o = T.axis.spatial(T.int64(8), ax2_0_2 * T.int64(4) + ax2_0_3) + v3_o = T.axis.reduce(T.int64(8), ax3_0_0 * T.int64(2) + ax3_0_1) + T.reads(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], A_reindex_shared_dyn_warp[T.int64(0), v1_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], B_reindex_shared_dyn_warp[T.int64(0), v2_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("matmul_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v3_i_o = T.axis.reduce(T.int64(1), T.int64(0)) + T.reads(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], A_reindex_shared_dyn_warp[T.int64(0), v1_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], B_reindex_shared_dyn_warp[T.int64(0), v2_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + A_1 = T.match_buffer(A_reindex_shared_dyn_warp[T.int64(0), v1_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + B_1 = T.match_buffer(B_reindex_shared_dyn_warp[T.int64(0), v2_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + C = T.match_buffer(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_1.data, A_1.elem_offset + tx * T.int64(8), B_1.data, B_1.elem_offset + tx * T.int64(8), C.data, C.elem_offset + tx * T.int64(8), T.bool(False)) + T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_1.data, A_1.elem_offset + tx * T.int64(8), B_1.data, B_1.elem_offset + tx * T.int64(8) + T.int64(4), C.data, C.elem_offset + tx * T.int64(8) + T.int64(4), T.bool(False)) + for ax0 in range(T.int64(1)): + for ax1_0 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_1, ax2_1 in T.grid(T.int64(4), T.int64(4)): + with T.block("var_matmul_intermediate_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), ax0) + v1_o = T.axis.spatial(T.int64(8) * b, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(8) + ax1_0 * T.int64(4) + ax1_1) + v2_o = T.axis.spatial(T.int64(8), ax2_0 * T.int64(4) + ax2_1) + T.reads(var_matmul_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(var_matmul_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.block_attr({"permuted_layout": 1}) + with T.block("var_matmul_intermediate_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(var_matmul_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(var_matmul_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + C_warp = T.match_buffer(var_matmul_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=1) + C = T.match_buffer(var_matmul_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=1) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for local_id in range(T.int64(8)): + C[T.int64(8) * (local_id % T.int64(4) // T.int64(2)) + tx // T.int64(4), T.int64(8) * (local_id // T.int64(4)) + tx % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = C_warp[tx, local_id] + for ax0_ax1_fused_0 in range(T.int64(16)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("var_matmul_intermediate_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(b * T.int64(128), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(128)) + v2 = T.axis.spatial(T.int64(128), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(128)) + T.reads(var_matmul_intermediate_reindex_shared_dyn[v0, v1, v2]) + T.writes(O[v1 // T.int64(128), v1 % T.int64(128), v2]) + T.block_attr({"permuted_layout": 1}) + if v1 // T.int64(128) < b: + O[v1 // T.int64(128), v1 % T.int64(128), v2] = T.Cast("float16", var_matmul_intermediate_reindex_shared_dyn[v0, v1, v2]) + # fmt: on + + +class TestTNMatmulMixedPrecision(BaseBeforeAfter): + # fmt: off + @T.prim_func + def before(p_A: T.handle, p_B: T.handle, p_O: T.handle): + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(128), T.int64(128)), "float16") + B = T.match_buffer(p_B, (T.int64(128), T.int64(128)), "float16") + O = T.match_buffer(p_O, (b, T.int64(128), T.int64(128)), "float16") + var_matmul_intermediate = T.alloc_buffer((b, T.int64(128), T.int64(128))) + for i0, i1, i2, k in T.grid(b, T.int64(128), T.int64(128), T.int64(128)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", A[v_i0, v_k, v_i1]) * T.Cast("float32", B[v_k, v_i2]) + for i0, i1, i2 in T.grid(b, T.int64(128), T.int64(128)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + O[v_i0, v_i1, v_i2] = T.Cast("float16", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def expected(p_A: T.handle, B: T.Buffer((T.int64(128), T.int64(128)), "float16"), p_O: T.handle): + T.func_attr({"tir.is_scheduled": 1}) + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(128), T.int64(128)), "float16") + O = T.match_buffer(p_O, (b, T.int64(128), T.int64(128)), "float16") + A_reindex_shared_dyn = T.alloc_buffer((T.int64(1), T.int64(128), b * T.int64(128)), "float16", scope="shared.dyn") + A_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), T.int64(8), b * T.int64(8), T.int64(32), T.int64(8)), "float16", scope="warp") + B_reindex_shared_dyn = T.alloc_buffer((T.int64(1), T.int64(128), T.int64(128)), "float16", scope="shared.dyn") + B_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), T.int64(8), T.int64(8), T.int64(32), T.int64(8)), "float16", scope="warp") + var_matmul_intermediate_reindex_shared_dyn = T.alloc_buffer((T.int64(1), b * T.int64(128), T.int64(128)), scope="shared.dyn") + var_matmul_intermediate_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), b * T.int64(8), T.int64(8), T.int64(32), T.int64(8)), scope="warp") + for ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused in T.thread_binding(b, thread="blockIdx.x"): + for ax1_0_2_init in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0_2_init in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_0_3_init, ax2_0_3_init in T.grid(T.int64(4), T.int64(4)): + with T.block("matmul_o_init"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(b * T.int64(8), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(8) + ax1_0_2_init * T.int64(4) + ax1_0_3_init) + v2_o = T.axis.spatial(T.int64(8), ax2_0_2_init * T.int64(4) + ax2_0_3_init) + T.reads() + T.writes(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("matmul_init_o"): + v1_i_init_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_init_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads() + T.writes(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + C_warp = T.match_buffer(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=1) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.mma_fill("float32", 8, C_warp.data, C_warp.elem_offset) + for ax3_0_0 in T.serial(T.int64(4), annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 3]}): + for ax1_ax0_fused_0 in range(T.int64(4)): + for ax1_ax0_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax1_ax0_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_ax0_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax1_ax0_fused_4 in T.vectorized(T.int64(8)): + with T.block("A_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(b * T.int64(128), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(128) + (ax1_ax0_fused_0 * T.int64(1024) + ax1_ax0_fused_1 * T.int64(512) + ax1_ax0_fused_2 * T.int64(256) + ax1_ax0_fused_3 * T.int64(8) + ax1_ax0_fused_4) % T.int64(128)) + v2 = T.axis.spatial(T.int64(128), ax3_0_0 * T.int64(32) + (ax1_ax0_fused_0 * T.int64(1024) + ax1_ax0_fused_1 * T.int64(512) + ax1_ax0_fused_2 * T.int64(256) + ax1_ax0_fused_3 * T.int64(8) + ax1_ax0_fused_4) // T.int64(128)) + T.reads(A[v1 // T.int64(128), v2, v1 % T.int64(128)]) + T.writes(A_reindex_shared_dyn[v0, v2, v1]) + T.block_attr({"permuted_layout": 1}) + A_reindex_shared_dyn[v0, v2, v1] = A[v1 // T.int64(128), v2, v1 % T.int64(128)] + for ax1_ax0_fused_0 in range(T.int64(4)): + for ax1_ax0_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax1_ax0_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_ax0_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax1_ax0_fused_4 in T.vectorized(T.int64(8)): + with T.block("B_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(128), (ax1_ax0_fused_0 * T.int64(1024) + ax1_ax0_fused_1 * T.int64(512) + ax1_ax0_fused_2 * T.int64(256) + ax1_ax0_fused_3 * T.int64(8) + ax1_ax0_fused_4) % T.int64(128)) + v2 = T.axis.spatial(T.int64(128), ax3_0_0 * T.int64(32) + (ax1_ax0_fused_0 * T.int64(1024) + ax1_ax0_fused_1 * T.int64(512) + ax1_ax0_fused_2 * T.int64(256) + ax1_ax0_fused_3 * T.int64(8) + ax1_ax0_fused_4) // T.int64(128)) + T.reads(B[v2, v1]) + T.writes(B_reindex_shared_dyn[v0, v2, v1]) + T.block_attr({"permuted_layout": 1}) + B_reindex_shared_dyn[v0, v2, v1] = B[v2, v1] + for ax1_0_2 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax3_0_1 in range(T.int64(2)): + for ax0_0, ax1_0 in T.grid(T.int64(1), T.int64(4)): + with T.block("A_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(T.int64(8), ax3_0_0 * T.int64(2) + ax3_0_1 + ax0_0) + v2_o = T.axis.spatial(T.int64(8) * b, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(8) + ax1_0_2 * T.int64(4) + ax1_0) + T.reads(A_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(A_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.block_attr({"permuted_layout": 1}) + with T.block("A_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(A_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(A_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + warp = T.match_buffer(A_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + shared = T.match_buffer(A_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_ldmatrix("float16", T.bool(True), 4, ".b16", warp.data, warp.elem_offset + T.int64(8) * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * T.int64(16), 1), shared.strides[0] * T.int64(8) * (tx // T.int64(16)) + shared.strides[0] * (tx % T.int64(8)) + T.int64(8) * (tx % T.int64(16) // T.int64(8))) + for ax0_0, ax1_0 in T.grid(T.int64(1), T.int64(4)): + with T.block("B_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(T.int64(8), ax3_0_0 * T.int64(2) + ax3_0_1 + ax0_0) + v2_o = T.axis.spatial(T.int64(8), ax2_0_2 * T.int64(4) + ax1_0) + T.reads(B_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(B_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.block_attr({"permuted_layout": 1}) + with T.block("B_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(B_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(B_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + warp = T.match_buffer(B_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + shared = T.match_buffer(B_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_ldmatrix("float16", T.bool(True), 4, ".b16", warp.data, warp.elem_offset + T.int64(8) * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * T.int64(16), 1), shared.strides[0] * (tx % T.int64(16)) + T.int64(8) * (tx // T.int64(16))) + for ax1_0_3, ax2_0_3 in T.grid(T.int64(4), T.int64(4)): + with T.block("matmul_o_update"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(b * T.int64(8), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(8) + ax1_0_2 * T.int64(4) + ax1_0_3) + v2_o = T.axis.spatial(T.int64(8), ax2_0_2 * T.int64(4) + ax2_0_3) + v3_o = T.axis.reduce(T.int64(8), ax3_0_0 * T.int64(2) + ax3_0_1) + T.reads(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], A_reindex_shared_dyn_warp[T.int64(0), v3_o, v1_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], B_reindex_shared_dyn_warp[T.int64(0), v3_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("matmul_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v3_i_o = T.axis.reduce(T.int64(1), T.int64(0)) + T.reads(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], A_reindex_shared_dyn_warp[T.int64(0), v3_o, v1_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], B_reindex_shared_dyn_warp[T.int64(0), v3_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + A_1 = T.match_buffer(A_reindex_shared_dyn_warp[T.int64(0), v3_o, v1_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + B_1 = T.match_buffer(B_reindex_shared_dyn_warp[T.int64(0), v3_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + C = T.match_buffer(var_matmul_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_1.data, A_1.elem_offset + tx * T.int64(8), B_1.data, B_1.elem_offset + tx * T.int64(8), C.data, C.elem_offset + tx * T.int64(8), T.bool(False)) + T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_1.data, A_1.elem_offset + tx * T.int64(8), B_1.data, B_1.elem_offset + tx * T.int64(8) + T.int64(4), C.data, C.elem_offset + tx * T.int64(8) + T.int64(4), T.bool(False)) + for ax0 in range(T.int64(1)): + for ax1_0 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_1, ax2_1 in T.grid(T.int64(4), T.int64(4)): + with T.block("var_matmul_intermediate_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), ax0) + v1_o = T.axis.spatial(T.int64(8) * b, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(8) + ax1_0 * T.int64(4) + ax1_1) + v2_o = T.axis.spatial(T.int64(8), ax2_0 * T.int64(4) + ax2_1) + T.reads(var_matmul_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(var_matmul_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.block_attr({"permuted_layout": 1}) + with T.block("var_matmul_intermediate_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(var_matmul_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(var_matmul_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + C_warp = T.match_buffer(var_matmul_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=1) + C = T.match_buffer(var_matmul_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=1) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for local_id in range(T.int64(8)): + C[T.int64(8) * (local_id % T.int64(4) // T.int64(2)) + tx // T.int64(4), T.int64(8) * (local_id // T.int64(4)) + tx % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = C_warp[tx, local_id] + for ax0_ax1_fused_0 in range(T.int64(16)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("var_matmul_intermediate_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(b * T.int64(128), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(128)) + v2 = T.axis.spatial(T.int64(128), (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(128)) + T.reads(var_matmul_intermediate_reindex_shared_dyn[v0, v1, v2]) + T.writes(O[v1 // T.int64(128), v1 % T.int64(128), v2]) + T.block_attr({"permuted_layout": 1}) + if v1 // T.int64(128) < b: + O[v1 // T.int64(128), v1 % T.int64(128), v2] = T.Cast("float16", var_matmul_intermediate_reindex_shared_dyn[v0, v1, v2]) + # fmt: on + + +class TestMatmulDecode(BaseBeforeAfter): + # fmt: off + @T.prim_func + def before( + data: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), + scale: T.Buffer((T.int64(4096), T.int64(128)), "float16"), + p_A: T.handle, + p_O: T.handle, + ): + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(512), T.int64(4096)), "float16") + O = T.match_buffer(p_O, (b, T.int64(512), T.int64(4096)), "float16") + B_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + O_intermediate = T.alloc_buffer((b, T.int64(512), T.int64(4096))) + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(data[v_i, v_j // T.int64(8)], scale[v_i, v_j // T.int64(32)]) + T.writes(B_intermediate[v_i, v_j]) + B_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(data[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * scale[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(b, T.int64(512), T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(A[v_i0, v_i1, v_k], B_intermediate[v_i2, v_k]) + T.writes(O_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + O_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + O_intermediate[v_i0, v_i1, v_i2] = O_intermediate[v_i0, v_i1, v_i2] + T.Cast("float32", A[v_i0, v_i1, v_k]) * T.Cast("float32", B_intermediate[v_i2, v_k]) + for i0, i1, i2 in T.grid(b, T.int64(512), T.int64(4096)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + T.reads(O_intermediate[v_i0, v_i1, v_i2]) + T.writes(O[v_i0, v_i1, v_i2]) + O[v_i0, v_i1, v_i2] = T.Cast("float16", O_intermediate[v_i0, v_i1, v_i2]) + + @T.prim_func + def expected(data: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), scale: T.Buffer((T.int64(4096), T.int64(128)), "float16"), p_A: T.handle, p_O: T.handle): + T.func_attr({"tir.is_scheduled": 1}) + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(512), T.int64(4096)), "float16") + O = T.match_buffer(p_O, (b, T.int64(512), T.int64(4096)), "float16") + # with T.block("root"): + B_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + A_reindex_shared_dyn = T.alloc_buffer((T.int64(1), b * T.int64(512), T.int64(4096)), "float16", scope="shared.dyn") + A_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), b * T.int64(32), T.int64(256), T.int64(32), T.int64(8)), "float16", scope="warp") + B_intermediate_reindex_shared_dyn = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), "float16", scope="shared.dyn") + B_intermediate_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), T.int64(256), T.int64(256), T.int64(32), T.int64(8)), "float16", scope="warp") + O_intermediate_reindex_shared_dyn = T.alloc_buffer((T.int64(1), b * T.int64(512), T.int64(4096)), scope="shared.dyn") + O_intermediate_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), b * T.int64(32), T.int64(256), T.int64(32), T.int64(8)), scope="warp") + for i_j_fused_0 in T.thread_binding(T.int64(16384), thread="blockIdx.x"): + for i_j_fused_1 in T.thread_binding(T.int64(128), thread="threadIdx.x"): + for i_j_fused_2 in T.unroll(T.int64(2)): + for i_j_fused_3 in T.vectorized(T.int64(4)): + with T.block("decode"): + v_i = T.axis.spatial(T.int64(4096), (i_j_fused_0 * T.int64(1024) + i_j_fused_1 * T.int64(8) + i_j_fused_2 * T.int64(4) + i_j_fused_3) // T.int64(4096)) + v_j = T.axis.spatial(T.int64(4096), (i_j_fused_0 * T.int64(1024) + i_j_fused_1 * T.int64(8) + i_j_fused_2 * T.int64(4) + i_j_fused_3) % T.int64(4096)) + T.reads(data[v_i, v_j // T.int64(8)], scale[v_i, v_j // T.int64(32)]) + T.writes(B_intermediate[v_i, v_j]) + B_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(data[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * scale[v_i, v_j // T.int64(32)] + for ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused in T.thread_binding(b * T.int64(128), thread="blockIdx.x"): + for ax1_0_2_init in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0_2_init in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_0_3_init, ax2_0_3_init in T.grid(T.int64(4), T.int64(4)): + with T.block("NT_matmul_o_init"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(b * T.int64(32), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(8) + ax1_0_2_init * T.int64(4) + ax1_0_3_init) + v2_o = T.axis.spatial(T.int64(256), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(8) + ax2_0_2_init * T.int64(4) + ax2_0_3_init) + T.reads() + T.writes(O_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("NT_matmul_init_o"): + v1_i_init_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_init_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads() + T.writes(O_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + C_warp = T.match_buffer(O_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=1) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.mma_fill("float32", 8, C_warp.data, C_warp.elem_offset) + for ax3_0_0 in T.serial(T.int64(128), annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 3]}): + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("A_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(b * T.int64(512), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(32)) + v2 = T.axis.spatial(T.int64(4096), ax3_0_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(32)) + T.reads(A[v1 // T.int64(512), v1 % T.int64(512), v2]) + T.writes(A_reindex_shared_dyn[v0, v1, v2]) + T.block_attr({"permuted_layout": 1}) + A_reindex_shared_dyn[v0, v1, v2] = A[v1 // T.int64(512), v1 % T.int64(512), v2] + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("B_intermediate_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(4096), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(32)) + v2 = T.axis.spatial(T.int64(4096), ax3_0_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(32)) + T.reads(B_intermediate[v1, v2]) + T.writes(B_intermediate_reindex_shared_dyn[v0, v1, v2]) + T.block_attr({"permuted_layout": 1}) + B_intermediate_reindex_shared_dyn[v0, v1, v2] = B_intermediate[v1, v2] + for ax1_0_2 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax3_0_1 in range(T.int64(2)): + for ax0_0, ax1_0 in T.grid(T.int64(4), T.int64(1)): + with T.block("A_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(T.int64(8) * (b * T.int64(4)), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(8) + ax1_0_2 * T.int64(4) + ax0_0) + v2_o = T.axis.spatial(T.int64(256), ax3_0_0 * T.int64(2) + ax3_0_1 + ax1_0) + T.reads(A_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(A_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.block_attr({"permuted_layout": 1}) + with T.block("A_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(A_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(A_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + warp = T.match_buffer(A_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + shared = T.match_buffer(A_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + T.int64(8) * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * T.int64(16), 1), shared.strides[0] * (tx % T.int64(16)) + T.int64(8) * (tx // T.int64(16))) + for ax0_0, ax1_0 in T.grid(T.int64(4), T.int64(1)): + with T.block("B_intermediate_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(T.int64(256), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(8) + ax2_0_2 * T.int64(4) + ax0_0) + v2_o = T.axis.spatial(T.int64(256), ax3_0_0 * T.int64(2) + ax3_0_1 + ax1_0) + T.reads(B_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(B_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.block_attr({"permuted_layout": 1}) + with T.block("B_intermediate_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(B_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(B_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + warp = T.match_buffer(B_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + shared = T.match_buffer(B_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + T.int64(8) * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * T.int64(16), 1), shared.strides[0] * T.int64(8) * (tx // T.int64(16)) + shared.strides[0] * (tx % T.int64(8)) + T.int64(8) * (tx % T.int64(16) // T.int64(8))) + for ax1_0_3, ax2_0_3 in T.grid(T.int64(4), T.int64(4)): + with T.block("NT_matmul_o_update"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(b * T.int64(32), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(8) + ax1_0_2 * T.int64(4) + ax1_0_3) + v2_o = T.axis.spatial(T.int64(256), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(8) + ax2_0_2 * T.int64(4) + ax2_0_3) + v3_o = T.axis.reduce(T.int64(256), ax3_0_0 * T.int64(2) + ax3_0_1) + T.reads(O_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], A_reindex_shared_dyn_warp[T.int64(0), v1_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], B_intermediate_reindex_shared_dyn_warp[T.int64(0), v2_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(O_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("NT_matmul_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v3_i_o = T.axis.reduce(T.int64(1), T.int64(0)) + T.reads(O_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], A_reindex_shared_dyn_warp[T.int64(0), v1_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], B_intermediate_reindex_shared_dyn_warp[T.int64(0), v2_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(O_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + A_1 = T.match_buffer(A_reindex_shared_dyn_warp[T.int64(0), v1_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + B = T.match_buffer(B_intermediate_reindex_shared_dyn_warp[T.int64(0), v2_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + C = T.match_buffer(O_intermediate_reindex_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_1.data, A_1.elem_offset + tx * T.int64(8), B.data, B.elem_offset + tx * T.int64(8), C.data, C.elem_offset + tx * T.int64(8), T.bool(False)) + T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_1.data, A_1.elem_offset + tx * T.int64(8), B.data, B.elem_offset + tx * T.int64(8) + T.int64(4), C.data, C.elem_offset + tx * T.int64(8) + T.int64(4), T.bool(False)) + for ax0 in range(T.int64(1)): + for ax1_0 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_1, ax2_1 in T.grid(T.int64(4), T.int64(4)): + with T.block("O_intermediate_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), ax0) + v1_o = T.axis.spatial(T.int64(8) * (b * T.int64(4)), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(8) + ax1_0 * T.int64(4) + ax1_1) + v2_o = T.axis.spatial(T.int64(256), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(8) + ax2_0 * T.int64(4) + ax2_1) + T.reads(O_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(O_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.block_attr({"permuted_layout": 1}) + with T.block("O_intermediate_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(O_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(O_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + C_warp = T.match_buffer(O_intermediate_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=1) + C = T.match_buffer(O_intermediate_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=1) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for local_id in range(T.int64(8)): + C[T.int64(8) * (local_id % T.int64(4) // T.int64(2)) + tx // T.int64(4), T.int64(8) * (local_id // T.int64(4)) + tx % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = C_warp[tx, local_id] + for ax0_ax1_fused_0 in range(T.int64(16)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("O_intermediate_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(b * T.int64(512), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(128)) + v2 = T.axis.spatial(T.int64(4096), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(128)) + T.reads(O_intermediate_reindex_shared_dyn[v0, v1, v2]) + T.writes(O[v1 // T.int64(512), v1 % T.int64(512), v2]) + T.block_attr({"permuted_layout": 1}) + if v1 // T.int64(512) < b: + O[v1 // T.int64(512), v1 % T.int64(512), v2] = T.Cast("float16", O_intermediate_reindex_shared_dyn[v0, v1, v2]) + # fmt: on + + +class TestMatmulEpilogue(BaseBeforeAfter): + # fmt: off + @T.prim_func + def before( + B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), + p_A: T.handle, + p_add: T.handle, + p_add1: T.handle, + p_O: T.handle, + ): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(p_A, (n, T.int64(4096)), "float16") + add = T.match_buffer(p_add, (n, T.int64(4096)), "float16") + add1 = T.match_buffer(p_add1, (n, T.int64(4096)), "float16") + O = T.match_buffer(p_O, (n, T.int64(4096)), "float32") + O_intermediate = T.alloc_buffer((n, T.int64(4096))) + O_intermediate1 = T.alloc_buffer((n, T.int64(4096)), "float16") + O_intermediate2 = T.alloc_buffer((n, T.int64(4096)), "float16") + O_intermediate3 = T.alloc_buffer((n, T.int64(4096)), "float16") + for i0, i1, k in T.grid(n, T.int64(4096), T.int64(4096)): + with T.block("NT_matmul"): + v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k]) + with T.init(): + O_intermediate[v_i0, v_i1] = T.float32(0) + O_intermediate[v_i0, v_i1] = O_intermediate[v_i0, v_i1] + T.Cast("float32", A[v_i0, v_k]) * T.Cast("float32", B[v_i1, v_k]) + for i0, i1 in T.grid(n, T.int64(4096)): + with T.block("compute"): + v_i0, v_i1 = T.axis.remap("SS", [i0, i1]) + O_intermediate1[v_i0, v_i1] = T.Cast("float16", O_intermediate[v_i0, v_i1]) + for ax0, ax1 in T.grid(n, T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + O_intermediate2[v_ax0, v_ax1] = O_intermediate1[v_ax0, v_ax1] + add[v_ax0, v_ax1] + for ax0, ax1 in T.grid(n, T.int64(4096)): + with T.block("T_add_1"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + O_intermediate3[v_ax0, v_ax1] = add1[v_ax0, v_ax1] + O_intermediate2[v_ax0, v_ax1] + for ax0, ax1 in T.grid(n, T.int64(4096)): + with T.block("T_cast"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + O[v_ax0, v_ax1] = T.Cast("float32", O_intermediate3[v_ax0, v_ax1]) + + @T.prim_func + def expected(B: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), p_A: T.handle, p_add: T.handle, p_add1: T.handle, p_O: T.handle): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + n = T.int64() + A = T.match_buffer(p_A, (n, T.int64(4096)), "float16") + add = T.match_buffer(p_add, (n, T.int64(4096)), "float16") + add1 = T.match_buffer(p_add1, (n, T.int64(4096)), "float16") + O = T.match_buffer(p_O, (n, T.int64(4096))) + # with T.block("root"): + A_reindex_pad_shared_dyn = T.alloc_buffer((T.int64(1), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(4096)), "float16", scope="shared.dyn") + A_reindex_pad_shared_dyn_warp = T.alloc_buffer((T.int64(1), (n + T.int64(127)) // T.int64(128) * T.int64(8), T.int64(256), T.int64(32), T.int64(8)), "float16", scope="warp") + B_reindex_shared_dyn = T.alloc_buffer((T.int64(1), T.int64(4096), T.int64(4096)), "float16", scope="shared.dyn") + B_reindex_shared_dyn_warp = T.alloc_buffer((T.int64(1), T.int64(256), T.int64(256), T.int64(32), T.int64(8)), "float16", scope="warp") + O_intermediate_reindex_pad_shared_dyn = T.alloc_buffer((T.int64(1), (n + T.int64(127)) // T.int64(128) * T.int64(128), T.int64(4096)), scope="shared.dyn") + O_intermediate_reindex_pad_shared_dyn_warp = T.alloc_buffer((T.int64(1), (n + T.int64(127)) // T.int64(128) * T.int64(8), T.int64(256), T.int64(32), T.int64(8)), scope="warp") + for ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused in T.thread_binding((n + T.int64(127)) // T.int64(128) * T.int64(32), thread="blockIdx.x"): + for ax1_0_2_init in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0_2_init in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_0_3_init, ax2_0_3_init in T.grid(T.int64(4), T.int64(4)): + with T.block("NT_matmul_o_init"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(8), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(8) + ax1_0_2_init * T.int64(4) + ax1_0_3_init) + v2_o = T.axis.spatial(T.int64(256), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(8) + ax2_0_2_init * T.int64(4) + ax2_0_3_init) + T.reads() + T.writes(O_intermediate_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("NT_matmul_init_o"): + v1_i_init_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_init_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads() + T.writes(O_intermediate_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + C_warp = T.match_buffer(O_intermediate_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=1) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.mma_fill("float32", 8, C_warp.data, C_warp.elem_offset) + for ax3_0_0 in T.serial(T.int64(128), annotations={"software_pipeline_async_stages": [0], "software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 3]}): + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("A_reindex_pad_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(32)) + v2 = T.axis.spatial(T.int64(4096), ax3_0_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(32)) + T.reads(A[v1, v2]) + T.writes(A_reindex_pad_shared_dyn[v0, v1, v2]) + T.block_attr({"permuted_layout": 1}) + A_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < n, A[v1, v2], T.float16(0)) + for ax0_ax1_fused_0 in range(T.int64(4)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("B_reindex_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(4096), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(32)) + v2 = T.axis.spatial(T.int64(4096), ax3_0_0 * T.int64(32) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(32)) + T.reads(B[v1, v2]) + T.writes(B_reindex_shared_dyn[v0, v1, v2]) + T.block_attr({"permuted_layout": 1}) + B_reindex_shared_dyn[v0, v1, v2] = B[v1, v2] + for ax1_0_2 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax3_0_1 in range(T.int64(2)): + for ax0_0, ax1_0 in T.grid(T.int64(4), T.int64(1)): + with T.block("A_reindex_pad_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(T.int64(8) * ((n + T.int64(127)) // T.int64(128)), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(8) + ax1_0_2 * T.int64(4) + ax0_0) + v2_o = T.axis.spatial(T.int64(256), ax3_0_0 * T.int64(2) + ax3_0_1 + ax1_0) + T.reads(A_reindex_pad_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(A_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.block_attr({"permuted_layout": 1}) + with T.block("A_reindex_pad_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(A_reindex_pad_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(A_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + warp = T.match_buffer(A_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + shared = T.match_buffer(A_reindex_pad_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + T.int64(8) * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * T.int64(16), 1), shared.strides[0] * (tx % T.int64(16)) + T.int64(8) * (tx // T.int64(16))) + for ax0_0, ax1_0 in T.grid(T.int64(4), T.int64(1)): + with T.block("B_reindex_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial(T.int64(256), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(8) + ax2_0_2 * T.int64(4) + ax0_0) + v2_o = T.axis.spatial(T.int64(256), ax3_0_0 * T.int64(2) + ax3_0_1 + ax1_0) + T.reads(B_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(B_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.block_attr({"permuted_layout": 1}) + with T.block("B_reindex_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(B_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.writes(B_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + warp = T.match_buffer(B_reindex_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + shared = T.match_buffer(B_reindex_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), "float16", strides=("shared_s0", "shared_s1"), scope="shared.dyn", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_ldmatrix("float16", T.bool(False), 4, ".b16", warp.data, warp.elem_offset + T.int64(8) * tx, T.tvm_access_ptr(T.type_annotation("float16"), shared.data, shared.elem_offset, shared.strides[0] * T.int64(16), 1), shared.strides[0] * T.int64(8) * (tx // T.int64(16)) + shared.strides[0] * (tx % T.int64(8)) + T.int64(8) * (tx % T.int64(16) // T.int64(8))) + for ax1_0_3, ax2_0_3 in T.grid(T.int64(4), T.int64(4)): + with T.block("NT_matmul_o_update"): + v0_o = T.axis.spatial(T.int64(1), T.int64(0)) + v1_o = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(8), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(8) + ax1_0_2 * T.int64(4) + ax1_0_3) + v2_o = T.axis.spatial(T.int64(256), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(8) + ax2_0_2 * T.int64(4) + ax2_0_3) + v3_o = T.axis.reduce(T.int64(256), ax3_0_0 * T.int64(2) + ax3_0_1) + T.reads(O_intermediate_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], A_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], B_reindex_shared_dyn_warp[T.int64(0), v2_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(O_intermediate_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + with T.block("NT_matmul_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v3_i_o = T.axis.reduce(T.int64(1), T.int64(0)) + T.reads(O_intermediate_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], A_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], B_reindex_shared_dyn_warp[T.int64(0), v2_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(O_intermediate_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + A_1 = T.match_buffer(A_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + B_1 = T.match_buffer(B_reindex_shared_dyn_warp[T.int64(0), v2_o, v3_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), "float16", scope="warp", offset_factor=16) + C = T.match_buffer(O_intermediate_reindex_pad_shared_dyn_warp[T.int64(0), v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=16) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_1.data, A_1.elem_offset + tx * T.int64(8), B_1.data, B_1.elem_offset + tx * T.int64(8), C.data, C.elem_offset + tx * T.int64(8), T.bool(False)) + T.ptx_mma("float32", "m16n8k16", "row", "col", "fp16", "fp16", "fp32", A_1.data, A_1.elem_offset + tx * T.int64(8), B_1.data, B_1.elem_offset + tx * T.int64(8) + T.int64(4), C.data, C.elem_offset + tx * T.int64(8) + T.int64(4), T.bool(False)) + for ax0 in range(T.int64(1)): + for ax1_0 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax2_0 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax1_1, ax2_1 in T.grid(T.int64(4), T.int64(4)): + with T.block("O_intermediate_reindex_pad_shared.dyn_warp_o"): + v0_o = T.axis.spatial(T.int64(1), ax0) + v1_o = T.axis.spatial(T.int64(8) * ((n + T.int64(127)) // T.int64(128)), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(8) + ax1_0 * T.int64(4) + ax1_1) + v2_o = T.axis.spatial(T.int64(256), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(8) + ax2_0 * T.int64(4) + ax2_1) + T.reads(O_intermediate_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(O_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + T.block_attr({"permuted_layout": 1}) + with T.block("O_intermediate_reindex_pad_shared.dyn_warp_o"): + v1_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + v2_i_o = T.axis.spatial(T.int64(1), T.int64(0)) + T.reads(O_intermediate_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)]) + T.writes(O_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)]) + C_warp = T.match_buffer(O_intermediate_reindex_pad_shared_dyn_warp[v0_o, v1_o, v2_o, T.int64(0):T.int64(32), T.int64(0):T.int64(8)], (T.int64(32), T.int64(8)), scope="warp", offset_factor=1) + C = T.match_buffer(O_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * T.int64(16):v1_o * T.int64(16) + T.int64(16), v2_o * T.int64(16):v2_o * T.int64(16) + T.int64(16)], (T.int64(16), T.int64(16)), strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=1) + for tx in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for local_id in range(T.int64(8)): + C[T.int64(8) * (local_id % T.int64(4) // T.int64(2)) + tx // T.int64(4), T.int64(8) * (local_id // T.int64(4)) + tx % T.int64(4) * T.int64(2) + local_id % T.int64(2)] = C_warp[tx, local_id] + for ax0_ax1_fused_0 in range(T.int64(16)): + for ax0_ax1_fused_1 in T.thread_binding(T.int64(2), thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(T.int64(2), thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(T.int64(8)): + with T.block("O_intermediate_reindex_pad_shared.dyn"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial((n + T.int64(127)) // T.int64(128) * T.int64(128), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // T.int64(32) * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) // T.int64(128)) + v2 = T.axis.spatial(T.int64(4096), ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % T.int64(32) * T.int64(128) + (ax0_ax1_fused_0 * T.int64(1024) + ax0_ax1_fused_1 * T.int64(512) + ax0_ax1_fused_2 * T.int64(256) + ax0_ax1_fused_3 * T.int64(8) + ax0_ax1_fused_4) % T.int64(128)) + T.reads(add1[v1, v2], O_intermediate_reindex_pad_shared_dyn[v0, v1, v2], add[v1, v2]) + T.writes(O[v1, v2]) + T.block_attr({"permuted_layout": 1}) + if v1 < n: + O[v1, v2] = T.Cast("float32", add1[v1, v2] + (T.Cast("float16", O_intermediate_reindex_pad_shared_dyn[v0, v1, v2]) + add[v1, v2])) + # fmt: on + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/dlight/test_gpu_matmul_tensorize_numeric.py b/tests/python/dlight/test_gpu_matmul_tensorize_numeric.py new file mode 100644 index 000000000000..c1bb555c6f12 --- /dev/null +++ b/tests/python/dlight/test_gpu_matmul_tensorize_numeric.py @@ -0,0 +1,219 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +from typing import Callable, List, Literal +import numpy as np +import pytest +from tvm.ir.module import IRModule + +import tvm.testing +from tvm import dlight as dl, tir +from tvm.script import tir as T +from tvm.target import Target + + +def do_numeric_test( + before: tir.PrimFunc, + input_np: List[np.ndarray], + compute: Callable[[List[np.ndarray]], np.ndarray], + target: Target, + dev: tvm.runtime.Device, + rule_name: Literal["mma", "wmma"] = "mma", + atol: float = 1e-3, + rtol: float = 1e-3, +): + before_mod = IRModule.from_expr(before.without_attr("global_symbol")) + rule = ( + dl.gpu.MatmulTensorizationMMA() if rule_name == "mma" else dl.gpu.MatmulTensorizationWMMA() + ) + with target: + after_mod = dl.ApplyDefaultSchedule(rule)(before_mod) + after = after_mod["main"] + # build + with tvm.transform.PassContext(config={"tir.use_async_copy": 1}): + ex = tvm.build(after, target=target) + + tvm_input = [tvm.nd.array(x, dev) for x in input_np] + ex(*tvm_input) + tvm_result_np = tvm_input[-1].numpy() + + np_result = compute(*input_np[:-1]) + assert np.allclose(np_result, tvm_result_np, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("rule_name", ["mma", "wmma"]) +def test_nt_matmul_mixed_precision(rule_name): + @T.prim_func + def before(p_A: T.handle, p_B: T.handle, p_O: T.handle): + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(256), T.int64(256)), "float16") + B = T.match_buffer(p_B, (T.int64(256), T.int64(256)), "float16") + O = T.match_buffer(p_O, (b, T.int64(256), T.int64(256)), "float16") + var_matmul_intermediate = T.alloc_buffer((b, T.int64(256), T.int64(256))) + for i0, i1, i2, k in T.grid(b, T.int64(256), T.int64(256), T.int64(256)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", A[v_i0, v_i1, v_k]) * T.Cast("float32", B[v_i2, v_k]) + for i0, i1, i2 in T.grid(b, T.int64(256), T.int64(256)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + O[v_i0, v_i1, v_i2] = T.Cast("float16", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + b = 2 + inputs = [ + np.random.normal(size=(b, 256, 256)).astype(np.float16), + np.random.normal(size=(256, 256)).astype(np.float16), + np.zeros((b, 256, 256), dtype=np.float16), + ] + np_compute = lambda x, y: np.matmul(x, y.T) + do_numeric_test(before, inputs, np_compute, Target("cuda"), tvm.cuda(), rule_name) + + +@pytest.mark.parametrize("rule_name", ["mma", "wmma"]) +def test_batched_nt_matmul_mixed_precision(rule_name): + @T.prim_func + def before(p_A: T.handle, p_B: T.handle, p_O: T.handle): + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(256), T.int64(256)), "float16") + B = T.match_buffer(p_B, (b, T.int64(256), T.int64(256)), "float16") + O = T.match_buffer(p_O, (b, T.int64(256), T.int64(256)), "float16") + var_matmul_intermediate = T.alloc_buffer((b, T.int64(256), T.int64(256))) + for i0, i1, i2, k in T.grid(b, T.int64(256), T.int64(256), T.int64(256)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", A[v_i0, v_i1, v_k]) * T.Cast("float32", B[v_i0, v_i2, v_k]) + for i0, i1, i2 in T.grid(b, T.int64(256), T.int64(256)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + O[v_i0, v_i1, v_i2] = T.Cast("float16", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + b = 2 + inputs = [ + np.random.normal(size=(b, 256, 256)).astype(np.float16), + np.random.normal(size=(b, 256, 256)).astype(np.float16), + np.zeros((b, 256, 256), dtype=np.float16), + ] + np_compute = lambda x, y: np.matmul(x, y.transpose(0, 2, 1)) + do_numeric_test(before, inputs, np_compute, Target("cuda"), tvm.cuda(), rule_name) + + +@pytest.mark.parametrize("rule_name", ["mma"]) +def test_nn_matmul_mixed_precision(rule_name: Literal['mma']): + @T.prim_func + def before(p_A: T.handle, p_B: T.handle, p_O: T.handle): + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(256), T.int64(256)), "float16") + B = T.match_buffer(p_B, (T.int64(256), T.int64(256)), "float16") + O = T.match_buffer(p_O, (b, T.int64(256), T.int64(256)), "float16") + var_matmul_intermediate = T.alloc_buffer((b, T.int64(256), T.int64(256))) + for i0, i1, i2, k in T.grid(b, T.int64(256), T.int64(256), T.int64(256)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", A[v_i0, v_i1, v_k]) * T.Cast("float32", B[v_k, v_i2]) + for i0, i1, i2 in T.grid(b, T.int64(256), T.int64(256)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + O[v_i0, v_i1, v_i2] = T.Cast("float16", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + b = 2 + inputs = [ + np.random.normal(size=(b, 256, 256)).astype(np.float16), + np.random.normal(size=(256, 256)).astype(np.float16), + np.zeros((b, 256, 256), dtype=np.float16), + ] + np_compute = lambda x, y: np.matmul(x, y) + do_numeric_test(before, inputs, np_compute, Target("cuda"), tvm.cuda(), rule_name) + + +@pytest.mark.parametrize("rule_name", ["mma"]) +def test_tn_matmul_mixed_precision(rule_name): + @T.prim_func + def before(p_A: T.handle, p_B: T.handle, p_O: T.handle): + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(256), T.int64(256)), "float16") + B = T.match_buffer(p_B, (T.int64(256), T.int64(256)), "float16") + O = T.match_buffer(p_O, (b, T.int64(256), T.int64(256)), "float16") + var_matmul_intermediate = T.alloc_buffer((b, T.int64(256), T.int64(256))) + for i0, i1, i2, k in T.grid(b, T.int64(256), T.int64(256), T.int64(256)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", A[v_i0, v_k, v_i1]) * T.Cast("float32", B[v_k, v_i2]) + for i0, i1, i2 in T.grid(b, T.int64(256), T.int64(256)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + O[v_i0, v_i1, v_i2] = T.Cast("float16", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + b = 2 + inputs = [ + np.random.normal(size=(b, 256, 256)).astype(np.float16), + np.random.normal(size=(256, 256)).astype(np.float16), + np.zeros((b, 256, 256), dtype=np.float16), + ] + np_compute = lambda x, y: np.matmul(x.transpose(0, 2, 1), y) + do_numeric_test(before, inputs, np_compute, Target("cuda"), tvm.cuda(), rule_name) + + +@pytest.mark.parametrize("rule_name", ["mma"]) +def test_tt_matmul_mixed_precision(rule_name): + @T.prim_func + def before(p_A: T.handle, p_B: T.handle, p_O: T.handle): + b = T.int64() + A = T.match_buffer(p_A, (b, T.int64(256), T.int64(256)), "float16") + B = T.match_buffer(p_B, (T.int64(256), T.int64(256)), "float16") + O = T.match_buffer(p_O, (b, T.int64(256), T.int64(256)), "float16") + var_matmul_intermediate = T.alloc_buffer((b, T.int64(256), T.int64(256))) + for i0, i1, i2, k in T.grid(b, T.int64(256), T.int64(256), T.int64(256)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float32(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[ + v_i0, v_i1, v_i2 + ] + T.Cast("float32", A[v_i0, v_k, v_i1]) * T.Cast("float32", B[v_i2, v_k]) + for i0, i1, i2 in T.grid(b, T.int64(256), T.int64(256)): + with T.block("compute"): + v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2]) + O[v_i0, v_i1, v_i2] = T.Cast("float16", var_matmul_intermediate[v_i0, v_i1, v_i2]) + + b = 2 + inputs = [ + np.random.normal(size=(b, 256, 256)).astype(np.float16), + np.random.normal(size=(256, 256)).astype(np.float16), + np.zeros((b, 256, 256), dtype=np.float16), + ] + np_compute = lambda x, y: np.matmul(x.transpose(0, 2, 1), y.T) + do_numeric_test(before, inputs, np_compute, Target("cuda"), tvm.cuda(), rule_name) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/dlight/test_gpu_matmul_tensorize_wmma.py b/tests/python/dlight/test_gpu_matmul_tensorize_wmma.py new file mode 100644 index 000000000000..89a3a3832c30 --- /dev/null +++ b/tests/python/dlight/test_gpu_matmul_tensorize_wmma.py @@ -0,0 +1,718 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import pytest + +import tvm.testing +from tvm import dlight as dl +from tvm.script import ir as I +from tvm.script import tir as T +from tvm.target import Target + + +class BaseBeforeAfter(tvm.testing.CompareBeforeAfter): + @pytest.fixture + def transform(self): + def transform(mod): + with Target("nvidia/geforce-rtx-2080-ti"): + return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod) + + return transform + + +# class TestMatmulTensorize(BaseBeforeAfter): +def test1(): + # fmt: off + @T.prim_func + def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + # with T.block("root"): + for i, j, k in T.grid(256, 256, 256): + with T.block("compute"): + v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) + T.reads(X[v_i, v_k], W[v_j, v_k]) + T.writes(compute[v_i, v_j]) + with T.init(): + compute[v_i, v_j] = T.float16(0) + compute[v_i, v_j] = compute[v_i, v_j] + X[v_i, v_k] * W[v_j, v_k] + + @T.prim_func + def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), "float16"), compute: T.Buffer((256, 256), "float16")): + T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + X_reindex_shared_dyn = T.alloc_buffer((2, 8, 8, 2, 16, 16), "float16", scope="shared.dyn") + X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((2, 8, 8, 2, 16, 16), "float16", scope="wmma.matrix_a") + W_reindex_shared_dyn = T.alloc_buffer((2, 8, 8, 2, 16, 16), "float16", scope="shared.dyn") + W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((2, 8, 8, 2, 16, 16), "float16", scope="wmma.matrix_b") + compute_reindex_shared_dyn = T.alloc_buffer((2, 2, 8, 8, 16, 16), "float16", scope="shared.dyn") + compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((2, 2, 8, 8, 16, 16), "float16", scope="wmma.accumulator") + for ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused in T.thread_binding(4, thread="blockIdx.x"): + for ax1_0_2_init in T.thread_binding(2, thread="threadIdx.z"): + for ax2_0_2_init in T.thread_binding(2, thread="threadIdx.y"): + for ax1_0_3_init, ax2_0_3_init in T.grid(4, 4): + with T.block("compute_o_init"): + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(16, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // 2 * 8 + ax1_0_2_init * 4 + ax1_0_3_init) + v2_o = T.axis.spatial(16, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % 2 * 8 + ax2_0_2_init * 4 + ax2_0_3_init) + T.reads() + T.writes(compute_reindex_shared_dyn_wmma_accumulator[v1_o // 8, v2_o // 8, v1_o % 8, v2_o % 8, 0:16, 0:16]) + with T.block("compute_init_o"): + v1_i_init_o = T.axis.spatial(1, 0) + v2_i_init_o = T.axis.spatial(1, 0) + T.reads() + T.writes(compute_reindex_shared_dyn_wmma_accumulator[v1_o // 8, v2_o // 8, v1_o % 8, v2_o % 8, 0:16, 0:16]) + C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v1_o // 8, v2_o // 8, v1_o % 8, v2_o % 8, 0:16, 0:16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) + T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) + for ax3_0_0 in T.unroll(8): + for ax0_ax1_fused_0 in range(4): + for ax0_ax1_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(8): + with T.block("X_reindex_shared.dyn"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(256, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // 2 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 512 + ax0_ax1_fused_2 * 256 + ax0_ax1_fused_3 * 8 + ax0_ax1_fused_4) // 32) + v2 = T.axis.spatial(256, ax3_0_0 * 32 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 512 + ax0_ax1_fused_2 * 256 + ax0_ax1_fused_3 * 8 + ax0_ax1_fused_4) % 32) + T.reads(X[v1, v2]) + T.writes(X_reindex_shared_dyn[v1 // 128, v2 // 32, v1 % 128 // 16, v2 % 32 // 16, v1 % 16, v2 % 16]) + X_reindex_shared_dyn[v1 // 128, v2 // 32, v1 % 128 // 16, v2 % 32 // 16, v1 % 16, v2 % 16] = X[v1, v2] + for ax0_ax1_fused_0 in range(4): + for ax0_ax1_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(8): + with T.block("W_reindex_shared.dyn"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(256, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % 2 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 512 + ax0_ax1_fused_2 * 256 + ax0_ax1_fused_3 * 8 + ax0_ax1_fused_4) // 32) + v2 = T.axis.spatial(256, ax3_0_0 * 32 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 512 + ax0_ax1_fused_2 * 256 + ax0_ax1_fused_3 * 8 + ax0_ax1_fused_4) % 32) + T.reads(W[v1, v2]) + T.writes(W_reindex_shared_dyn[v1 // 128, v2 // 32, v1 % 128 // 16, v2 % 32 // 16, v1 % 16, v2 % 16]) + W_reindex_shared_dyn[v1 // 128, v2 // 32, v1 % 128 // 16, v2 % 32 // 16, v1 % 16, v2 % 16] = W[v1, v2] + for ax1_0_2 in T.thread_binding(2, thread="threadIdx.z"): + for ax2_0_2 in T.thread_binding(2, thread="threadIdx.y"): + for ax3_0_1 in range(2): + for ax0 in range(4): + with T.block("X_reindex_shared.dyn_wmma.matrix_a_o"): + v0_o = T.axis.spatial(2, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // 2) + v1_o = T.axis.spatial(8, ax3_0_0) + v2_o = T.axis.spatial(8, ax1_0_2 * 4 + ax0) + v3_o = T.axis.spatial(2, ax3_0_1) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(X_reindex_shared_dyn[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + A = T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) + C = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") + for ax0 in range(4): + with T.block("W_reindex_shared.dyn_wmma.matrix_b_o"): + v0_o = T.axis.spatial(2, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % 2) + v1_o = T.axis.spatial(8, ax3_0_0) + v2_o = T.axis.spatial(8, ax2_0_2 * 4 + ax0) + v3_o = T.axis.spatial(2, ax3_0_1) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(W_reindex_shared_dyn[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + A = T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) + C = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) + T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") + for ax1_0_3, ax2_0_3 in T.grid(4, 4): + with T.block("compute_o_update"): + v0_o = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(16, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // 2 * 8 + ax1_0_2 * 4 + ax1_0_3) + v2_o = T.axis.spatial(16, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % 2 * 8 + ax2_0_2 * 4 + ax2_0_3) + v3_o = T.axis.reduce(16, ax3_0_0 * 2 + ax3_0_1) + T.reads(compute_reindex_shared_dyn_wmma_accumulator[v1_o // 8, v2_o // 8, v1_o % 8, v2_o % 8, 0:16, 0:16], X_reindex_shared_dyn_wmma_matrix_a[v1_o * 16 // 128, v3_o * 16 // 32, (v1_o * 16 - 128 * (v1_o * 16 // 128)) // 16, (v3_o * 16 - 32 * (v3_o * 16 // 32)) // 16, 0:16, 0:16], W_reindex_shared_dyn_wmma_matrix_b[v2_o * 16 // 128, v3_o * 16 // 32, (v2_o * 16 - 128 * (v2_o * 16 // 128)) // 16, (v3_o * 16 - 32 * (v3_o * 16 // 32)) // 16, 0:16, 0:16]) + T.writes(compute_reindex_shared_dyn_wmma_accumulator[v1_o // 8, v2_o // 8, v1_o % 8, v2_o % 8, 0:16, 0:16]) + with T.block("compute_o"): + v1_i_o = T.axis.spatial(1, 0) + v2_i_o = T.axis.spatial(1, 0) + v3_i_o = T.axis.reduce(1, 0) + T.reads(compute_reindex_shared_dyn_wmma_accumulator[v1_o // 8, v2_o // 8, v1_o % 8, v2_o % 8, 0:16, 0:16], X_reindex_shared_dyn_wmma_matrix_a[v1_o // 8, v3_o // 2, v1_o % 8, v3_o % 2, 0:16, 0:16], W_reindex_shared_dyn_wmma_matrix_b[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, 0:16, 0:16]) + T.writes(compute_reindex_shared_dyn_wmma_accumulator[v1_o // 8, v2_o // 8, v1_o % 8, v2_o % 8, 0:16, 0:16]) + A = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v1_o // 8, v3_o // 2, v1_o % 8, v3_o % 2, 0:16, 0:16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) + B = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v2_o // 8, v3_o // 2, v2_o % 8, v3_o % 2, 0:16, 0:16], (16, 16), "float16", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) + C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v1_o // 8, v2_o // 8, v1_o % 8, v2_o % 8, 0:16, 0:16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) + T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) + for ax0, ax1 in T.grid(1, 1): + for ax2_0 in T.thread_binding(2, thread="threadIdx.z"): + for ax3_0 in T.thread_binding(2, thread="threadIdx.y"): + for ax2_1, ax3_1 in T.grid(4, 4): + with T.block("compute_reindex_shared.dyn_wmma.accumulator_o"): + v0_o = T.axis.spatial(2, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // 2 + ax0) + v1_o = T.axis.spatial(2, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % 2 + ax1) + v2_o = T.axis.spatial(8, ax2_0 * 4 + ax2_1) + v3_o = T.axis.spatial(8, ax3_0 * 4 + ax3_1) + v4_o = T.axis.spatial(1, 0) + v5_o = T.axis.spatial(1, 0) + T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + T.writes(compute_reindex_shared_dyn[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16]) + A = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) + C = T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o, v2_o, v3_o, 0:16, 0:16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) + T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") + for ax0_ax1_fused_0 in range(16): + for ax0_ax1_fused_1 in T.thread_binding(2, thread="threadIdx.z"): + for ax0_ax1_fused_2 in T.thread_binding(2, thread="threadIdx.y"): + for ax0_ax1_fused_3 in T.thread_binding(32, thread="threadIdx.x"): + for ax0_ax1_fused_4 in T.vectorized(8): + with T.block("compute_reindex_shared.dyn"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(256, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused // 2 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 512 + ax0_ax1_fused_2 * 256 + ax0_ax1_fused_3 * 8 + ax0_ax1_fused_4) // 128) + v2 = T.axis.spatial(256, ax0_ax1_0_0_ax2_0_0_ax1_0_1_ax2_0_1_fused % 2 * 128 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 512 + ax0_ax1_fused_2 * 256 + ax0_ax1_fused_3 * 8 + ax0_ax1_fused_4) % 128) + T.reads(compute_reindex_shared_dyn[v1 // 128, v2 // 128, v1 % 128 // 16, v2 % 128 // 16, v1 % 16, v2 % 16]) + T.writes(compute[v1, v2]) + compute[v1, v2] = compute_reindex_shared_dyn[v1 // 128, v2 // 128, v1 % 128 // 16, v2 % 128 // 16, v1 % 16, v2 % 16] + # fmt: on + before_mod = tvm.IRModule.from_expr(before.without_attr("global_symbol")) + with Target("nvidia/geforce-rtx-2080-ti"): + after_mod = dl.ApplyDefaultSchedule(dl.gpu.Matmul())(before_mod) + after_mod.show() + +test1() +exit() + +# class TestMatmulTensorizeTooSmall(BaseBeforeAfter): +# # fmt: off + +# @T.prim_func +# def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.handle): +# T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) +# m = T.int32() +# X = T.match_buffer(var_X, (m, 256), "float16") +# compute = T.match_buffer(var_compute, (m, 15)) +# # with T.block("root"): +# for i, j, k in T.grid(m, 15, 256): +# with T.block("compute"): +# v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k]) +# T.reads(X[v_i, v_k], W[v_j, v_k]) +# T.writes(compute[v_i, v_j]) +# with T.init(): +# compute[v_i, v_j] = T.float32(0) +# compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("float32", X[v_i, v_k]) * T.Cast("float32", W[v_j, v_k]) + +# @T.prim_func +# def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.handle): +# T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) +# m = T.int32() +# X = T.match_buffer(var_X, (m, 256), "float16") +# compute = T.match_buffer(var_compute, (m, 15)) +# # with T.block("root"): +# compute_reindex_pad_local = T.alloc_buffer((1, (m + 31) // 32 * 32, 64), scope="local") +# X_reindex_pad_shared = T.alloc_buffer((1, (m + 31) // 32 * 32, 256), "float16", scope="shared") +# W_reindex_pad_shared = T.alloc_buffer((1, 64, 256), "float16", scope="shared") +# for ax0_ax2_0_fused in T.thread_binding(1, thread="blockIdx.y"): +# for ax1_0 in T.thread_binding((m + 31) // 32, thread="blockIdx.x"): +# for ax2_1 in T.thread_binding(1, thread="vthread.y"): +# for ax1_1 in T.thread_binding(1, thread="vthread.x"): +# for ax2_2 in T.thread_binding(16, thread="threadIdx.y"): +# for ax1_2 in T.thread_binding(8, thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): +# for ax2_3_init, ax1_3_0_init in T.grid(4, 2): +# for ax1_3_1_init in T.vectorized(2): +# with T.block("compute_init"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0_init * 2 + ax1_3_1_init) +# v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3_init) +# T.reads() +# T.writes(compute_reindex_pad_local[0, v1, v2]) +# compute_reindex_pad_local[0, v1, v2] = T.float32(0) +# for ax3_0 in range(16): +# for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): +# for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): +# for ax0_ax1_ax2_fused_2 in range(2): +# for ax0_ax1_ax2_fused_3 in T.vectorized(2): +# with T.block("X_reindex_pad_shared"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) +# v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) +# T.reads(X[v1, v2]) +# T.writes(X_reindex_pad_shared[v0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) +# X_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, X[v1, v2], T.float16(0)) +# for ax0_ax1_ax2_fused_0 in T.thread_binding(16, thread="threadIdx.y"): +# for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.x"): +# for ax0_ax1_ax2_fused_2 in range(4): +# for ax0_ax1_ax2_fused_3 in T.vectorized(2): +# with T.block("W_reindex_pad_shared"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial(64, (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16) +# v2 = T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16) +# T.reads(W[v1, v2]) +# T.writes(W_reindex_pad_shared[v0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]}) +# W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 15, W[v1, v2], T.float16(0)) +# for ax3_1, ax2_3, ax1_3_0 in T.grid(16, 4, 2): +# for ax1_3_1 in T.vectorized(2): +# with T.block("compute_update"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3_0 * 2 + ax1_3_1) +# v2 = T.axis.spatial(64, ax2_1 * 64 + ax2_2 * 4 + ax2_3) +# v3 = T.axis.reduce(256, ax3_0 * 16 + ax3_1) +# T.reads(compute_reindex_pad_local[0, v1, v2], X_reindex_pad_shared[0, v1, v3], W_reindex_pad_shared[0, v2, v3]) +# T.writes(compute_reindex_pad_local[0, v1, v2]) +# compute_reindex_pad_local[0, v1, v2] = compute_reindex_pad_local[0, v1, v2] + T.Cast("float32", X_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", W_reindex_pad_shared[0, v2, v3]) +# for ax0, ax1, ax2_0 in T.grid(1, 4, 2): +# for ax2_1_1 in T.vectorized(2): +# with T.block("compute_reindex_pad_local"): +# v0 = T.axis.spatial(1, ax0) +# v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1) +# v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1) +# T.reads(compute_reindex_pad_local[v0, v1, v2]) +# T.writes(compute[v1, v2]) +# if v1 < m and v2 < 15: +# compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2] +# # fmt: on + + +# class TestMatmulTensorizeEpilogue(BaseBeforeAfter): +# # fmt: off + +# @T.prim_func +# def before(lv686: T.Buffer((T.int32(4096), T.int32(256)), "uint32"), lv687: T.Buffer((T.int32(4096), T.int32(64)), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle): +# T.func_attr({"tir.noalias": T.bool(True)}) +# n = T.int32() +# lv42 = T.match_buffer(p_lv42, (T.int32(1), n, T.int32(2048)), "float16") +# lv3 = T.match_buffer(p_lv3, (T.int32(1), n, T.int32(4096)), "float16") +# p_output0_intermediate = T.match_buffer(p_output0, (T.int32(1), n, T.int32(4096)), "float16") +# # with T.block("root"): +# p_output0_intermediate_1 = T.alloc_buffer((T.int32(4096), T.int32(2048)), "float16") +# var_NT_matmul_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16") +# var_T_divide_intermediate = T.alloc_buffer((T.int32(1), n, T.int32(4096)), "float16") +# for i, j in T.grid(T.int32(4096), T.int32(2048)): +# with T.block("decode"): +# v_i, v_j = T.axis.remap("SS", [i, j]) +# T.reads(lv686[v_i, v_j // T.int32(8)], lv687[v_i, v_j // T.int32(32)]) +# T.writes(p_output0_intermediate_1[v_i, v_j]) +# p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v_i, v_j // T.int32(8)], T.Cast("uint32", v_j % T.int32(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v_i, v_j // T.int32(32)] +# for i0, i1, i2, k in T.grid(T.int32(1), n, T.int32(4096), T.int32(2048)): +# with T.block("NT_matmul"): +# v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) +# T.reads(lv42[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) +# T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) +# with T.init(): +# var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) +# var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv42[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] +# for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)): +# with T.block("T_divide"): +# v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) +# T.reads(lv3[v_ax0, v_ax1, v_ax2]) +# T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2]) +# var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] * T.float16(0.5) +# for ax0, ax1, ax2 in T.grid(T.int32(1), n, T.int32(4096)): +# with T.block("T_add"): +# v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) +# T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) +# T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) +# p_output0_intermediate[v_ax0, v_ax1, v_ax2] = var_T_divide_intermediate[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + +# @T.prim_func +# def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64), "float16"), p_lv42: T.handle, p_lv3: T.handle, p_output0: T.handle): +# T.func_attr({"global_symbol": "fused_fused_decode3_fused_NT_matmul6_divide1_add1", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) +# n = T.int32() +# lv42 = T.match_buffer(p_lv42, (1, n, 2048), "float16") +# lv3 = T.match_buffer(p_lv3, (1, n, 4096), "float16") +# p_output0_intermediate = T.match_buffer(p_output0, (1, n, 4096), "float16") +# # with T.block("root"): +# lv42_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="shared.dyn") +# p_output0_intermediate_1_reindex_shared_dyn = T.alloc_buffer((1, 4096, 2048), "float16", scope="shared.dyn") +# lv42_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (n + 127) // 128 * 128, 2048), "float16", scope="wmma.matrix_a") +# p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 2048), "float16", scope="wmma.matrix_b") +# var_NT_matmul_intermediate_reindex_pad_shared_dyn = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="shared.dyn") +# var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1, (n + 127) // 128 * 128, 4096), "float16", scope="wmma.accumulator") +# for ax0 in T.thread_binding(1, thread="blockIdx.z"): +# for ax1_0_0_ax2_0_0_fused in T.thread_binding((n + 127) // 128, thread="blockIdx.x"): +# for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"): +# for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): +# for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): +# with T.block("NT_matmul_o_init"): +# v0_o = T.axis.spatial(1, ax0) +# v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) +# v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) +# T.reads() +# T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# with T.block("NT_matmul_init_o"): +# v1_i_init_o = T.axis.spatial(1, 0) +# v2_i_init_o = T.axis.spatial(1, 0) +# T.reads() +# T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) +# T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) +# for ax3_0_0 in range(32, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): +# for ax0_ax1_fused_0 in range(4): +# for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): +# for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): +# for ax0_ax1_fused_3 in T.vectorized(4): +# with T.block("lv42_reindex_pad_shared.dyn"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) +# v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) +# T.reads(lv42[v0, v1, v2]) +# T.writes(lv42_reindex_pad_shared_dyn[v0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) +# lv42_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < n, lv42[v0, v1, v2], T.float16(0)) +# for ax0_ax1_fused_0 in range(4): +# for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): +# for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): +# for ax0_ax1_fused_3 in T.vectorized(4): +# with T.block("p_output0_intermediate_1_reindex_shared.dyn"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 64) +# v2 = T.axis.spatial(2048, ax3_0_0 * 64 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 64) +# T.reads(lv686[v1, v2 // 8], lv687[v1, v2 // 32]) +# T.writes(p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) +# p_output0_intermediate_1_reindex_shared_dyn[v0, v1, v2] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv686[v1, v2 // 8], T.Cast("uint32", v2 % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv687[v1, v2 // 32] +# for ax3_0_1 in range(4, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): +# for ax0_0 in T.unroll(2): +# for ax1_0 in T.unroll(1): +# with T.block("lv42_reindex_pad_shared.dyn_wmma.matrix_a_o"): +# v0_o = T.axis.spatial(1, 0) +# v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) +# v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0) +# T.reads(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# T.writes(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A = T.match_buffer(lv42_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) +# C = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) +# T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") +# for ax0_0 in T.unroll(2): +# for ax1_0 in T.unroll(1): +# with T.block("p_output0_intermediate_1_reindex_shared.dyn_wmma.matrix_b_o"): +# v0_o = T.axis.spatial(1, 0) +# v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) +# v2_o = T.axis.spatial(128, ax3_0_0 * 4 + ax3_0_1 + ax1_0) +# T.reads(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# T.writes(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) +# C = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) +# T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") +# for ax1_0_3, ax2_0_3 in T.grid(2, 2): +# with T.block("NT_matmul_o_update"): +# v0_o = T.axis.spatial(1, ax0) +# v1_o = T.axis.spatial((n + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) +# v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) +# v3_o = T.axis.reduce(128, ax3_0_0 * 4 + ax3_0_1) +# T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) +# T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# with T.block("NT_matmul_o"): +# v1_i_o = T.axis.spatial(1, 0) +# v2_i_o = T.axis.spatial(1, 0) +# v3_i_o = T.axis.reduce(1, 0) +# T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) +# T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A = T.match_buffer(lv42_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) +# B = T.match_buffer(p_output0_intermediate_1_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) +# C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) +# T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) +# for ax0_0, ax1_0 in T.grid(2, 2): +# with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn_wmma.accumulator_o"): +# v0_o = T.axis.spatial(1, 0) +# v1_o = T.axis.spatial(8 * ((n + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) +# v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) +# T.reads(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# T.writes(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) +# C = T.match_buffer(var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) +# T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") +# for ax0_ax1_fused_0 in range(8): +# for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): +# for ax0_ax1_fused_2 in T.vectorized(4): +# with T.block("var_NT_matmul_intermediate_reindex_pad_shared.dyn"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) +# v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) +# T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]) +# T.writes(p_output0_intermediate[0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) +# if v1 < n: +# p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2] +# # fmt: on + + +# class TestMatmulInt8Tensorize(BaseBeforeAfter): +# # fmt: off +# @T.prim_func +# def before(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): +# T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) +# # with T.block("root"): +# for i, j, r in T.grid(256, 256, 256): +# with T.block("compute"): +# v_i, v_j, v_k = T.axis.remap("SSR", [i, j, r]) +# T.reads(X[v_i, v_k], W[v_j, v_k]) +# T.writes(compute[v_i, v_j]) +# with T.init(): +# compute[v_i, v_j] = 0 +# compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("int32", X[v_i, v_k]) * T.Cast("int32", W[v_j, v_k]) + +# @T.prim_func +# def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256), "int8"), compute: T.Buffer((256, 256), "int32")): +# T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) +# # with T.block("root"): +# X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") +# W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8", scope="shared.dyn") +# X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), "int8", scope="wmma.matrix_a") +# W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), "int8", scope="wmma.matrix_b") +# compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int32", scope="shared.dyn") +# compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256, 256), "int32", scope="wmma.accumulator") +# for ax0 in T.thread_binding(1, thread="blockIdx.z"): +# for ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.x"): +# for ax1_0_1_ax2_0_1_fused in T.thread_binding(2, thread="blockIdx.y"): +# for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): +# for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): +# with T.block("compute_o_init"): +# v0_o = T.axis.spatial(1, ax0) +# v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) +# v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) +# T.reads() +# T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# with T.block("compute_init_o"): +# v1_i_init_o = T.axis.spatial(1, 0) +# v2_i_init_o = T.axis.spatial(1, 0) +# T.reads() +# T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) +# T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) +# for ax3_0_0 in T.serial(16, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): +# for ax0_ax1_fused_0 in range(1): +# for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): +# for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): +# for ax0_ax1_fused_3 in T.vectorized(4): +# with T.block("X_reindex_shared.dyn"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) +# v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) +# T.reads(X[v1, v2]) +# T.writes(X_reindex_shared_dyn[v0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) +# X_reindex_shared_dyn[v0, v1, v2] = X[v1, v2] +# for ax0_ax1_fused_0 in range(1): +# for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): +# for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): +# for ax0_ax1_fused_3 in T.vectorized(4): +# with T.block("W_reindex_shared.dyn"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) +# v2 = T.axis.spatial(256, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) +# T.reads(W[v1, v2]) +# T.writes(W_reindex_shared_dyn[v0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) +# W_reindex_shared_dyn[v0, v1, v2] = W[v1, v2] +# for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): +# for ax0_0 in T.unroll(2): +# for ax1_0 in T.unroll(1): +# with T.block("X_reindex_shared.dyn_wmma.matrix_a_o"): +# v0_o = T.axis.spatial(1, 0) +# v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) +# v2_o = T.axis.spatial(16, ax3_0_0 + ax1_0) +# T.reads(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A = T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) +# C = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) +# T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major") +# for ax0_0 in T.unroll(2): +# for ax1_0 in T.unroll(1): +# with T.block("W_reindex_shared.dyn_wmma.matrix_b_o"): +# v0_o = T.axis.spatial(1, 0) +# v1_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) +# v2_o = T.axis.spatial(16, ax3_0_0 + ax1_0) +# T.reads(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A = T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) +# C = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) +# T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major") +# for ax1_0_3, ax2_0_3 in T.grid(2, 2): +# with T.block("compute_o_update"): +# v0_o = T.axis.spatial(1, ax0) +# v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) +# v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) +# v3_o = T.axis.reduce(16, ax3_0_0 + ax3_0_1) +# T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) +# T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# with T.block("compute_o"): +# v1_i_o = T.axis.spatial(1, 0) +# v2_i_o = T.axis.spatial(1, 0) +# v3_i_o = T.axis.reduce(1, 0) +# T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) +# T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A = T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) +# B = T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) +# C = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) +# T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) +# for ax0_0, ax1_0 in T.grid(2, 2): +# with T.block("compute_reindex_shared.dyn_wmma.accumulator_o"): +# v0_o = T.axis.spatial(1, 0) +# v1_o = T.axis.spatial(16, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) +# v2_o = T.axis.spatial(16, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) +# T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# T.writes(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A = T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) +# C = T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) +# T.tvm_store_matrix_sync(A.data, 16, 16, 16, A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") +# for ax0_ax1_fused_0 in range(8): +# for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): +# for ax0_ax1_fused_2 in T.vectorized(4): +# with T.block("compute_reindex_shared.dyn"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) +# v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) +# T.reads(compute_reindex_shared_dyn[v0, v1, v2]) +# T.writes(compute[v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) +# compute[v1, v2] = compute_reindex_shared_dyn[v0, v1, v2] +# # fmt: on + + +# class TestMatmulInt8Tensorize3d2dDyn(BaseBeforeAfter): +# # fmt: off +# @T.prim_func +# def before(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.handle): +# T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)}) +# m = T.int32() +# A = T.match_buffer(var_A, (1, m, 22016), "int8") +# matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") +# # with T.block("root"): +# for i0, i1, i2, k in T.grid(1, m, 4096, 22016): +# with T.block("matmul"): +# v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) +# T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k]) +# T.writes(matmul_1[v_i0, v_i1, v_i2]) +# with T.init(): +# matmul_1[v_i0, v_i1, v_i2] = 0 +# matmul_1[v_i0, v_i1, v_i2] = matmul_1[v_i0, v_i1, v_i2] + T.Cast("int32", A[v_i0, v_i1, v_k]) * T.Cast("int32", B[v_i2, v_k]) + +# @T.prim_func +# def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.handle): +# T.func_attr({"op_pattern": 4, "tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) +# m = T.int32() +# A = T.match_buffer(var_A, (1, m, 22016), "int8") +# matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32") +# # with T.block("root"): +# A_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 22016), "int8", scope="shared.dyn") +# B_reindex_shared_dyn = T.alloc_buffer((1, 4096, 22016), "int8", scope="shared.dyn") +# A_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (m + 127) // 128 * 128, 22016), "int8", scope="wmma.matrix_a") +# B_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 22016), "int8", scope="wmma.matrix_b") +# matmul_1_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128, 4096), "int32", scope="shared.dyn") +# matmul_1_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1, (m + 127) // 128 * 128, 4096), "int32", scope="wmma.accumulator") +# for ax0 in T.thread_binding(1, thread="blockIdx.z"): +# for ax1_0_0_ax2_0_0_fused in T.thread_binding((m + 127) // 128, thread="blockIdx.x"): +# for ax1_0_1_ax2_0_1_fused in T.thread_binding(32, thread="blockIdx.y"): +# for ax2_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): +# for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2): +# with T.block("matmul_o_init"): +# v0_o = T.axis.spatial(1, ax0) +# v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init) +# v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init) +# T.reads() +# T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# with T.block("matmul_init_o"): +# v1_i_init_o = T.axis.spatial(1, 0) +# v2_i_init_o = T.axis.spatial(1, 0) +# T.reads() +# T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# C = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) +# T.tvm_fill_fragment(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.float32(0)) +# for ax3_0_0 in T.serial(1376, annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}): +# for ax0_ax1_fused_0 in range(1): +# for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): +# for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): +# for ax0_ax1_fused_3 in T.vectorized(4): +# with T.block("A_reindex_pad_shared.dyn"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) +# v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) +# T.reads(A[v0, v1, v2]) +# T.writes(A_reindex_pad_shared_dyn[v0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) +# A_reindex_pad_shared_dyn[v0, v1, v2] = T.if_then_else(v1 < m, A[v0, v1, v2], T.int8(0)) +# for ax0_ax1_fused_0 in range(1): +# for ax0_ax1_fused_1 in T.thread_binding(16, thread="threadIdx.y"): +# for ax0_ax1_fused_2 in T.thread_binding(32, thread="threadIdx.x"): +# for ax0_ax1_fused_3 in T.vectorized(4): +# with T.block("B_reindex_shared.dyn"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16) +# v2 = T.axis.spatial(22016, ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16) +# T.reads(B[v1, v2]) +# T.writes(B_reindex_shared_dyn[v0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0, "tir.manifest_shared_memory_local_stage": 1}) +# B_reindex_shared_dyn[v0, v1, v2] = B[v1, v2] +# for ax3_0_1 in T.serial(1, annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage": [0, 0, 1]}): +# for ax0_0 in T.unroll(2): +# for ax1_0 in T.unroll(1): +# with T.block("A_reindex_pad_shared.dyn_wmma.matrix_a_o"): +# v0_o = T.axis.spatial(1, 0) +# v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) +# v2_o = T.axis.spatial(1376, ax3_0_0 + ax1_0) +# T.reads(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# T.writes(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A_1 = T.match_buffer(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) +# C = T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_a", offset_factor=16) +# T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "row_major") +# for ax0_0 in T.unroll(2): +# for ax1_0 in T.unroll(1): +# with T.block("B_reindex_shared.dyn_wmma.matrix_b_o"): +# v0_o = T.axis.spatial(1, 0) +# v1_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0) +# v2_o = T.axis.spatial(1376, ax3_0_0 + ax1_0) +# T.reads(B_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# T.writes(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A_1 = T.match_buffer(B_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="shared.dyn", offset_factor=16) +# C = T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"), scope="wmma.matrix_b", offset_factor=16) +# T.tvm_load_matrix_sync(C.data, 16, 16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"), A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "col_major") +# for ax1_0_3, ax2_0_3 in T.grid(2, 2): +# with T.block("matmul_o_update"): +# v0_o = T.axis.spatial(1, ax0) +# v1_o = T.axis.spatial((m + 127) // 128 * 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3) +# v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3) +# v3_o = T.axis.reduce(1376, ax3_0_0 + ax3_0_1) +# T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) +# T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# with T.block("matmul_o"): +# v1_i_o = T.axis.spatial(1, 0) +# v2_i_o = T.axis.spatial(1, 0) +# v3_i_o = T.axis.reduce(1, 0) +# T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16]) +# T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A_1 = T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"), scope="wmma.matrix_a", offset_factor=16) +# B_1 = T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"), scope="wmma.matrix_b", offset_factor=16) +# C = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16) +# T.tvm_mma_sync(C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16, A_1.data, A_1.elem_offset // A_1.strides[0] // 16 * (A_1.strides[0] // 16) + A_1.elem_offset % A_1.strides[0] // 16, B_1.data, B_1.elem_offset // B_1.strides[0] // 16 * (B_1.strides[0] // 16) + B_1.elem_offset % B_1.strides[0] // 16, C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % C.strides[0] // 16) +# for ax0_0, ax1_0 in T.grid(2, 2): +# with T.block("matmul_1_reindex_pad_shared.dyn_wmma.accumulator_o"): +# v0_o = T.axis.spatial(1, 0) +# v1_o = T.axis.spatial(8 * ((m + 127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0) +# v2_o = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0) +# T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# T.writes(matmul_1_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16]) +# A_1 = T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16) +# C = T.match_buffer(matmul_1_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"), scope="shared.dyn", offset_factor=16) +# T.tvm_store_matrix_sync(A_1.data, 16, 16, 16, A_1.elem_offset // A_1.strides[0] // 16 * (A_1.strides[0] // 16) + A_1.elem_offset % A_1.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major") +# for ax0_ax1_fused_0 in range(8): +# for ax0_ax1_fused_1 in T.thread_binding(32, thread="threadIdx.x"): +# for ax0_ax1_fused_2 in T.vectorized(4): +# with T.block("matmul_1_reindex_pad_shared.dyn"): +# v0 = T.axis.spatial(1, 0) +# v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32) +# v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32) +# T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2]) +# T.writes(matmul_1[0, v1, v2]) +# T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]}) +# if v1 < m: +# matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2] +# # fmt: on + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-base/test_tir_nodes.py b/tests/python/tir-base/test_tir_nodes.py index 49816778f11f..08ff000eb276 100644 --- a/tests/python/tir-base/test_tir_nodes.py +++ b/tests/python/tir-base/test_tir_nodes.py @@ -19,6 +19,8 @@ import tvm from tvm import ir, te +import tvm.testing + def test_const(): x = tvm.tir.const(1, "int32") @@ -364,6 +366,14 @@ def test_vars(): assert isinstance(ptype.element_type, tvm.ir.PrimType) +def test_size_vars(): + x = tvm.tir.SizeVar("x", "int32") + assert x.dtype == "int32" + x = tvm.tir.SizeVar("x", "int64", 1) + assert x.dtype == "int64" + assert x.min_value == 1 + + def test_scoped_storage_vars(): dtype = "float" storage_scope = "global.texture" diff --git a/tests/python/tir-transform/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py b/tests/python/tir-transform/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py new file mode 100644 index 000000000000..d704dc243891 --- /dev/null +++ b/tests/python/tir-transform/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py @@ -0,0 +1,341 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import numpy as np +import pytest +import tvm +import tvm.testing +from tvm import te +from tvm.testing.tir import mma_schedule +from tvm.tir.tensor_intrin.cuda import ( + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_INTRIN, + LDMATRIX_f16_B_TRANS_INTRIN, + LDMATRIX_i8_A_INTRIN, + LDMATRIX_i8_B_TRANS_INTRIN, + LDMATRIX_i8_B_INTRIN, + MMA_f16f16f16_INTRIN, + MMA_f16f16f16_TRANS_B_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_f16f16f32_TRANS_B_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_i8i8i32_INTRIN, + MMA_i8i8i32_TRANS_B_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + MMA_store_16x16_i32_global_INTRIN, + shared_16x16_to_ldmatrix_32x8_layout, + shared_16x32_to_ldmatrix_32x16_layout, + shared_32x16_to_ldmatrix_32x16_layout, +) + +M = 4096 +N = 4096 +K = 4096 +measure_perf = False +gflops = (N * M * K) * 2 / 1e9 + + +def matmul(m, n, k, in_dtype, out_dtype, b_transposed): + b_shape = (n, k) if b_transposed else (k, n) + a = te.placeholder((m, k), name="A", dtype=in_dtype) + b = te.placeholder(b_shape, name="B", dtype=in_dtype) + k = te.reduce_axis((0, k), name="k") + + def maybe_cast(v): + if in_dtype != out_dtype: + return tvm.tir.Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if b_transposed: + return j, i + return i, j + + c = te.compute( + (m, n), + lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, j)]), axis=[k]), + name="C", + ) + return (a, b, c) + + +def run_test( + k_inner, + in_dtype, + out_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, +): + sch = mma_schedule( + te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)), + k_inner, + in_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, + ) + + f = tvm.build(sch.mod["main"], target="cuda", name="dense") + + dev = tvm.device("cuda", 0) + + if in_dtype == "float16": + a_np = np.random.normal(size=(M, K)).astype("float16") + + if b_transposed: + b_np = np.random.normal(size=(N, K)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + out_dtype + ) + else: + b_np = np.random.normal(size=(K, N)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) + else: + a_np = np.random.randint(-128, 128, (M, K)).astype("int8") + + if b_transposed: + b_np = np.random.randint(-128, 128, (N, K)).astype("int8") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + "int32" + ) + else: + b_np = np.random.randint(-128, 128, (K, N)).astype("int8") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32") + + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev) + + f(a, b, c) + + if out_dtype != "float16": + # The numpy reference is computed with fp32 precision (otherwise too slow). + # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation. + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2) + + return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) + + +@tvm.testing.requires_cuda_compute_version(8) +def test_f16f16f32_m16n16k16(): + def index_map(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 16 + in_dtype = "float16" + out_dtype = "float32" + i_factors, j_factors, k_factors = [4, 8, 2, 4, 1], [1, 64, 2, 1, 2], [128, 2, 1] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_INTRIN, + MMA_f16f16f32_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf and timer: + print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_TRANS_INTRIN, + MMA_f16f16f32_TRANS_B_INTRIN, + MMA_fill_16x16_f32_INTRIN, + MMA_store_16x16_f32_global_INTRIN, + ) + + if measure_perf and timer: + print("f16f16f32_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) + + +@tvm.testing.requires_cuda_compute_version(8) +def test_f16f16f16_m16n16k16(): + def index_map(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 16 + in_dtype = "float16" + out_dtype = "float16" + i_factors, j_factors, k_factors = [16, 2, 1, 4, 2], [16, 2, 2, 1, 4], [128, 2, 1] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_INTRIN, + MMA_f16f16f16_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + ) + + if measure_perf and timer: + print("f16f16f16_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map, + index_map, + index_map, + LDMATRIX_f16_A_INTRIN, + LDMATRIX_f16_B_TRANS_INTRIN, + MMA_f16f16f16_TRANS_B_INTRIN, + MMA_fill_16x16_f16_INTRIN, + MMA_store_16x16_f16_global_INTRIN, + ) + + if measure_perf and timer: + print("f16f16f16_m16n16k16_trans: %f GFLOPS" % (gflops / (timer().mean))) + + +@tvm.testing.requires_cuda_compute_version(8) +def test_i8i8i32_m16n16k32(): + def index_map_A(i, j): + return ( + i // 16, + j // 32, + *shared_16x32_to_ldmatrix_32x16_layout(i % 16, j % 32), + ) + + def index_map_B(i, j): + return ( + i // 32, + j // 16, + *shared_32x16_to_ldmatrix_32x16_layout(i % 32, j % 16), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ) + + k_inner = 32 + in_dtype = "int8" + out_dtype = "int32" + i_factors, j_factors, k_factors = [1, 32, 1, 4, 2], [8, 4, 4, 2, 1], [32, 2, 2] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + LDMATRIX_i8_A_INTRIN, + LDMATRIX_i8_B_INTRIN, + MMA_i8i8i32_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_i32_global_INTRIN, + ) + + if measure_perf and timer: + print("i8i8i32_m16n16k32: %f GOPS" % (gflops / (timer().mean))) + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + True, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_A, + index_map_C, + LDMATRIX_i8_A_INTRIN, + LDMATRIX_i8_B_TRANS_INTRIN, + MMA_i8i8i32_TRANS_B_INTRIN, + MMA_fill_16x16_i32_INTRIN, + MMA_store_16x16_i32_global_INTRIN, + ) + + if measure_perf and timer: + print("i8i8i32_m16n16k32_trans: %f GOPS" % (gflops / (timer().mean))) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_schedule_tensorize_mfma_numeric.py b/tests/python/tir-transform/test_tir_schedule_tensorize_mfma_numeric.py new file mode 100644 index 000000000000..8077a603bcf2 --- /dev/null +++ b/tests/python/tir-transform/test_tir_schedule_tensorize_mfma_numeric.py @@ -0,0 +1,314 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring +import tvm +from tvm import te +from tvm.tir.tensor_intrin.rocm import ( + shared_16x4_to_local_64x1_layout_A, + shared_4x16_to_local_64x1_layout_B, + shared_16x16_to_local_64x4_layout_A, + shared_16x16_to_local_64x4_layout_B, + shared_16x16_to_local_64x4_layout_C, + ROCM_MFMA_fill_16x16_f32_INTRIN, + ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN, + ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN, + ROCM_MFMA_f32f32f32_INTRIN, + ROCM_MFMA_STORE_16x16_f32_INTRIN, + ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN, + ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN, + ROCM_MFMA_f16f16f32_INTRIN, + ROCM_MFMA_STORE_16x16_f32_INTRIN, + ROCM_MFMA_fill_16x16_i32_INTRIN, + ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN, + ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN, + ROCM_MFMA_s8s8s32_INTRIN, + ROCM_MFMA_STORE_16x16_s32_INTRIN, +) +import tvm.testing +import numpy as np +from tvm.testing.tir import mfma_schedule + + +M = 1024 +N = 1024 +K = 1024 +measure_perf = False +gflops = (N * M * K) * 2 / 1e9 + + +def matmul(m, n, k, in_dtype, out_dtype, b_transposed): + b_shape = (n, k) if b_transposed else (k, n) + a = te.placeholder((m, k), name="A", dtype=in_dtype) + b = te.placeholder(b_shape, name="B", dtype=in_dtype) + k = te.reduce_axis((0, k), name="k") + + def maybe_cast(v): + if in_dtype != out_dtype: + return tvm.tir.Cast(out_dtype, v) + return v + + def maybe_swap(i, j): + if b_transposed: + return j, i + return i, j + + c = te.compute( + (m, n), + lambda i, j: te.sum(maybe_cast(a[i, k]) * maybe_cast(b[maybe_swap(k, j)]), axis=[k]), + name="C", + ) + return (a, b, c) + + +def run_test( + k_inner, + in_dtype, + out_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, +): + sch = mfma_schedule( + te.create_prim_func(matmul(M, N, K, in_dtype, out_dtype, b_transposed)), + k_inner, + in_dtype, + b_transposed, + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ldmatrix_a_intrin, + ldmatrix_b_intrin, + mma_intrin, + mma_fill_intrin, + mma_store_intrin, + ) + + f = tvm.build(sch.mod["main"], target="rocm", name="dense") + + dev = tvm.device("rocm", 0) + if in_dtype == "float32": + a_np = np.random.uniform(size=(M, K)).astype("float32") + + if b_transposed: + b_np = np.random.uniform(size=(N, K)).astype("float32") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + out_dtype + ) + else: + b_np = np.random.uniform(size=(K, N)).astype("float32") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) + elif in_dtype == "float16": + a_np = np.random.uniform(size=(M, K)).astype("float16") + + if b_transposed: + b_np = np.random.uniform(size=(N, K)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + out_dtype + ) + else: + b_np = np.random.uniform(size=(K, N)).astype("float16") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype(out_dtype) + else: + a_np = np.random.randint(-128, 128, (M, K)).astype("int8") + + if b_transposed: + b_np = np.random.randint(-128, 128, (N, K)).astype("int8") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32").transpose()).astype( + "int32" + ) + else: + b_np = np.random.randint(-128, 128, (K, N)).astype("int8") + c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")).astype("int32") + + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros((M, N), dtype=out_dtype), dev) + + f(a, b, c) + + if in_dtype != "float16": + # The numpy reference is computed with fp32 precision (otherwise too slow). + # So there is non-trivial accuracy difference if TVM result is computed with fp16 accumulation. + tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2) + + return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c) + + +@tvm.testing.requires_matrixcore +def test_i8i8i32_m16n16k16(): + def index_map_A(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_local_64x4_layout_A(i % 16, j % 16), + ) + + def index_map_B(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_local_64x4_layout_B(i % 16, j % 16), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_local_64x4_layout_C(i % 16, j % 16), + ) + + k_inner = 16 + in_dtype = "int8" + out_dtype = "int32" + i_factors, j_factors, k_factors = [1, 8, 2, 4, 1], [1, 16, 2, 1, 2], [32, 2, 1] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ROCM_MFMA_LOAD_16x16_A_SHARED_s8_INTRIN, + ROCM_MFMA_LOAD_16x16_B_SHARED_s8_INTRIN, + ROCM_MFMA_s8s8s32_INTRIN, + ROCM_MFMA_fill_16x16_i32_INTRIN, + ROCM_MFMA_STORE_16x16_s32_INTRIN, + ) + + if measure_perf and timer: + print("test_i8i8i32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) + + +@tvm.testing.requires_matrixcore +def test_f16f16f32_m16n16k16(): + def index_map_A(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_local_64x4_layout_A(i % 16, j % 16), + ) + + def index_map_B(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_local_64x4_layout_B(i % 16, j % 16), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_local_64x4_layout_C(i % 16, j % 16), + ) + + k_inner = 16 + in_dtype = "float16" + out_dtype = "float32" + i_factors, j_factors, k_factors = [1, 8, 2, 4, 1], [1, 16, 2, 1, 2], [32, 2, 1] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ROCM_MFMA_LOAD_16x16_A_SHARED_f16_INTRIN, + ROCM_MFMA_LOAD_16x16_B_SHARED_f16_INTRIN, + ROCM_MFMA_f16f16f32_INTRIN, + ROCM_MFMA_fill_16x16_f32_INTRIN, + ROCM_MFMA_STORE_16x16_f32_INTRIN, + ) + + if measure_perf and timer: + print("f16f16f32_m16n16k16: %f GFLOPS" % (gflops / (timer().mean))) + + +@tvm.testing.requires_matrixcore +def test_f32f32f32_m16n16k4(): + def index_map_A(i, j): + return ( + i // 16, + j // 16, + *shared_16x4_to_local_64x1_layout_A(i % 16, j % 16), + ) + + def index_map_B(i, j): + return ( + i // 16, + j // 16, + *shared_4x16_to_local_64x1_layout_B(i % 16, j % 16), + ) + + def index_map_C(i, j): + return ( + i // 16, + j // 16, + *shared_16x16_to_local_64x4_layout_C(i % 16, j % 16), + ) + + k_inner = 4 + in_dtype = "float32" + out_dtype = "float32" + i_factors, j_factors, k_factors = [4, 2, 1, 4, 2], [4, 2, 2, 1, 4], [128, 2, 1] + + timer = run_test( + k_inner, + in_dtype, + out_dtype, + False, # b_transposed + i_factors, + j_factors, + k_factors, + index_map_A, + index_map_B, + index_map_C, + ROCM_MFMA_LOAD_16x4_A_SHARED_f32_INTRIN, + ROCM_MFMA_LOAD_16x4_B_SHARED_f32_INTRIN, + ROCM_MFMA_f32f32f32_INTRIN, + ROCM_MFMA_fill_16x16_f32_INTRIN, + ROCM_MFMA_STORE_16x16_f32_INTRIN, + ) + + if measure_perf and timer: + print("test_f32f32f32_m16n16k4: %f GFLOPS" % (gflops / (timer().mean))) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 5b3e68e22fa9..7bd24cacfa0a 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -2922,6 +2922,17 @@ def simplify_bracket() -> None: return simplify_bracket +def size_var(): + @T.prim_func + def size_var() -> None: + a = T.int32(is_size_var=True) + b = T.int32(is_size_var=True, min_value=0) + c = T.int32(is_size_var=True, min_value=1) + T.evaluate(a + b * c) + + return size_var + + def var_with_same_name(): @T.prim_func def var_with_same_name(a: T.handle) -> None: @@ -4015,6 +4026,7 @@ def func(): abs, constant_folding, simplify_bracket, + size_var, while_loop, primfunc_with_allocate_annotations, comm_reducer_single_reduce_group,