diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index cdb4aa9cfa20..90c585ac8ce1 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -68,7 +68,7 @@ class TensorizeInfo(Object): def get_tensorize_loop_mapping( - sch: Schedule, block: BlockRV, desc_func: PrimFunc + sch: Schedule, block: BlockRV, desc_func: PrimFunc, allow_padding: bool = False ) -> Optional[TensorizeInfo]: """Establish a mapping between loops in a target block and an intrinsic description @@ -80,13 +80,14 @@ def get_tensorize_loop_mapping( The target block to match against desc_func : PrimFunc The prim func describing the computation to be tensorized - + allow_padding : bool + Whether to allow padding the block iters to match the intrinsic description Returns ------- tensorize_info : Optional[TensorizeInfo] TensorizeInfo structure if a valid mapping is found, None otherwise """ - return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore + return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func, allow_padding) # type: ignore @tvm._ffi.register_object("tir.schedule.AutoTensorizeMappingInfo") diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 8fcb8fe503b7..6e0f0aaa3126 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -515,7 +515,8 @@ Optional MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin( state->sch->TransformBlockLayout(state->tensor_core_reindex_B, index_map); state->sch->TransformBlockLayout(state->block_rv, index_map); - return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name); + return tir::TileWithTensorIntrin(state->sch, state->block_rv, intrin_name, + /*allow_padding=*/true); } inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorization( diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index ca45bcac6b34..57165fd08ad4 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -731,10 +731,15 @@ class TensorizeInfoNode : public Object { Map loop_map; /*! \brief Maps loops in an intrinsic description to its index, outer to inner */ Map desc_loop_indexer; + /*! \brief Optional padded extents of the block iters when padding is needed to match the + * intrinsic description + */ + Optional> block_iter_paddings; void VisitAttrs(AttrVisitor* v) { v->Visit("loop_map", &loop_map); v->Visit("desc_loop_indexer", &desc_loop_indexer); + v->Visit("block_iter_paddings", &block_iter_paddings); } static constexpr const char* _type_key = "tir.schedule.TensorizeInfo"; @@ -751,11 +756,12 @@ class TensorizeInfo : public ObjectRef { * \param self The schedule state to be tensorized * \param block_sref The target block to match against * \param desc_func The prim func describing the computation to be tensorized + * \param allow_padding Whether to allow padding the block iters to match the intrinsic description * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise */ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func); + const tir::PrimFunc& desc_func, bool allow_padding); /*!\brief Necessary information used to perform transformations for tensorization */ class AutoTensorizeMappingInfoNode : public Object { diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4f78b0c9cd43..b2f879c60e29 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1699,7 +1699,8 @@ TensorIntrinDescInfo ExtractTensorIntrinDescInfo(arith::Analyzer* analyzer, Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, - const tir::PrimFunc& desc_func) { + const tir::PrimFunc& desc_func, + bool allow_padding) { arith::Analyzer analyzer; const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars @@ -1732,6 +1733,8 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const int n_desc_vars = desc_block->iter_values.size(); const int offset = n_block_vars - n_desc_vars; + std::unordered_map block_index_to_padding; // padding of each block iter if necessary + if (offset < 0) { return NullOpt; } @@ -1782,10 +1785,11 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type PrimExpr block_bind; - for (int i = next_block_ind; i >= 0; --i) { - if (iter_types_block[i] == iter_type_desc) { - next_block_ind = i - 1; - block_bind = block->iter_values[i]; + int current_block_ind = next_block_ind; + for (; current_block_ind >= 0; --current_block_ind) { + if (iter_types_block[current_block_ind] == iter_type_desc) { + next_block_ind = current_block_ind - 1; + block_bind = block->iter_values[current_block_ind]; break; } } @@ -1802,15 +1806,30 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var); if (UsesVar(residual, - [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) + [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { continue; + } + // padding is allowed only when the block has trivial bindings + if (allow_padding && !is_zero(residual)) { + allow_padding = false; + } const IntImmNode* int_block_extent = block_loops[i]->extent.as(); // Check divisibility - if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { + if (!int_block_extent) { return NullOpt; } + int64_t remainder = int_block_extent->value % int_desc_extent->value; + if (remainder != 0) { + if (allow_padding) { + // If the block loop is not divisible by the desc loop, we pad the block loop to make it + // divisible if padding is allowed. + block_index_to_padding[current_block_ind] = int_desc_extent->value - remainder; + } else { + return NullOpt; + } + } ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); break; @@ -1820,13 +1839,29 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, for (int i = 0, n = desc_loops.size(); i < n; ++i) { ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); } + if (!block_index_to_padding.empty()) { + if (!allow_padding) { + return NullOpt; + } + Array paddings; + for (int i = 0, n = block->block->iter_vars.size(); i < n; ++i) { + const IterVar& iter_var = block->block->iter_vars[i]; + if (auto it = block_index_to_padding.find(i); it != block_index_to_padding.end()) { + paddings.push_back(IntImm(iter_var->var.dtype(), it->second)); + } else { + paddings.push_back(IntImm(iter_var->var.dtype(), 0)); + } + } + ret->block_iter_paddings = std::move(paddings); + } + return TensorizeInfo(ret); } TVM_REGISTER_GLOBAL("tir.schedule.IsSpatialPrimFunc").set_body_typed(IsSpatialPrimFunc); TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") - .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { - return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); + .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func, bool allow_padding) { + return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func, allow_padding); }); /******** Auto Tensorization ********/ diff --git a/src/tir/schedule/transform.cc b/src/tir/schedule/transform.cc index dfbd3dbcbcc4..b00005c58061 100644 --- a/src/tir/schedule/transform.cc +++ b/src/tir/schedule/transform.cc @@ -288,11 +288,15 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ } Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, - const String& intrin_name) { - Optional opt_tensorize_info = GetTensorizeLoopMapping( - sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); + const String& intrin_name, bool allow_padding) { + Optional opt_tensorize_info = + GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block_rv), + tir::TensorIntrin::Get(intrin_name)->desc, allow_padding); if (!opt_tensorize_info) return NullOpt; const tir::TensorizeInfoNode* info = opt_tensorize_info.value().get(); + if (info->block_iter_paddings.defined()) { + sch->PadEinsum(block_rv, info->block_iter_paddings.value()); + } // Construct a mapping from tir loops back to LoopRVs Map loop2rv; { diff --git a/src/tir/schedule/transform.h b/src/tir/schedule/transform.h index 4de3685e2482..eb90ca0139bd 100644 --- a/src/tir/schedule/transform.h +++ b/src/tir/schedule/transform.h @@ -197,7 +197,7 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_ * block tiled according to the given intrin, NullOpt if a valid loop mapping is not found */ Optional TileWithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv, - const String& intrin_name); + const String& intrin_name, bool allow_padding = false); /******** Block mutation ********/ diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py index fbb74090b1e5..f7a5ce997edf 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm +import tvm.testing from tvm import meta_schedule as ms from tvm import te from tvm.meta_schedule.testing import te_workload @@ -947,11 +948,145 @@ def test_matmul_relu_non_tensorizable(): tvm.ir.assert_structural_equal(mod, sch.mod["main"]) +def test_padded_matmul_relu(): + # fmt: off + @T.prim_func + def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 127), "float16"], compute: T.Buffer[(127, 127), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_0_0 in T.serial(1): + for ax0_ax1_fused in T.serial(4096): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(128, ax0_ax1_fused % 128) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + A_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, A[v0, v1], T.float16(0), dtype="float16") + for ax0_ax1_fused in T.serial(4096): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(128, ax0_ax1_fused // 32) + v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + B_reindex_shared[v0, v1] = T.if_then_else(v0 < 127 and v1 < 127, B[v0, v1], T.float16(0), dtype="float16") + for ax2_0_1 in T.serial(4): + for ax0_0, ax1_0 in T.grid(2, 2): + with T.block("A_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) + T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused) + T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 2, 2, 1): + with T.block("C_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) + v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1 in T.grid(32, 32): + with T.block("C_reindex_shared"): + T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1 < 127) + v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0) + v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1) + T.reads(C_reindex_shared[v0, v1]) + T.writes(compute[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + # fmt: on + + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1, 1, 2]), + ("SamplePerfectTile", [2, 2, 2, 1, 1]), + ("SamplePerfectTile", [1, 4, 2]), + ("SampleCategorical", 3), + ("SampleCategorical", 3), + ("SampleCategorical", 0), + ] + + mod = te.create_prim_func( + te_workload.matmul_relu( + n=127, + m=127, + k=127, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = ms.TuneContext( + mod=mod, + target=tvm.target.Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="shared")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[padded_matmul_relu_0], + expected_decisions=[decision_0], + ) + + if __name__ == "__main__": - test_matmul_relu() - test_matmul_relu_with_fallback() - test_conv2d() - test_conv2d_more_intrin() - test_matmul_relu_pipeline() - test_matmul_relu_global() - test_matmul_relu_non_tensorizable() + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 5524abbaf094..242ba5363c7d 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -21,7 +21,10 @@ import tvm.testing from tvm.tir.function import TensorIntrin from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc -from tvm.tir.tensor_intrin.cuda import WMMA_SYNC_16x16x16_f16f16f32_INTRIN +from tvm.tir.tensor_intrin.cuda import ( + WMMA_SYNC_16x16x16_f16f16f16_INTRIN, + WMMA_SYNC_16x16x16_f16f16f32_INTRIN, +) from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule @@ -260,6 +263,30 @@ def matmul_16x16x16xf16f16f16_desc( assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) +def test_get_tensorize_loop_mapping_padding_matmul(): + matmul = create_prim_func( + te_workload.matmul_relu( + n=127, + m=256, + k=65, + in_dtype="float16", + out_dtype="float16", + ) + ) + s = Schedule(matmul) + block = s.get_block("C") + + desc = TensorIntrin.get(WMMA_SYNC_16x16x16_f16f16f16_INTRIN).desc + info = get_tensorize_loop_mapping(s, block, desc) + assert info is not None + expected_padding = [1, 0, 15] + actual_padding = info.block_iter_paddings + assert actual_padding is not None + assert len(actual_padding) == len(expected_padding) + for actual, expected in zip(actual_padding, expected_padding): + assert actual == expected + + def check_index_map(workload, block_name, intrin_name, expected_index_map): s = Schedule(workload) block = s.get_block(block_name)