From 712212e6a3ec421adc511e1e7b77e14882e394dd Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 15 Sep 2022 11:27:00 -0700 Subject: [PATCH] [MetaSchedule][Test] MLT uses SEqual tests This PR finishes migration from `check_trace` (string-based equality check on TIR trace) to `check_sketch` (SEqual-based equality check on TIR). Here, we split multi-level-tiling into 3 files: - Plain multi-level tiling without any intrinsics - Multi-level tiling with intrinsics like VNNI, DP4a - Multi-level tiling with TensorCore which comes with different handling Besides, we cleaned up the testing folder and removed several methods that are no longer useful for unittests. --- .../meta_schedule/testing/schedule_rule.py | 138 +- .../multi_level_tiling_tensor_core.cc | 4 +- src/meta_schedule/utils.h | 35 +- ...t_meta_schedule_schedule_rule_auto_bind.py | 22 +- ...meta_schedule_schedule_rule_auto_inline.py | 19 +- ...le_schedule_rule_cross_thread_reduction.py | 17 +- .../test_meta_schedule_schedule_rule_mlt.py | 529 ++++++++ ..._meta_schedule_schedule_rule_mlt_intrin.py | 418 ++++++ ...test_meta_schedule_schedule_rule_mlt_tc.py | 957 +++++++++++++ ...hedule_schedule_rule_multi_level_tiling.py | 1205 ----------------- 10 files changed, 1961 insertions(+), 1383 deletions(-) create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py delete mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index 12ca4200d77a..f14e90b6f0b2 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -15,122 +15,22 @@ # specific language governing permissions and limitations # under the License. """Default schedule rules""" -from typing import List, Union - -from tvm.meta_schedule.schedule_rule import ( - AutoInline, - MultiLevelTiling, - MultiLevelTilingTensorCore, - ReuseType, - ScheduleRule, -) -from tvm.target import Target - - -def auto_inline(target: Target) -> ScheduleRule: - """Default schedule rules for auto inline""" - if target.kind.name == "llvm": - return AutoInline( - into_producer=False, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=True, - require_injective=True, - require_ordered=True, - disallow_op=["tir.exp"], - ) - if target.kind.name == "cuda": - return AutoInline( - into_producer=True, - into_consumer=True, - inline_const_tensor=True, - disallow_if_then_else=False, - require_injective=False, - require_ordered=False, - disallow_op=None, - ) - raise NotImplementedError(f"{target.kind.name} is not supported") - - -def multi_level_tiling(target: Target) -> ScheduleRule: - """Default schedule rules for with multi-level tiling and reuse""" - if target.kind.name == "llvm": - return MultiLevelTiling( - structure="SSRSRS", - tile_binds=None, - max_innermost_factor=64, - vector_load_lens=None, - reuse_read=None, - reuse_write=ReuseType( - req="may", - levels=[1, 2], - scope="global", - ), - ) - if target.kind.name == "cuda": - return MultiLevelTiling( - structure="SSSRRSRS", - tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], - max_innermost_factor=64, - vector_load_lens=[1, 2, 3, 4, 8, 16], - reuse_read=ReuseType( - req="must", - levels=[4], - scope="shared", - ), - reuse_write=ReuseType( - req="must", - levels=[3], - scope="local", - ), - ) - raise NotImplementedError(f"{target.kind.name} is not supported") - - -def multi_level_tiling_tensor_core( - target: Target, - write_reuse_scope: str = "shared", - in_dtype: Union[str, List[str]] = "float16", - out_dtype: Union[str, List[str]] = "float32", - trans_b: Union[bool, List[bool]] = False, - use_software_pipeline: bool = False, -) -> ScheduleRule: - """Default schedule rules for with multi-level tiling reuse for tensor core""" - assert write_reuse_scope in ["shared", "global"] - if not isinstance(in_dtype, list): - in_dtype = [in_dtype] - if not isinstance(out_dtype, list): - out_dtype = [out_dtype] - if not isinstance(trans_b, list): - trans_b = [trans_b] - - if target.kind.name == "cuda": - from tvm.tir.tensor_intrin import ( # pylint: disable=import-outside-toplevel - cuda, - ) - - intrin_groups = [ - cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b) - for _in_dtype in in_dtype - for _out_dtype in out_dtype - for _trans_b in trans_b - ] - return MultiLevelTilingTensorCore( - intrin_groups=intrin_groups, - structure="SSSRRSRS", - tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"], - max_innermost_factor=4, # 64 // tensor intrin size - vector_load_lens=[1, 2, 3, 4, 8, 16], - reuse_read=ReuseType( - req="must", - levels=[4], - scope="shared", - ), - reuse_write=ReuseType( - req="must" if write_reuse_scope == "shared" else "no", - levels=[2], - scope=write_reuse_scope, - ), - use_software_pipeline=use_software_pipeline, - ) - raise NotImplementedError(f"{target.kind.name} is not supported") +from typing import List, Tuple, Union + +from tvm.meta_schedule import default_config +from tvm.meta_schedule.schedule_rule import ScheduleRule + + +def get_rules(kind: str, types: Union[type, Tuple[type, ...]]) -> List[ScheduleRule]: + """Get default schedule rules""" + # pylint: disable=protected-access + if kind == "llvm": + rules = default_config._DefaultLLVM.schedule_rules() + elif kind == "cuda": + rules = default_config._DefaultCUDA.schedule_rules() + elif kind == "tensor_core": + rules = default_config._DefaultCUDATensorCore.schedule_rules() + else: + raise NotImplementedError(f"{kind} is not supported") + # pylint: enable=protected-access + return [rule for rule in rules if isinstance(rule, types)] diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc index 7ddda9b2635b..f7d4cf891aeb 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc @@ -321,7 +321,7 @@ std::vector MultiLevelTilingTensorCoreNode::AddSoftwarePipeline( // Add local stage and double buffering for (int i = 0; i < 2; ++i) { const tir::BlockRV cache_read = state->read_reuse.at(i); - sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Bool(true)); + sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Integer(1)); sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0)); } @@ -529,7 +529,7 @@ inline std::vector MultiLevelTilingTensorCoreNode::TransformForTensorizat state->intrin_group.compute_intrin); state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init, state->intrin_group.init_intrin); - state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true)); + state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Integer(1)); return {std::move(state)}; } diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index ad56fa7f6a52..cf9a32917031 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -77,33 +77,34 @@ class PyLogMessage { // FATAL not included }; - PyLogMessage(const std::string& file, int lineno, PackedFunc logging_func, Level logging_level) { - this->logging_func = logging_func; - this->logging_level = logging_level; - } + explicit PyLogMessage(const char* file, int lineno, PackedFunc logging_func, Level logging_level) + : file_(file), lineno_(lineno), logging_func_(logging_func), logging_level_(logging_level) {} + TVM_NO_INLINE ~PyLogMessage() { - if (this->logging_func.defined()) { - logging_func(static_cast(logging_level), stream_.str()); + if (this->logging_func_.defined()) { + logging_func_(static_cast(logging_level_), stream_.str()); } else { - if (logging_level == Level::INFO) { - LOG(INFO) << stream_.str(); - } else if (logging_level == Level::WARNING) { - LOG(WARNING) << stream_.str(); - } else if (logging_level == Level::ERROR) { - LOG(ERROR) << stream_.str(); - } else if (logging_level == Level::DEBUG) { - DLOG(INFO) << stream_.str(); + if (logging_level_ == Level::INFO) { + runtime::detail::LogMessage(file_, lineno_).stream() << stream_.str(); + } else if (logging_level_ == Level::WARNING) { + runtime::detail::LogMessage(file_, lineno_).stream() << "Warning: " << stream_.str(); + } else if (logging_level_ == Level::ERROR) { + runtime::detail::LogMessage(file_, lineno_).stream() << "Error: " << stream_.str(); + } else if (logging_level_ == Level::DEBUG) { + runtime::detail::LogMessage(file_, lineno_).stream() << "Debug: " << stream_.str(); } else { - LOG(FATAL) << stream_.str(); + runtime::detail::LogFatal(file_, lineno_).stream() << stream_.str(); } } } std::ostringstream& stream() { return stream_; } private: + const char* file_; + int lineno_; std::ostringstream stream_; - PackedFunc logging_func; - Level logging_level; + PackedFunc logging_func_; + Level logging_level_; }; /*! \brief The type of the random state */ diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py index 21ad04da473e..a50292df7ae3 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_bind.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.schedule_rule import get_rules from tvm.meta_schedule.testing.space_generation import check_sketches from tvm.script import tir as T from tvm.target import Target @@ -83,12 +84,7 @@ def elementwise_0( mod=mod, target=Target("nvidia/geforce-rtx-3080", host="llvm"), space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ - ms.schedule_rule.AutoBind( - max_threadblocks=256, - thread_extents=[32, 64, 128, 256, 512, 1024], - ) - ], + sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind), task_name="test", ).generate_design_space() check_sketches( @@ -122,12 +118,7 @@ def reduction_loop_only_0( mod=mod, target=Target("nvidia/geforce-rtx-3080", host="llvm"), space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ - ms.schedule_rule.AutoBind( - max_threadblocks=256, - thread_extents=[32, 64, 128, 256, 512, 1024], - ) - ], + sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind), task_name="test", ).generate_design_space() check_sketches( @@ -158,12 +149,7 @@ def zero_dim_add_0( mod=mod, target=Target("nvidia/geforce-rtx-3080", host="llvm"), space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ - ms.schedule_rule.AutoBind( - max_threadblocks=256, - thread_extents=[32, 64, 128, 256, 512, 1024], - ) - ], + sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind), task_name="test", ).generate_design_space() check_sketches( diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py index fcf6a8571b7f..c0801c9d7b5e 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py @@ -16,9 +16,8 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring import tvm -from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply -from tvm.meta_schedule.testing.schedule_rule import auto_inline -from tvm.meta_schedule.tune_context import TuneContext +from tvm import meta_schedule as ms +from tvm.meta_schedule.testing.schedule_rule import get_rules from tvm.script import tir as T from tvm.target import Target @@ -340,10 +339,10 @@ def main(T_full: T.Buffer[(1, 12, 4096), "int64"]) -> None: def _create_context(mod, target, rule): - ctx = TuneContext( + ctx = ms.TuneContext( mod=mod, target=target, - space_generator=PostOrderApply(), + space_generator=ms.space_generator.PostOrderApply(), sch_rules=[rule], task_name="test", ) @@ -356,7 +355,7 @@ def test_inline_consumer_chain(): ctx = _create_context( mod=mod, target=target, - rule=auto_inline(target=target), + rule=get_rules("llvm", ms.schedule_rule.AutoInline)[0], ) (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined) @@ -368,7 +367,7 @@ def test_inline_into_cache(): ctx = _create_context( mod=mod, target=target, - rule=auto_inline(target=target), + rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0], ) (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=MultiLevelTiledConv2DAfterInline) @@ -380,7 +379,7 @@ def test_inline_into_multiple_consumers(): ctx = _create_context( mod=mod, target=target, - rule=auto_inline(target=target), + rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0], ) (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=SoftmaxAfterInline) @@ -392,7 +391,7 @@ def test_inline_pure_spatial(): ctx = _create_context( mod=mod, target=target, - rule=auto_inline(target=target), + rule=get_rules("llvm", ms.schedule_rule.AutoInline)[0], ) (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=AfterPureSpatial) @@ -404,7 +403,7 @@ def test_inline_constant_tensor(): ctx = _create_context( mod=mod, target=target, - rule=auto_inline(target=target), + rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0], ) (space,) = ctx.space_generator.generate_design_space(mod=mod) tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py index ab8df6678b0b..4278638a1aa3 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_cross_thread_reduction.py @@ -19,6 +19,7 @@ import tvm from tvm import meta_schedule as ms from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.schedule_rule import get_rules from tvm.meta_schedule.testing.space_generation import check_sketches from tvm.script import tir as T from tvm.target import Target @@ -283,9 +284,7 @@ def softmax_mn_3( mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ - ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) - ], + sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction), task_name="test", ).generate_design_space() check_sketches( @@ -481,9 +480,7 @@ def softmax_mn_after_inline_3( mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ - ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) - ], + sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction), task_name="test", ).generate_design_space() check_sketches( @@ -559,9 +556,7 @@ def batch_norm_bmn_1(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "floa mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ - ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) - ], + sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction), task_name="test", ).generate_design_space() check_sketches( @@ -657,9 +652,7 @@ def argmax_1( mod=mod, target=Target("nvidia/geforce-rtx-3090", host="llvm"), space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ - ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512]) - ], + sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction), task_name="test", ).generate_design_space() check_sketches( diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py new file mode 100644 index 000000000000..939ccbe54fa6 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -0,0 +1,529 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from tvm import meta_schedule as ms +from tvm import te +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.schedule_rule import get_rules +from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.script import tir as T +from tvm.target import Target + + +def test_cpu_matmul(): + @T.prim_func + def cpu_matmul_0( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C_global = T.alloc_buffer([512, 512], dtype="float32") + for i0_0, i1_0, i0_1, i1_1 in T.grid(1, 8, 8, 1): + for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(16, 2, 8, 32, 32, 8): + with T.block("C"): + i = T.axis.spatial(512, i0_0 * 512 + i0_1 * 64 + i0_2 * 32 + i0_3) + j = T.axis.spatial(512, i1_0 * 64 + i1_1 * 64 + i1_2 * 8 + i1_3) + k = T.axis.reduce(512, i2_0 * 32 + i2_1) + T.reads(A[i, k], B[k, j]) + T.writes(C_global[i, j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + with T.init(): + C_global[i, j] = T.float32(0) + C_global[i, j] = C_global[i, j] + A[i, k] * B[k, j] + for ax0, ax1 in T.grid(64, 64): + with T.block("C_global"): + v0 = T.axis.spatial(512, i0_1 * 64 + ax0) + v1 = T.axis.spatial(512, i1_0 * 64 + ax1) + T.reads(C_global[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_global[v0, v1] + + @T.prim_func + def cpu_matmul_1( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C_global = T.alloc_buffer([512, 512], dtype="float32") + for i0_0, i1_0 in T.grid(1, 8): + for i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(8, 1, 16, 2, 8, 32, 32, 8): + with T.block("C"): + i = T.axis.spatial(512, i0_0 * 512 + i0_1 * 64 + i0_2 * 32 + i0_3) + j = T.axis.spatial(512, i1_0 * 64 + i1_1 * 64 + i1_2 * 8 + i1_3) + k = T.axis.reduce(512, i2_0 * 32 + i2_1) + T.reads(A[i, k], B[k, j]) + T.writes(C_global[i, j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + with T.init(): + C_global[i, j] = T.float32(0) + C_global[i, j] = C_global[i, j] + A[i, k] * B[k, j] + for ax0, ax1 in T.grid(512, 64): + with T.block("C_global"): + v0 = T.axis.spatial(512, ax0) + v1 = T.axis.spatial(512, i1_0 * 64 + ax1) + T.reads(C_global[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_global[v0, v1] + + @T.prim_func + def cpu_matmul_2( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + for i0_0, i1_0, i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid( + 1, 8, 8, 1, 16, 2, 8, 32, 32, 8 + ): + with T.block("C"): + i = T.axis.spatial(512, i0_0 * 512 + i0_1 * 64 + i0_2 * 32 + i0_3) + j = T.axis.spatial(512, i1_0 * 64 + i1_1 * 64 + i1_2 * 8 + i1_3) + k = T.axis.reduce(512, i2_0 * 32 + i2_1) + T.reads(A[i, k], B[k, j]) + T.writes(C[i, j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + + decision_0 = [ + ("SamplePerfectTile", [1, 8, 2, 32]), + ("SamplePerfectTile", [8, 1, 8, 8]), + ("SamplePerfectTile", [16, 32]), + ] + decision_1 = [ + ("SamplePerfectTile", [1, 8, 2, 32]), + ("SamplePerfectTile", [8, 1, 8, 8]), + ("SamplePerfectTile", [16, 32]), + ] + decision_2 = [ + ("SamplePerfectTile", [1, 8, 2, 32]), + ("SamplePerfectTile", [8, 1, 8, 8]), + ("SamplePerfectTile", [16, 32]), + ] + + mod = te.create_prim_func(te_workload.matmul(512, 512, 512)) + actual = ms.TuneContext( + mod=mod, + target=Target("llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=get_rules("llvm", ms.schedule_rule.MultiLevelTiling), + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[cpu_matmul_0, cpu_matmul_1, cpu_matmul_2], + expected_decisions=[decision_0, decision_1, decision_2], + ) + + +def test_cpu_matmul_relu(): + @T.prim_func + def cpu_matmul_relu_0( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + compute: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C = T.alloc_buffer([512, 512], dtype="float32") + for i0_0, i1_0, i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid( + 256, 4, 1, 4, 64, 1, 32, 8, 2, 1 + ): + with T.block("C"): + i = T.axis.spatial(512, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3) + j = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + i1_2 + i1_3) + k = T.axis.reduce(512, i2_0 * 8 + i2_1) + T.reads(A[i, k], B[k, j]) + T.writes(C[i, j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + for i0, i1 in T.grid(512, 512): + with T.block("compute"): + i0_4, i1_4 = T.axis.remap("SS", [i0, i1]) + T.reads(C[i0_4, i1_4]) + T.writes(compute[i0_4, i1_4]) + compute[i0_4, i1_4] = T.max(C[i0_4, i1_4], T.float32(0)) + + @T.prim_func + def cpu_matmul_relu_1( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + compute: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C = T.alloc_buffer([512, 512], dtype="float32") + for i0_0, i1_0, i0_1, i1_1 in T.grid(256, 4, 1, 4): + for i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(64, 1, 32, 8, 2, 1): + with T.block("C"): + i = T.axis.spatial(512, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3) + j = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + i1_2 + i1_3) + k = T.axis.reduce(512, i2_0 * 8 + i2_1) + T.reads(A[i, k], B[k, j]) + T.writes(C[i, j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + for ax0, ax1 in T.grid(2, 32): + with T.block("compute"): + i0 = T.axis.spatial(512, i0_0 * 2 + ax0) + i1 = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + ax1) + T.reads(C[i0, i1]) + T.writes(compute[i0, i1]) + compute[i0, i1] = T.max(C[i0, i1], T.float32(0)) + + @T.prim_func + def cpu_matmul_relu_2( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + compute: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C = T.alloc_buffer([512, 512], dtype="float32") + for i0_0, i1_0 in T.grid(256, 4): + for i0_1, i1_1, i2_0, i0_2, i1_2, i2_1, i0_3, i1_3 in T.grid(1, 4, 64, 1, 32, 8, 2, 1): + with T.block("C"): + i = T.axis.spatial(512, i0_0 * 2 + i0_1 * 2 + i0_2 * 2 + i0_3) + j = T.axis.spatial(512, i1_0 * 128 + i1_1 * 32 + i1_2 + i1_3) + k = T.axis.reduce(512, i2_0 * 8 + i2_1) + T.reads(A[i, k], B[k, j]) + T.writes(C[i, j]) + T.block_attr({"meta_schedule.tiling_structure": "SSRSRS"}) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + for ax0, ax1 in T.grid(2, 128): + with T.block("compute"): + i0 = T.axis.spatial(512, i0_0 * 2 + ax0) + i1 = T.axis.spatial(512, i1_0 * 128 + ax1) + T.reads(C[i0, i1]) + T.writes(compute[i0, i1]) + compute[i0, i1] = T.max(C[i0, i1], T.float32(0)) + + decision_0 = [ + ("SamplePerfectTile", [256, 1, 1, 2]), + ("SamplePerfectTile", [4, 4, 32, 1]), + ("SamplePerfectTile", [64, 8]), + ] + decision_1 = [ + ("SamplePerfectTile", [256, 1, 1, 2]), + ("SamplePerfectTile", [4, 4, 32, 1]), + ("SamplePerfectTile", [64, 8]), + ] + decision_2 = [ + ("SamplePerfectTile", [256, 1, 1, 2]), + ("SamplePerfectTile", [4, 4, 32, 1]), + ("SamplePerfectTile", [64, 8]), + ] + mod = te.create_prim_func(te_workload.matmul_relu(512, 512, 512)) + actual = ms.TuneContext( + mod=mod, + target=Target("llvm"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=get_rules("llvm", ms.schedule_rule.MultiLevelTiling), + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[cpu_matmul_relu_0, cpu_matmul_relu_1, cpu_matmul_relu_2], + expected_decisions=[decision_0, decision_1, decision_2], + ) + + +def test_cuda_matmul(): + @T.prim_func + def cuda_matmul_0( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + C: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(128, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(8, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(4, thread="threadIdx.x"): + for i2_0 in T.serial(128): + for ax0_ax1_fused in T.serial(256): + with T.block("A_shared"): + v0 = T.axis.spatial( + 512, i0_0_i1_0_fused // 16 * 64 + ax0_ax1_fused // 4 + ) + v1 = T.axis.spatial(512, i2_0 * 4 + ax0_ax1_fused % 4) + T.reads(A[v0, v1]) + T.writes(A_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in T.serial(128): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0 * 4 + ax0_ax1_fused // 32) + v1 = T.axis.spatial( + 512, i0_0_i1_0_fused % 16 * 32 + ax0_ax1_fused % 32 + ) + T.reads(B[v0, v1]) + T.writes(B_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(2, 1, 1, 2, 16, 4): + with T.block("C"): + i = T.axis.spatial( + 512, + i0_0_i1_0_fused // 16 * 64 + + i0_1_i1_1_fused // 2 * 16 + + i0_3 * 16 + + i0_4, + ) + j = T.axis.spatial( + 512, + i0_0_i1_0_fused % 16 * 32 + + i0_1_i1_1_fused % 2 * 16 + + i0_2_i1_2_fused * 4 + + i1_3 * 4 + + i1_4, + ) + k = T.axis.reduce(512, i2_0 * 4 + i2_1 * 2 + i2_2) + T.reads(A_shared[i, k], B_shared[k, j]) + T.writes(C_local[i, j]) + T.block_attr( + { + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(16, 4): + with T.block("C_local"): + v0 = T.axis.spatial( + 512, i0_0_i1_0_fused // 16 * 64 + i0_1_i1_1_fused // 2 * 16 + ax0 + ) + v1 = T.axis.spatial( + 512, + i0_0_i1_0_fused % 16 * 32 + + i0_1_i1_1_fused % 2 * 16 + + i0_2_i1_2_fused * 4 + + ax1, + ) + T.reads(C_local[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_local[v0, v1] + + decision_0 = [ + ("SamplePerfectTile", [8, 4, 1, 1, 16]), + ("SamplePerfectTile", [16, 2, 4, 1, 4]), + ("SamplePerfectTile", [128, 2, 2]), + ("SampleCategorical", 1), + ("SampleCategorical", 0), + ] + mod = te.create_prim_func(te_workload.matmul(512, 512, 512)) + actual = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3080"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=get_rules("cuda", ms.schedule_rule.MultiLevelTiling), + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[cuda_matmul_0], + expected_decisions=[decision_0], + ) + + +def test_cuda_matmul_relu(): + @T.prim_func + def cuda_matmul_relu_0( + A: T.Buffer[(512, 512), "float32"], + B: T.Buffer[(512, 512), "float32"], + compute: T.Buffer[(512, 512), "float32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C = T.alloc_buffer([512, 512], dtype="float32") + C_local = T.alloc_buffer([512, 512], dtype="float32", scope="local") + A_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + B_shared = T.alloc_buffer([512, 512], dtype="float32", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(64, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(64, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(8, thread="threadIdx.x"): + for i2_0 in T.serial(8): + for ax0_ax1_fused in T.serial(4096): + with T.block("A_shared"): + v0 = T.axis.spatial( + 512, i0_0_i1_0_fused // 8 * 64 + ax0_ax1_fused // 64 + ) + v1 = T.axis.spatial(512, i2_0 * 64 + ax0_ax1_fused % 64) + T.reads(A[v0, v1]) + T.writes(A_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 2}) + A_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in T.serial(4096): + with T.block("B_shared"): + v0 = T.axis.spatial(512, i2_0 * 64 + ax0_ax1_fused // 64) + v1 = T.axis.spatial( + 512, i0_0_i1_0_fused % 8 * 64 + ax0_ax1_fused % 64 + ) + T.reads(B[v0, v1]) + T.writes(B_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 4}) + B_shared[v0, v1] = B[v0, v1] + for i2_1, i0_3, i1_3, i2_2, i0_4, i1_4 in T.grid(8, 2, 1, 8, 2, 2): + with T.block("C"): + i = T.axis.spatial( + 512, + i0_0_i1_0_fused // 8 * 64 + + i0_1_i1_1_fused // 8 * 8 + + i0_2_i1_2_fused // 4 * 4 + + i0_3 * 2 + + i0_4, + ) + j = T.axis.spatial( + 512, + i0_0_i1_0_fused % 8 * 64 + + i0_1_i1_1_fused % 8 * 8 + + i0_2_i1_2_fused % 4 * 2 + + i1_3 * 2 + + i1_4, + ) + k = T.axis.reduce(512, i2_0 * 64 + i2_1 * 8 + i2_2) + T.reads(A_shared[i, k], B_shared[k, j]) + T.writes(C_local[i, j]) + T.block_attr( + { + "meta_schedule.thread_extent_high_inclusive": 1024, + "meta_schedule.thread_extent_low_inclusive": 32, + "meta_schedule.tiling_structure": "SSSRRSRS", + } + ) + with T.init(): + C_local[i, j] = T.float32(0) + C_local[i, j] = C_local[i, j] + A_shared[i, k] * B_shared[k, j] + for ax0, ax1 in T.grid(4, 2): + with T.block("C_local"): + v0 = T.axis.spatial( + 512, + i0_0_i1_0_fused // 8 * 64 + + i0_1_i1_1_fused // 8 * 8 + + i0_2_i1_2_fused // 4 * 4 + + ax0, + ) + v1 = T.axis.spatial( + 512, + i0_0_i1_0_fused % 8 * 64 + + i0_1_i1_1_fused % 8 * 8 + + i0_2_i1_2_fused % 4 * 2 + + ax1, + ) + T.reads(C_local[v0, v1]) + T.writes(C[v0, v1]) + C[v0, v1] = C_local[v0, v1] + for i0, i1 in T.grid(512, 512): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(C[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) + + decision_0 = [ + ("SamplePerfectTile", [8, 8, 2, 2, 2]), + ("SamplePerfectTile", [8, 8, 4, 1, 2]), + ("SamplePerfectTile", [8, 8, 8]), + ("SampleCategorical", 1), + ("SampleCategorical", 3), + ] + mod = te.create_prim_func(te_workload.matmul_relu(512, 512, 512)) + actual = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3080"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=get_rules("cuda", ms.schedule_rule.MultiLevelTiling), + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[cuda_matmul_relu_0], + expected_decisions=[decision_0], + ) + + +def test_cuda_sum_with_trivial_block_iter(): + @T.prim_func + def sum_with_trivial_block_iter( + A: T.Buffer[(1, 64, 768), "float32"], + B: T.Buffer[(1, 64, 1), "float32"], + ) -> None: + for i0, i1, i2, i3 in T.grid(1, 64, 1, 768): + with T.block("sum"): + ax0, ax1, ax2, k2 = T.axis.remap("SSSR", [i0, i1, i2, i3]) + T.reads(A[ax0, ax1, k2]) + T.writes(B[ax0, ax1, ax2]) + with T.init(): + B[ax0, ax1, ax2] = T.float32(0) + B[ax0, ax1, ax2] = B[ax0, ax1, ax2] + A[ax0, ax1, k2] + + # Expect nothing to happen - the rule is not supposed to be applied in this case + mod = sum_with_trivial_block_iter + (sch,) = ms.TuneContext( + mod=mod, + target=Target("nvidia/geforce-rtx-3080"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=get_rules("cuda", ms.schedule_rule.MultiLevelTiling), + task_name="test", + ).generate_design_space() + assert not sch.trace.simplified(remove_postproc=True).insts + + +if __name__ == "__main__": + test_cpu_matmul() + test_cpu_matmul_relu() + test_cuda_matmul() + test_cuda_matmul_relu() + test_cuda_sum_with_trivial_block_iter() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py new file mode 100644 index 000000000000..38ddb137e108 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_intrin.py @@ -0,0 +1,418 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from tvm import meta_schedule as ms +from tvm import te +from tvm.ir import assert_structural_equal +from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.script import tir as T +from tvm.target import Target +from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN +from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN + + +def test_vnni_conv2d_nchwc(): + @T.prim_func + def conv2d_nchwc( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + with T.block("conv2d_NCHWc_int8"): + ( + n, + oc_chunk, + oh, + ow, + oc_block, + kh, + kw, + ic_outer, + ic_f_inner, + ic_s_inner, + ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + # fmt: off + @T.prim_func + def vnni_conv2d_nchwc_0(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16], dtype="int32") + for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1): + for i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): + with T.block("conv2d_NCHWc_int8_o"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) + oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) + ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1) + ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) + T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16]) + T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"}) + with T.init(): + for i4_1 in T.serial(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_i_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init]) + conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init] = 0 + for i4_1, i9_1 in T.grid(16, 4): + with T.block("conv2d_NCHWc_int8"): + oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1]) + T.reads(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) + T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.cast(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], "int32") * T.cast(placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i], "int32") + for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 1, 2, 1, 16): + with T.block("conv2d_NCHWc_int8_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(16, i1_0 * 2 + i1_1 + ax1) + v2 = T.axis.spatial(56, i2_0 * 2 + ax2) + v3 = T.axis.spatial(56, i3_0 + ax3) + v4 = T.axis.spatial(16, ax4) + T.reads(conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4]) + T.writes(conv2d_NCHWc_int8[v0, v1, v2, v3, v4]) + conv2d_NCHWc_int8[v0, v1, v2, v3, v4] = conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4] + + @T.prim_func + def vnni_conv2d_nchwc_1(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + conv2d_NCHWc_int8_global = T.alloc_buffer([1, 16, 56, 56, 16], dtype="int32") + for i0_0, i1_0, i2_0, i3_0, i4_0_0 in T.grid(1, 8, 28, 56, 1): + for i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): + with T.block("conv2d_NCHWc_int8_o"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) + oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) + ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1) + ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) + T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, 0 : 16]) + T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"}) + with T.init(): + for i4_1 in T.serial(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_i_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init]) + conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i_init] = 0 + for i4_1, i9_1 in T.grid(16, 4): + with T.block("conv2d_NCHWc_int8"): + oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1]) + T.reads(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) + T.writes(conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8_global[n, oc_chunk, oh, ow, oc_block_i] + T.cast(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], "int32") * T.cast(placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i], "int32") + for ax0, ax1, ax2, ax3, ax4 in T.grid(1, 2, 2, 1, 16): + with T.block("conv2d_NCHWc_int8_global"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(16, i1_0 * 2 + ax1) + v2 = T.axis.spatial(56, i2_0 * 2 + ax2) + v3 = T.axis.spatial(56, i3_0 + ax3) + v4 = T.axis.spatial(16, ax4) + T.reads(conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4]) + T.writes(conv2d_NCHWc_int8[v0, v1, v2, v3, v4]) + conv2d_NCHWc_int8[v0, v1, v2, v3, v4] = conv2d_NCHWc_int8_global[v0, v1, v2, v3, v4] + + @T.prim_func + def vnni_conv2d_nchwc_2(placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0_0, i1_0, i2_0, i3_0, i4_0_0, i0_1, i1_1, i2_1, i3_1, i4_0_1, i5_0, i6_0, i7_0, i8_0, i9_0_0, i0_2, i1_2, i2_2, i3_2, i4_0_2, i5_1, i6_1, i7_1, i8_1, i9_0_1, i0_3, i1_3, i2_3, i3_3, i4_0_3 in T.grid(1, 8, 28, 56, 1, 1, 2, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 2, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 1, 1): + with T.block("conv2d_NCHWc_int8_o"): + n = T.axis.spatial(1, 0) + oc_chunk = T.axis.spatial(16, i1_0 * 2 + i1_1 + i1_2 + i1_3) + oh = T.axis.spatial(56, i2_0 * 2 + i2_1 * 2 + i2_2 + i2_3) + ow = T.axis.spatial(56, i3_3 + i3_0 + i3_1 + i3_2) + oc_block_o = T.axis.spatial(1, 0) + kh = T.axis.reduce(1, 0) + kw = T.axis.reduce(1, 0) + ic_outer = T.axis.reduce(4, i7_0 * 4 + i7_1) + ic_f_inner = T.axis.reduce(4, i8_0 + i8_1) + ic_s_inner_o = T.axis.reduce(1, 0) + T.reads(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 : ic_f_inner * 4 + 4], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, 0 : 16, 0 : 4]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, 0 : 16]) + T.block_attr({"meta_schedule.auto_tensorize":"dot_16x4_vnni"}) + with T.init(): + for i4_1 in T.serial(16): + with T.block("conv2d_NCHWc_int8_init"): + oc_block_i_init = T.axis.spatial(16, i4_1) + T.reads() + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init]) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i_init] = 0 + for i4_1, i9_1 in T.grid(16, 4): + with T.block("conv2d_NCHWc_int8"): + oc_block_i, ic_s_inner_i = T.axis.remap("SR", [i4_1, i9_1]) + T.reads(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i], placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"}) + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block_i] + T.cast(placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner_i], "int32") * T.cast(placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block_i, ic_s_inner_i], "int32") + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [8, 2, 1, 1]), + ("SamplePerfectTile", [28, 1, 2, 1]), + ("SamplePerfectTile", [56, 1, 1, 1]), + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [1, 1]), + ("SamplePerfectTile", [1, 1]), + ("SamplePerfectTile", [1, 4]), + ("SamplePerfectTile", [4, 1]), + ("SamplePerfectTile", [1, 1]), + ] + decision_1 = [ + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [8, 2, 1, 1]), + ("SamplePerfectTile", [28, 1, 2, 1]), + ("SamplePerfectTile", [56, 1, 1, 1]), + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [1, 1]), + ("SamplePerfectTile", [1, 1]), + ("SamplePerfectTile", [1, 4]), + ("SamplePerfectTile", [4, 1]), + ("SamplePerfectTile", [1, 1]), + ] + decision_2 = [ + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [8, 2, 1, 1]), + ("SamplePerfectTile", [28, 1, 2, 1]), + ("SamplePerfectTile", [56, 1, 1, 1]), + ("SamplePerfectTile", [1, 1, 1, 1]), + ("SamplePerfectTile", [1, 1]), + ("SamplePerfectTile", [1, 1]), + ("SamplePerfectTile", [1, 4]), + ("SamplePerfectTile", [4, 1]), + ("SamplePerfectTile", [1, 1]), + ] + + mod = conv2d_nchwc + target = Target("llvm -mcpu=cascadelake -num-cores=4") + actual = ms.TuneContext( + mod=mod, + target=Target(target), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.MultiLevelTilingWithIntrin( + VNNI_INTRIN, + structure="SSRSRS", + tile_binds=None, + max_innermost_factor=64, + vector_load_lens=None, + reuse_read=None, + reuse_write=ms.schedule_rule.ReuseType(req="may", levels=[1, 2], scope="global"), + ), + ], + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[vnni_conv2d_nchwc_0, vnni_conv2d_nchwc_1, vnni_conv2d_nchwc_2], + expected_decisions=[decision_0, decision_1, decision_2], + ) + + +def _check_dp4a_dense(m, n, k, in_dtype, out_dtype, expected_mods, expected_decisions): + def _dense(m, n, k, in_dtype, out_dtype): + X = te.placeholder((m, k), name="X", dtype=in_dtype) + W = te.placeholder((n, k), name="W", dtype=in_dtype) + ak = te.reduce_axis((0, k), name="k") + matmul = te.compute( + (m, n), + lambda i, j: te.sum( + X[i, ak].astype(out_dtype) * W[j, ak].astype(out_dtype), + axis=ak, + ), + name="compute", + ) + return te.create_prim_func([X, W, matmul]) + + mod = _dense(m, n, k, in_dtype, out_dtype) + actual = ms.TuneContext( + mod=mod, + target=Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.MultiLevelTilingWithIntrin( + DP4A_INTRIN, + structure="SSSRRSRS", + tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], + max_innermost_factor=64, + vector_load_lens=[1, 2, 3, 4], + reuse_read=ms.schedule_rule.ReuseType(req="must", levels=[4], scope="shared"), + reuse_write=ms.schedule_rule.ReuseType(req="must", levels=[3], scope="local"), + ) + ], + ).generate_design_space() + if expected_mods is None: + assert expected_decisions is None + assert len(actual) == 1 + assert_structural_equal(mod, actual[0].mod["main"]) + else: + check_sketches(mod, actual, expected_mods, expected_decisions) + + +def test_dp4a_dense(): + @T.prim_func + def dp4a_dense_0( + X: T.Buffer[(128, 128), "int8"], + W: T.Buffer[(128, 128), "int8"], + compute: T.Buffer[(128, 128), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + compute_local = T.alloc_buffer([128, 128], dtype="int32", scope="local") + X_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + W_shared = T.alloc_buffer([128, 128], dtype="int8", scope="shared") + for i0_0_i1_0_fused in T.thread_binding(1, thread="blockIdx.x"): + for i0_1_i1_1_fused in T.thread_binding(512, thread="vthread.x"): + for i0_2_i1_2_fused in T.thread_binding(2, thread="threadIdx.x"): + for i2_0_0 in T.serial(1): + for ax0_ax1_fused in T.serial(16384): + with T.block("X_shared"): + v0 = T.axis.spatial(128, ax0_ax1_fused // 128) + v1 = T.axis.spatial(128, ax0_ax1_fused % 128) + T.reads(X[v0, v1]) + T.writes(X_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + X_shared[v0, v1] = X[v0, v1] + for ax0_ax1_fused in T.serial(16384): + with T.block("W_shared"): + v0 = T.axis.spatial(128, ax0_ax1_fused // 128) + v1 = T.axis.spatial(128, ax0_ax1_fused % 128) + T.reads(W[v0, v1]) + T.writes(W_shared[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch": 1}) + W_shared[v0, v1] = W[v0, v1] + for i2_0_1, i0_3, i1_3, i2_0_2, i0_4, i1_4 in T.grid(1, 2, 4, 32, 2, 1): + with T.block("compute_o"): + i = T.axis.spatial( + 128, + i0_1_i1_1_fused // 32 * 8 + + i0_2_i1_2_fused * 4 + + i0_3 * 2 + + i0_4, + ) + j = T.axis.spatial(128, i1_4 + i0_1_i1_1_fused % 32 * 4 + i1_3) + k_o = T.axis.reduce(32, i2_0_0 * 32 + i2_0_1 * 32 + i2_0_2) + T.reads( + X_shared[i, k_o * 4 : k_o * 4 + 4], + W_shared[j, k_o * 4 : k_o * 4 + 4], + ) + T.writes(compute_local[i, j]) + T.block_attr({"meta_schedule.auto_tensorize": "dp4a"}) + with T.init(): + with T.block("compute_init"): + T.reads() + T.writes(compute_local[i, j]) + compute_local[i, j] = 0 + for i2_1 in T.serial(4): + with T.block("compute"): + k_i = T.axis.reduce(4, i2_1) + T.reads( + compute_local[i, j], + X_shared[i, k_o * 4 + k_i], + W_shared[j, k_o * 4 + k_i], + ) + T.writes(compute_local[i, j]) + T.block_attr({"meta_schedule.tiling_structure": "SSSRRSRS"}) + compute_local[i, j] = compute_local[i, j] + T.cast( + X_shared[i, k_o * 4 + k_i], "int32" + ) * T.cast(W_shared[j, k_o * 4 + k_i], "int32") + for ax0, ax1 in T.grid(4, 4): + with T.block("compute_local"): + v0 = T.axis.spatial( + 128, i0_1_i1_1_fused // 32 * 8 + i0_2_i1_2_fused * 4 + ax0 + ) + v1 = T.axis.spatial(128, i0_1_i1_1_fused % 32 * 4 + ax1) + T.reads(compute_local[v0, v1]) + T.writes(compute[v0, v1]) + compute[v0, v1] = compute_local[v0, v1] + + decision_0 = [ + ("SamplePerfectTile", [1, 16, 2, 2, 2]), + ("SamplePerfectTile", [1, 32, 1, 4, 1]), + ("SamplePerfectTile", [1, 1, 32]), + ("SampleCategorical", 0), + ("SampleCategorical", 0), + ] + _check_dp4a_dense( + m=128, + n=128, + k=128, + in_dtype="int8", + out_dtype="int32", + expected_mods=[dp4a_dense_0], + expected_decisions=[decision_0], + ) + + +def test_dp4a_dense_no_tensorize_1(): + _check_dp4a_dense( + m=128, + n=128, + k=128, + in_dtype="float32", + out_dtype="float32", + expected_mods=None, + expected_decisions=None, + ) + + +def test_dp4a_dense_no_tensorize_2(): + _check_dp4a_dense( + m=127, + n=127, + k=127, + in_dtype="int8", + out_dtype="int32", + expected_mods=None, + expected_decisions=None, + ) + + +if __name__ == "__main__": + test_vnni_conv2d_nchwc() + test_dp4a_dense() + test_dp4a_dense_no_tensorize_1() + test_dp4a_dense_no_tensorize_2() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py new file mode 100644 index 000000000000..fbb74090b1e5 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py @@ -0,0 +1,957 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +import tvm +from tvm import meta_schedule as ms +from tvm import te +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.schedule_rule import get_rules +from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.script import tir as T +from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group + + +def multi_level_tiling_tensor_core( + *, + write_reuse_scope="shared", + in_dtype="float16", + out_dtype="float32", + trans_b=False, + use_software_pipeline=False, +) -> ms.schedule_rule.ScheduleRule: + assert write_reuse_scope in ["shared", "global"] + if not isinstance(in_dtype, list): + in_dtype = [in_dtype] + if not isinstance(out_dtype, list): + out_dtype = [out_dtype] + if not isinstance(trans_b, list): + trans_b = [trans_b] + return ms.schedule_rule.MultiLevelTilingTensorCore( + intrin_groups=[ + get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b) + for _in_dtype in in_dtype + for _out_dtype in out_dtype + for _trans_b in trans_b + ], + structure="SSSRRSRS", + tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"], + max_innermost_factor=4, # 64 // tensor intrin size + vector_load_lens=[1, 2, 3, 4, 8, 16], + reuse_read=ms.schedule_rule.ReuseType( + req="must", + levels=[4], + scope="shared", + ), + reuse_write=ms.schedule_rule.ReuseType( + req="must" if write_reuse_scope == "shared" else "no", + levels=[2], + scope=write_reuse_scope, + ), + use_software_pipeline=use_software_pipeline, + ) + + +def test_matmul_relu(): + # fmt: off + @T.prim_func + def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(8, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_0_0 in T.serial(1): + for ax0_ax1_fused in T.serial(4096): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(128, ax0_ax1_fused % 128) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + A_reindex_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in T.serial(4096): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(128, ax0_ax1_fused // 32) + v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in T.serial(4): + for ax0_0, ax1_0 in T.grid(2, 2): + with T.block("A_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax2_0_1 * 2 + ax1_0) + T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(8, ax2_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused) + T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 2, 2, 1): + with T.block("C_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused + ax1_0_3) + v2_o = T.axis.reduce(8, ax2_0_0 * 8 + ax2_0_1 * 2 + ax2_0_2) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused // 2 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused % 2 * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_2_ax1_0_2_fused) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1 in T.grid(32, 32): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0) + v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1) + T.reads(C_reindex_shared[v0, v1]) + T.writes(compute[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1, 1, 2]), + ("SamplePerfectTile", [2, 2, 2, 1, 1]), + ("SamplePerfectTile", [1, 4, 2]), + ("SampleCategorical", 3), + ("SampleCategorical", 3), + ("SampleCategorical", 0), + ] + + mod = te.create_prim_func( + te_workload.matmul_relu( + n=128, + m=128, + k=128, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = ms.TuneContext( + mod=mod, + target=tvm.target.Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[multi_level_tiling_tensor_core()] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[matmul_relu_0], + expected_decisions=[decision_0], + ) + + +def test_matmul_relu_with_fallback(): + # fmt: off + @T.prim_func + def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(2, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(2, thread="threadIdx.y"): + for ax2_0_0 in T.serial(2): + for ax0_ax1_fused in T.serial(2048): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 64) + v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":4}) + A_reindex_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in T.serial(8192): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(128, ax0_ax1_fused % 128) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in T.serial(1): + for ax0_0, ax1_0 in T.grid(2, 4): + with T.block("A_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax1_0) + T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(4, 4): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) + T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 4, 2, 4): + with T.block("C_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0_3 * 4 + ax1_0_4) + v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 4 + ax2_0_2) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") + for ax0_0, ax1_0 in T.grid(2, 4): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(8, ax0_0_0_ax1_0_0_fused * 4 + ax0_0_1_ax1_0_1_fused * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused * 4 + ax1_0) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1 in T.grid(32, 128): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0) + v1 = T.axis.spatial(128, ax1) + T.reads(C_reindex_shared[v0, v1]) + T.writes(compute[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch":4}) + compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0)) + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [2, 2, 1, 1, 2]), + ("SamplePerfectTile", [1, 1, 2, 1, 4]), + ("SamplePerfectTile", [2, 1, 4]), + ("SampleCategorical", 3), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ] + + mod = te.create_prim_func( + te_workload.matmul_relu( + n=128, + m=128, + k=128, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = ms.TuneContext( + mod=mod, + target=tvm.target.Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + multi_level_tiling_tensor_core(), + ] + + get_rules( + "cuda", + ( + ms.schedule_rule.MultiLevelTiling, + ms.schedule_rule.AutoInline, + ), + ), + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[matmul_relu_fallback_0], + expected_decisions=[decision_0], + ) + + +def test_conv2d(): + # fmt: off + @T.prim_func + def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, 3, 32, 32), "float16"], conv2d_nhwc: T.Buffer[(1, 16, 16, 32), "float32"]) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16") + conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 32], dtype="float32", scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope="shared") + weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope="shared") + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], dtype="float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], dtype="float16", scope="wmma.matrix_b") + for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16") + for ax0_0_ax1_0_0_ax2_0_0_fused in T.thread_binding(2, thread="blockIdx.y"): + for ax0_1_ax1_0_1_ax2_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): + for ax0_2_ax1_0_2_ax2_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax3_0_0 in T.serial(1): + for ax0_ax1_fused in T.serial(4608): + with T.block("PadInput_reindex_shared"): + v0 = T.axis.spatial(256, ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0_ax1_fused // 288) + v1 = T.axis.spatial(288, ax0_ax1_fused % 288) + T.reads(PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) + T.writes(PadInput_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":2}) + PadInput_reindex_shared[v0, v1] = PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] + for ax0_ax1_fused in T.serial(4608): + with T.block("weight_reindex_shared"): + v0 = T.axis.spatial(288, ax0_ax1_fused // 16) + v1 = T.axis.spatial(32, ax0_0_ax1_0_0_ax2_0_0_fused * 16 + ax0_ax1_fused % 16) + T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) + T.writes(weight_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] + for ax3_0_1 in T.serial(18): + for ax0_0, ax1_0 in T.grid(1, 1): + with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): + v0_o, v1_o = T.axis.remap("SS", [ax0_1_ax1_0_1_ax2_0_1_fused, ax3_0_1]) + T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("PadInput_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 1): + with T.block("weight_reindex_shared_wmma.matrix_b_o"): + v0_o, v1_o = T.axis.remap("SS", [ax3_0_1, ax0_0_ax1_0_0_ax2_0_0_fused]) + T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("weight_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_3, ax1_0_3, ax2_0_3, ax3_0_2, ax0_4, ax1_0_4, ax2_0_4 in T.grid(1, 1, 1, 1, 1, 1, 1): + with T.block("conv2d_nhwc_o"): + v0 = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(16, ax1_0_4 + ax0_1_ax1_0_1_ax2_0_1_fused + ax1_0_3) + v2_o = T.axis.spatial(2, ax0_0_ax1_0_0_ax2_0_0_fused + ax2_0_3 + ax2_0_4) + v3_o = T.axis.reduce(18, ax3_0_0 * 18 + ax3_0_1 + ax3_0_2) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 : v1_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v3_o * 16 : v3_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + with T.init(): + for ax1_1, ax2_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v1_i_init, v2_i_init = T.axis.remap("SS", [ax1_1, ax2_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init] = T.float32(0) + for ax1_1, ax2_1, ax3_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v1_i, v2_i, v3_i = T.axis.remap("SSR", [ax1_1, ax2_1, ax3_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i], PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i], "float32") + for ax0_0, ax1_0 in T.grid(1, 1): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o, v1_o = T.axis.remap("SS", [ax0_1_ax1_0_1_ax2_0_1_fused, ax0_0_ax1_0_0_ax2_0_0_fused]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(256, ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0) + v1 = T.axis.spatial(32, ax0_0_ax1_0_0_ax2_0_0_fused * 16 + ax1) + T.reads(conv2d_nhwc_reindex_shared[v0, v1]) + T.writes(conv2d_nhwc[0, v0 // 16, v0 % 16, v1]) + T.block_attr({"meta_schedule.cooperative_fetch":3}) + conv2d_nhwc[0, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1, 1, 1]), + ("SamplePerfectTile", [1, 16, 1, 1, 1]), + ("SamplePerfectTile", [2, 1, 1, 1, 1]), + ("SamplePerfectTile", [1, 18, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ("SampleCategorical", 3), + ] + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + N=1, + H=16, + W=16, + CI=32, + CO=32, + kernel_size=3, + stride=1, + padding=1, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = ms.TuneContext( + mod=mod, + target=tvm.target.Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[multi_level_tiling_tensor_core()], + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[conv2d_0], + expected_decisions=[decision_0], + ) + + +def test_conv2d_more_intrin(): + # test adding inapplicable tensor intrinsics doesn't change the search space + # fmt: off + @T.prim_func + def conv2d_more_intrin_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3, 3, 32, 32), "float16"], conv2d_nhwc: T.Buffer[(1, 16, 16, 32), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="float16") + conv2d_nhwc_reindex_shared = T.alloc_buffer([256, 32], dtype="float32", scope="shared") + conv2d_nhwc_reindex_shared_wmma_accumulator = T.alloc_buffer([256, 32], dtype="float32", scope="wmma.accumulator") + PadInput_reindex_shared = T.alloc_buffer([256, 288], dtype="float16", scope="shared") + weight_reindex_shared = T.alloc_buffer([288, 32], dtype="float16", scope="shared") + PadInput_reindex_shared_wmma_matrix_a = T.alloc_buffer([256, 288], dtype="float16", scope="wmma.matrix_a") + weight_reindex_shared_wmma_matrix_b = T.alloc_buffer([288, 32], dtype="float16", scope="wmma.matrix_b") + for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 17 and 1 <= i2_1 and i2_1 < 17, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16") + for ax0_0_ax1_0_0_ax2_0_0_fused in T.thread_binding(4, thread="blockIdx.y"): + for ax0_1_ax1_0_1_ax2_0_1_fused in T.thread_binding(4, thread="blockIdx.x"): + for ax0_2_ax1_0_2_ax2_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax3_0_0 in T.serial(3): + for ax0_ax1_fused in T.serial(1536): + with T.block("PadInput_reindex_shared"): + v0 = T.axis.spatial(256, ax0_0_ax1_0_0_ax2_0_0_fused * 64 + ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0_ax1_fused // 96) + v1 = T.axis.spatial(288, ax3_0_0 * 96 + ax0_ax1_fused % 96) + T.reads(PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32]) + T.writes(PadInput_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + PadInput_reindex_shared[v0, v1] = PadInput[0, v1 // 96 + v0 // 16, v1 % 96 // 32 + v0 % 16, v1 % 32] + for ax0_ax1_fused in T.serial(3072): + with T.block("weight_reindex_shared"): + v0 = T.axis.spatial(288, ax3_0_0 * 96 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(32, ax0_ax1_fused % 32) + T.reads(weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1]) + T.writes(weight_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":8}) + weight_reindex_shared[v0, v1] = weight[v0 // 96, v0 % 96 // 32, v0 % 32, v1] + for ax3_0_1 in T.serial(2): + for ax0_0, ax1_0 in T.grid(1, 3): + with T.block("PadInput_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(16, ax0_0_ax1_0_0_ax2_0_0_fused * 4 + ax0_1_ax1_0_1_ax2_0_1_fused) + v1_o = T.axis.spatial(18, ax3_0_0 * 6 + ax3_0_1 * 3 + ax1_0) + T.reads(PadInput_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("PadInput_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + PadInput_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = PadInput_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(3, 2): + with T.block("weight_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(18, ax3_0_0 * 6 + ax3_0_1 * 3 + ax0_0) + v1_o = T.axis.spatial(2, ax1_0) + T.reads(weight_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("weight_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + weight_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = weight_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_3, ax1_0_3, ax2_0_3, ax3_0_2, ax0_4, ax1_0_4, ax2_0_4 in T.grid(1, 1, 2, 3, 1, 1, 1): + with T.block("conv2d_nhwc_o"): + v0 = T.axis.spatial(1, 0) + v1_o = T.axis.spatial(16, ax1_0_4 + ax0_0_ax1_0_0_ax2_0_0_fused * 4 + ax0_1_ax1_0_1_ax2_0_1_fused + ax1_0_3) + v2_o = T.axis.spatial(2, ax2_0_4 + ax2_0_3) + v3_o = T.axis.reduce(18, ax3_0_0 * 6 + ax3_0_1 * 3 + ax3_0_2) + T.reads(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 : v1_o * 16 + 16, v3_o * 16 : v3_o * 16 + 16], weight_reindex_shared_wmma_matrix_b[v3_o * 16 : v3_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 : v1_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + with T.init(): + for ax1_1, ax2_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_init"): + v1_i_init, v2_i_init = T.axis.remap("SS", [ax1_1, ax2_1]) + T.reads() + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init]) + conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i_init, v2_o * 16 + v2_i_init] = T.float32(0) + for ax1_1, ax2_1, ax3_1 in T.grid(16, 16, 16): + with T.block("conv2d_nhwc"): + v1_i, v2_i, v3_i = T.axis.remap("SSR", [ax1_1, ax2_1, ax3_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i], PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i]) + T.writes(conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v1_o * 16 + v1_i, v2_o * 16 + v2_i] + T.cast(PadInput_reindex_shared_wmma_matrix_a[v1_o * 16 + v1_i, v3_o * 16 + v3_i], "float32") * T.cast(weight_reindex_shared_wmma_matrix_b[v3_o * 16 + v3_i, v2_o * 16 + v2_i], "float32") + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(16, ax0_0_ax1_0_0_ax2_0_0_fused * 4 + ax0_1_ax1_0_1_ax2_0_1_fused) + v1_o = T.axis.spatial(2, ax1_0) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("conv2d_nhwc_reindex_shared_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1 in T.grid(16, 32): + with T.block("conv2d_nhwc_reindex_shared"): + v0 = T.axis.spatial(256, ax0_0_ax1_0_0_ax2_0_0_fused * 64 + ax0_1_ax1_0_1_ax2_0_1_fused * 16 + ax0) + v1 = T.axis.spatial(32, ax1) + T.reads(conv2d_nhwc_reindex_shared[v0, v1]) + T.writes(conv2d_nhwc[0, v0 // 16, v0 % 16, v1]) + T.block_attr({"meta_schedule.cooperative_fetch":3}) + conv2d_nhwc[0, v0 // 16, v0 % 16, v1] = conv2d_nhwc_reindex_shared[v0, v1] + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1, 1, 1]), + ("SamplePerfectTile", [4, 4, 1, 1, 1]), + ("SamplePerfectTile", [1, 1, 1, 2, 1]), + ("SamplePerfectTile", [3, 2, 3]), + ("SampleCategorical", 2), + ("SampleCategorical", 3), + ("SampleCategorical", 3), + ] + + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + N=1, + H=16, + W=16, + CI=32, + CO=32, + kernel_size=3, + stride=1, + padding=1, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = ms.TuneContext( + mod=mod, + target=tvm.target.Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + multi_level_tiling_tensor_core( + in_dtype="float16", + out_dtype=["float16", "float32"], + ), + ], + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[conv2d_more_intrin_0], + expected_decisions=[decision_0], + ) + + +def test_matmul_relu_pipeline(): + # fmt: off + @T.prim_func + def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C = T.alloc_buffer([128, 128], dtype="float32") + C_reindex_shared = T.alloc_buffer([128, 128], dtype="float32", scope="shared") + C_reindex_shared_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(16, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(1, thread="threadIdx.y"): + for ax2_0_0 in T.serial(4, annotations={"software_pipeline_order":[0, 3, 1, 4, 5, 2, 6], "software_pipeline_stage":[0, 0, 0, 0, 0, 1, 1]}): + for ax0_ax1_fused in T.serial(1024): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused % 32) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":4, "tir.manifest_shared_memory_local_stage":1}) + A_reindex_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in T.serial(1024): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(128, ax2_0_0 * 32 + ax0_ax1_fused // 32) + v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "double_buffer_scope":0, "meta_schedule.cooperative_fetch":2, "tir.manifest_shared_memory_local_stage":1}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in T.serial(2, annotations={"software_pipeline_order":[0, 1, 2], "software_pipeline_stage":[0, 0, 1]}): + for ax0_0, ax1_0 in T.grid(2, 1): + with T.block("A_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1) + T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(8, ax2_0_0 * 2 + ax2_0_1) + v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) + T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 1, 1, 2, 2): + with T.block("C_o"): + v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0_3 * 2 + ax0_0_4) + v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0_3 * 2 + ax1_0_4) + v2_o = T.axis.reduce(8, ax2_0_0 * 2 + ax2_0_1 + ax2_0_2) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) + C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") + for ax0_0, ax1_0 in T.grid(2, 2): + with T.block("C_reindex_shared_wmma.accumulator_o"): + v0_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused // 4 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_1_ax1_0_1_fused % 4 * 2 + ax1_0) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_shared"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_reindex_shared_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0, ax1 in T.grid(32, 32): + with T.block("C_reindex_shared"): + v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0) + v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax1) + T.reads(C_reindex_shared[v0, v1]) + T.writes(C[v0, v1]) + T.block_attr({"meta_schedule.cooperative_fetch":3}) + C[v0, v1] = C_reindex_shared[v0, v1] + for i0, i1 in T.grid(128, 128): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(C[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 4, 1, 1, 2]), + ("SamplePerfectTile", [1, 4, 1, 1, 2]), + ("SamplePerfectTile", [4, 2, 1]), + ("SampleCategorical", 2), + ("SampleCategorical", 2), + ("SampleCategorical", 1), + ] + mod = te.create_prim_func( + te_workload.matmul_relu( + n=128, + m=128, + k=128, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = ms.TuneContext( + mod=mod, + target=tvm.target.Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + multi_level_tiling_tensor_core( + use_software_pipeline=True, + ), + ], + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[matmul_relu_pipeline_0], + expected_decisions=[decision_0], + ) + + +def test_matmul_relu_global(): + # fmt: off + @T.prim_func + def matmul_relu_global_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "float16"], compute: T.Buffer[(128, 128), "float32"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + C = T.alloc_buffer([128, 128], dtype="float32") + C_reindex_wmma_accumulator = T.alloc_buffer([128, 128], dtype="float32", scope="wmma.accumulator") + A_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + B_reindex_shared = T.alloc_buffer([128, 128], dtype="float16", scope="shared") + A_reindex_shared_wmma_matrix_a = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_a") + B_reindex_shared_wmma_matrix_b = T.alloc_buffer([128, 128], dtype="float16", scope="wmma.matrix_b") + for ax0_0_0_ax1_0_0_fused in T.thread_binding(1, thread="blockIdx.y"): + for ax0_0_1_ax1_0_1_fused in T.thread_binding(1, thread="blockIdx.x"): + for ax0_0_2_ax1_0_2_fused in T.thread_binding(16, thread="threadIdx.y"): + for ax2_0_0 in T.serial(2): + for ax0_ax1_fused in T.serial(8192): + with T.block("A_reindex_shared"): + v0 = T.axis.spatial(128, ax0_ax1_fused // 64) + v1 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused % 64) + T.reads(A[v0, v1]) + T.writes(A_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + A_reindex_shared[v0, v1] = A[v0, v1] + for ax0_ax1_fused in T.serial(8192): + with T.block("B_reindex_shared"): + v0 = T.axis.spatial(128, ax2_0_0 * 64 + ax0_ax1_fused // 128) + v1 = T.axis.spatial(128, ax0_ax1_fused % 128) + T.reads(B[v0, v1]) + T.writes(B_reindex_shared[v0, v1]) + T.block_attr({"buffer_dim_align":[[0, 0, 32, 8]], "meta_schedule.cooperative_fetch":1}) + B_reindex_shared[v0, v1] = B[v0, v1] + for ax2_0_1 in T.serial(2): + for ax0_0, ax1_0 in T.grid(1, 2): + with T.block("A_reindex_shared_wmma.matrix_a_o"): + v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2) + v1_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax1_0) + T.reads(A_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_a"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("A_reindex_shared_wmma.matrix_a"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = A_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0, ax1_0 in T.grid(2, 4): + with T.block("B_reindex_shared_wmma.matrix_b_o"): + v0_o = T.axis.spatial(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax0_0) + v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) + T.reads(B_reindex_shared[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_load_16x16x16_f16_b"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("B_reindex_shared_wmma.matrix_b"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + B_reindex_shared_wmma_matrix_b[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = B_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for ax0_0_3, ax1_0_3, ax2_0_2, ax0_0_4, ax1_0_4 in T.grid(1, 4, 2, 1, 1): + with T.block("C_o"): + v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2 + ax0_0_3 + ax0_0_4) + v1_o = T.axis.spatial(8, ax1_0_4 + ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0_3) + v2_o = T.axis.reduce(8, ax2_0_0 * 4 + ax2_0_1 * 2 + ax2_0_2) + T.reads(A_reindex_shared_wmma_matrix_a[v0_o * 16 : v0_o * 16 + 16, v2_o * 16 : v2_o * 16 + 16], B_reindex_shared_wmma_matrix_b[v2_o * 16 : v2_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_sync_16x16x16_f16f16f32", "meta_schedule.auto_tensorize_init":"wmma_fill_16x16x16_f32", "warp_execution":1}) + with T.init(): + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_init"): + v0_i_init, v1_i_init = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads() + T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init]) + C_reindex_wmma_accumulator[v0_o * 16 + v0_i_init, v1_o * 16 + v1_i_init] = T.float32(0) + for ax0_1, ax1_1, ax2_1 in T.grid(16, 16, 16): + with T.block("C"): + v0_i, v1_i, v2_i = T.axis.remap("SSR", [ax0_1, ax1_1, ax2_1]) + T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i], A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i]) + T.writes(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.block_attr({"meta_schedule.tiling_structure":"SSSRRSRS"}) + C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + T.cast(A_reindex_shared_wmma_matrix_a[v0_o * 16 + v0_i, v2_o * 16 + v2_i], "float32") * T.cast(B_reindex_shared_wmma_matrix_b[v2_o * 16 + v2_i, v1_o * 16 + v1_i], "float32") + for ax0_0, ax1_0 in T.grid(1, 4): + with T.block("C_reindex_wmma.accumulator_o"): + v0_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused // 2) + v1_o = T.axis.spatial(8, ax0_0_2_ax1_0_2_fused % 2 * 4 + ax1_0) + T.reads(C_reindex_wmma_accumulator[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.writes(C[v0_o * 16 : v0_o * 16 + 16, v1_o * 16 : v1_o * 16 + 16]) + T.block_attr({"meta_schedule.auto_tensorize":"wmma_store_16x16x16_f32_global"}) + for ax0_1, ax1_1 in T.grid(16, 16): + with T.block("C_reindex_wmma.accumulator"): + v0_i, v1_i = T.axis.remap("SS", [ax0_1, ax1_1]) + T.reads(C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + T.writes(C[v0_o * 16 + v0_i, v1_o * 16 + v1_i]) + C[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i] + for i0, i1 in T.grid(128, 128): + with T.block("compute"): + i0_1, i1_1 = T.axis.remap("SS", [i0, i1]) + T.reads(C[i0_1, i1_1]) + T.writes(compute[i0_1, i1_1]) + compute[i0_1, i1_1] = T.max(C[i0_1, i1_1], T.float32(0)) + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [1, 1, 8, 1, 1]), + ("SamplePerfectTile", [1, 1, 2, 4, 1]), + ("SamplePerfectTile", [2, 2, 2]), + ("SampleCategorical", 0), + ("SampleCategorical", 0), + ] + mod = te.create_prim_func( + te_workload.matmul_relu( + n=128, + m=128, + k=128, + in_dtype="float16", + out_dtype="float32", + ) + ) + actual = ms.TuneContext( + mod=mod, + target=tvm.target.Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[matmul_relu_global_0], + expected_decisions=[decision_0], + ) + + +def test_matmul_relu_non_tensorizable(): + # expected to do nothing on non-tensorizable workloads + mod = te.create_prim_func( + te_workload.matmul_relu( # dtype doesn't match tensor intrin + n=128, + m=128, + k=128, + ) + ) + (sch,) = ms.TuneContext( + mod=mod, + target=tvm.target.Target("cuda"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[multi_level_tiling_tensor_core(write_reuse_scope="global")] + + get_rules("cuda", ms.schedule_rule.AutoInline), + ).generate_design_space() + tvm.ir.assert_structural_equal(mod, sch.mod["main"]) + + +if __name__ == "__main__": + test_matmul_relu() + test_matmul_relu_with_fallback() + test_conv2d() + test_conv2d_more_intrin() + test_matmul_relu_pipeline() + test_matmul_relu_global() + test_matmul_relu_non_tensorizable() diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py b/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py deleted file mode 100644 index fe1220c50925..000000000000 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_multi_level_tiling.py +++ /dev/null @@ -1,1205 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring -import tvm -import tvm.testing -from tvm import te -from tvm.meta_schedule import schedule_rule -from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply -from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.testing.schedule_rule import ( - auto_inline, - multi_level_tiling, - multi_level_tiling_tensor_core, -) -from tvm.meta_schedule.testing.space_generation import check_trace -from tvm.meta_schedule.tune_context import TuneContext -from tvm.script import tir as T -from tvm.target import Target -from tvm.te import create_prim_func -from tvm.tir.tensor_intrin.arm_cpu import DP4A_INTRIN -from tvm.tir.tensor_intrin.x86 import VNNI_DOT_16x4_INTRIN as VNNI_INTRIN - - -def _create_context(mod, target, rule) -> TuneContext: - if not isinstance(rule, (list, tuple)): - rule = [rule] - ctx = TuneContext( - mod=mod, - target=target, - space_generator=PostOrderApply(), - sch_rules=rule, - task_name="test", - ) - return ctx - - -def test_cpu_matmul(): - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)", - ], - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - 'b24 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="global")', - "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)", - ], - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - ], - ] - target = Target("llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul( - n=512, - m=512, - k=512, - ) - ), - target=target, - rule=multi_level_tiling(target=target), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 3 - check_trace(spaces, expected) - - -def test_cpu_matmul_relu(): - # pylint: disable=line-too-long - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - "b24, = sch.get_consumers(block=b0)", - "sch.reverse_compute_at(block=b24, loop=l17, preserve_unit_loops=True, index=-1)", - ], - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - "b24, = sch.get_consumers(block=b0)", - "sch.reverse_compute_at(block=b24, loop=l16, preserve_unit_loops=True, index=-1)", - ], - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7 = sch.sample_perfect_tile(loop=l1, n=4, max_innermost_factor=64)", - "l8, l9, l10, l11 = sch.split(loop=l1, factors=[v4, v5, v6, v7], preserve_unit_iters=True)", - "v12, v13, v14, v15 = sch.sample_perfect_tile(loop=l2, n=4, max_innermost_factor=64)", - "l16, l17, l18, l19 = sch.split(loop=l2, factors=[v12, v13, v14, v15], preserve_unit_iters=True)", - "v20, v21 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l22, l23 = sch.split(loop=l3, factors=[v20, v21], preserve_unit_iters=True)", - "sch.reorder(l8, l16, l9, l17, l22, l10, l18, l23, l11, l19)", - ], - ] - # pylint: enable=line-too-long - target = Target("llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul_relu( - n=512, - m=512, - k=512, - ) - ), - target=target, - rule=multi_level_tiling(target=target), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 3 - check_trace(spaces, expected) - - -def test_cuda_matmul(): - # pylint: disable=line-too-long - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", - "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8], preserve_unit_iters=True)", - "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", - "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True)", - "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", - "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26], preserve_unit_iters=True)", - "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", - "l30 = sch.fuse(l9, l19, preserve_unit_iters=True)", - 'sch.bind(loop=l30, thread_axis="blockIdx.x")', - "l31 = sch.fuse(l10, l20, preserve_unit_iters=True)", - 'sch.bind(loop=l31, thread_axis="vthread.x")', - "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)", - 'sch.bind(loop=l32, thread_axis="threadIdx.x")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32)', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=1024)', - 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)", - 'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)", - "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", - "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)", - "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", - 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', - 'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)", - "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", - "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)", - "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", - 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', - ] - ] - # pylint: enable=line-too-long - target = Target("cuda --max_threads_per_block=1024 --thread_warp_size=32", host="llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul( - n=512, - m=512, - k=512, - ) - ), - target=target, - rule=multi_level_tiling(target=target), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - check_trace(spaces, expected) - - -def test_cuda_matmul_relu(): - # pylint: disable=line-too-long - expected = [ - [ - 'b0 = sch.get_block(name="C", func_name="main")', - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5, v6, v7, v8 = sch.sample_perfect_tile(loop=l1, n=5, max_innermost_factor=64)", - "l9, l10, l11, l12, l13 = sch.split(loop=l1, factors=[v4, v5, v6, v7, v8], preserve_unit_iters=True)", - "v14, v15, v16, v17, v18 = sch.sample_perfect_tile(loop=l2, n=5, max_innermost_factor=64)", - "l19, l20, l21, l22, l23 = sch.split(loop=l2, factors=[v14, v15, v16, v17, v18], preserve_unit_iters=True)", - "v24, v25, v26 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64)", - "l27, l28, l29 = sch.split(loop=l3, factors=[v24, v25, v26], preserve_unit_iters=True)", - "sch.reorder(l9, l19, l10, l20, l11, l21, l27, l28, l12, l22, l29, l13, l23)", - "l30 = sch.fuse(l9, l19, preserve_unit_iters=True)", - 'sch.bind(loop=l30, thread_axis="blockIdx.x")', - "l31 = sch.fuse(l10, l20, preserve_unit_iters=True)", - 'sch.bind(loop=l31, thread_axis="vthread.x")', - "l32 = sch.fuse(l11, l21, preserve_unit_iters=True)", - 'sch.bind(loop=l32, thread_axis="threadIdx.x")', - 'b33 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local")', - "sch.reverse_compute_at(block=b33, loop=l32, preserve_unit_loops=True, index=-1)", - 'b34 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared")', - "sch.compute_at(block=b34, loop=l27, preserve_unit_loops=True, index=-1)", - "l35, l36, l37, l38, l39, l40 = sch.get_loops(block=b34)", - "l41 = sch.fuse(l39, l40, preserve_unit_iters=True)", - "v42 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", - 'sch.annotate(block_or_loop=b34, ann_key="meta_schedule.cooperative_fetch", ann_val=v42)', - 'b43 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared")', - "sch.compute_at(block=b43, loop=l27, preserve_unit_loops=True, index=-1)", - "l44, l45, l46, l47, l48, l49 = sch.get_loops(block=b43)", - "l50 = sch.fuse(l48, l49, preserve_unit_iters=True)", - "v51 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25])", - 'sch.annotate(block_or_loop=b43, ann_key="meta_schedule.cooperative_fetch", ann_val=v51)', - ] - ] - # pylint: enable=line-too-long - target = Target("cuda", host="llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul_relu( - n=512, - m=512, - k=512, - ) - ), - target=target, - rule=multi_level_tiling(target=target), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - check_trace(spaces, expected) - - -def test_cuda_sum_with_trivial_block_iter(): - @T.prim_func - def sum_with_trivial_block_iter( - A: T.Buffer[(1, 64, 768), "float32"], B: T.Buffer[(1, 64, 1), "float32"] - ) -> None: - for i0, i1, i2, i3 in T.grid(1, 64, 1, 768): - with T.block("sum"): - ax0, ax1, ax2, k2 = T.axis.remap("SSSR", [i0, i1, i2, i3]) - T.reads(A[ax0, ax1, k2]) - T.writes(B[ax0, ax1, ax2]) - with T.init(): - B[ax0, ax1, ax2] = T.float32(0) - B[ax0, ax1, ax2] = B[ax0, ax1, ax2] + A[ax0, ax1, k2] - - # Expect nothing to happen - the rule is not supposed to be applied in this case - expected = [[]] - target = Target("cuda", host="llvm") - ctx = _create_context( - sum_with_trivial_block_iter, - target=target, - rule=multi_level_tiling(target=target), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - check_trace(spaces, expected) - - -@tvm.script.ir_module -class Conv2dNCHWcVNNIModule: - @T.prim_func - def main( - placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], - placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], - conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], - ) -> None: - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): - with T.block("conv2d_NCHWc_int8"): - ( - n, - oc_chunk, - oh, - ow, - oc_block, - kh, - kw, - ic_outer, - ic_f_inner, - ic_s_inner, - ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) - T.reads( - placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], - placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], - ) - T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) - with T.init(): - conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 - conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ - n, oc_chunk, oh, ow, oc_block - ] + T.cast( - placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" - ) * T.cast( - placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], - "int32", - ) - - -def test_multi_level_tiling_conv2d_nchwc_vnni(): - target = "llvm -mcpu=cascadelake -num-cores 4" - ctx = _create_context( - Conv2dNCHWcVNNIModule, - target=tvm.target.Target(target), - rule=schedule_rule.MultiLevelTilingWithIntrin( - VNNI_INTRIN, - structure="SSRSRS", - tile_binds=None, - max_innermost_factor=64, - vector_load_lens=None, - reuse_read=None, - reuse_write=schedule_rule.ReuseType( - req="may", - levels=[1, 2], - scope="global", - ), - ), - ) - - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - - expected = [ - """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") -l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) -l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True) -l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) -l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) -sch.reorder(l21, l22, l23, l24, l25, l14, l12) -b27 = sch.blockize(loop=l14) -sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") -l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) -v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) -l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True) -v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) -l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True) -v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) -l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True) -v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) -l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True) -v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) -l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True) -v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) -l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True) -v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) -l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True) -v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) -l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True) -v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) -l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True) -v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) -l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True) -sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) -b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") -sch.reverse_compute_at(block=b98, loop=l75, preserve_unit_loops=True, index=-1)""".split( - "\n" - ), - """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") -l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) -l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True) -l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) -l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) -sch.reorder(l21, l22, l23, l24, l25, l14, l12) -b27 = sch.blockize(loop=l14) -sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") -l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) -v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) -l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True) -v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) -l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True) -v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) -l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True) -v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) -l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True) -v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) -l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True) -v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) -l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True) -v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) -l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True) -v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) -l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True) -v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) -l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True) -v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) -l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True) -sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77) -b98 = sch.cache_write(block=b27, write_buffer_index=0, storage_scope="global") -sch.reverse_compute_at(block=b98, loop=l74, preserve_unit_loops=True, index=-1)""".split( - "\n" - ), - """b0 = sch.get_block(name="conv2d_NCHWc_int8", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSRSRS") -l1, l2, l3, l4, l5, l6, l7, l8, l9, l10 = sch.get_loops(block=b0) -l11, l12 = sch.split(loop=l10, factors=[None, 4], preserve_unit_iters=True) -l13, l14 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) -l15, l16, l17, l18, l19, l20, l21, l22, l23, l24, l25, l26 = sch.get_loops(block=b0) -sch.reorder(l21, l22, l23, l24, l25, l14, l12) -b27 = sch.blockize(loop=l14) -sch.annotate(block_or_loop=b27, ann_key="meta_schedule.auto_tensorize", ann_val="dot_16x4_vnni") -l28, l29, l30, l31, l32, l33, l34, l35, l36, l37 = sch.get_loops(block=b27) -v38, v39, v40, v41 = sch.sample_perfect_tile(loop=l28, n=4, max_innermost_factor=64) -l42, l43, l44, l45 = sch.split(loop=l28, factors=[v38, v39, v40, v41], preserve_unit_iters=True) -v46, v47, v48, v49 = sch.sample_perfect_tile(loop=l29, n=4, max_innermost_factor=64) -l50, l51, l52, l53 = sch.split(loop=l29, factors=[v46, v47, v48, v49], preserve_unit_iters=True) -v54, v55, v56, v57 = sch.sample_perfect_tile(loop=l30, n=4, max_innermost_factor=64) -l58, l59, l60, l61 = sch.split(loop=l30, factors=[v54, v55, v56, v57], preserve_unit_iters=True) -v62, v63, v64, v65 = sch.sample_perfect_tile(loop=l31, n=4, max_innermost_factor=64) -l66, l67, l68, l69 = sch.split(loop=l31, factors=[v62, v63, v64, v65], preserve_unit_iters=True) -v70, v71, v72, v73 = sch.sample_perfect_tile(loop=l32, n=4, max_innermost_factor=64) -l74, l75, l76, l77 = sch.split(loop=l32, factors=[v70, v71, v72, v73], preserve_unit_iters=True) -v78, v79 = sch.sample_perfect_tile(loop=l33, n=2, max_innermost_factor=64) -l80, l81 = sch.split(loop=l33, factors=[v78, v79], preserve_unit_iters=True) -v82, v83 = sch.sample_perfect_tile(loop=l34, n=2, max_innermost_factor=64) -l84, l85 = sch.split(loop=l34, factors=[v82, v83], preserve_unit_iters=True) -v86, v87 = sch.sample_perfect_tile(loop=l35, n=2, max_innermost_factor=64) -l88, l89 = sch.split(loop=l35, factors=[v86, v87], preserve_unit_iters=True) -v90, v91 = sch.sample_perfect_tile(loop=l36, n=2, max_innermost_factor=64) -l92, l93 = sch.split(loop=l36, factors=[v90, v91], preserve_unit_iters=True) -v94, v95 = sch.sample_perfect_tile(loop=l37, n=2, max_innermost_factor=64) -l96, l97 = sch.split(loop=l37, factors=[v94, v95], preserve_unit_iters=True) -sch.reorder(l42, l50, l58, l66, l74, l43, l51, l59, l67, l75, l80, l84, l88, l92, l96, l44, l52, l60, l68, l76, l81, l85, l89, l93, l97, l45, l53, l61, l69, l77)""".split( - "\n" - ), - ] - - check_trace(spaces, expected) - - -def _test_multi_level_tiling_dense_dp4a(m, n, k, in_dtype, out_dtype, expected): - X = te.placeholder((m, k), name="X", dtype=in_dtype) - W = te.placeholder((n, k), name="W", dtype=in_dtype) - ak = te.reduce_axis((0, k), name="k") - - matmul = te.compute( - (m, n), - lambda i, j: te.sum( - X[i, ak].astype(out_dtype) * W[j, ak].astype(out_dtype), - axis=ak, - ), - name="compute", - ) - - func = te.create_prim_func([X, W, matmul]) - - ctx = _create_context( - func, - target=tvm.target.Target("cuda"), - rule=schedule_rule.MultiLevelTilingWithIntrin( - DP4A_INTRIN, - structure="SSSRRSRS", - tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"], - max_innermost_factor=64, - vector_load_lens=[1, 2, 3, 4], - reuse_read=schedule_rule.ReuseType( - req="must", - levels=[4], - scope="shared", - ), - reuse_write=schedule_rule.ReuseType( - req="must", - levels=[3], - scope="local", - ), - ), - ) - - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - check_trace(spaces, expected) - - -def test_multi_level_tiling_dense_dp4a(): - m, n, k = 128, 128, 128 - - expected = [ - """b0 = sch.get_block(name="compute", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") -l1, l2, l3 = sch.get_loops(block=b0) -l4, l5 = sch.split(loop=l3, factors=[None, 4], preserve_unit_iters=True) -sch.reorder(l5) -b6 = sch.blockize(loop=l5) -sch.annotate(block_or_loop=b6, ann_key="meta_schedule.auto_tensorize", ann_val="dp4a") -l7, l8, l9 = sch.get_loops(block=b6) -v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64) -l15, l16, l17, l18, l19 = sch.split(loop=l7, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) -v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64) -l25, l26, l27, l28, l29 = sch.split(loop=l8, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) -v30, v31, v32 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64) -l33, l34, l35 = sch.split(loop=l9, factors=[v30, v31, v32], preserve_unit_iters=True) -sch.reorder(l15, l25, l16, l26, l17, l27, l33, l34, l18, l28, l35, l19, l29) -l36 = sch.fuse(l15, l25, preserve_unit_iters=True) -sch.bind(loop=l36, thread_axis="blockIdx.x") -l37 = sch.fuse(l16, l26, preserve_unit_iters=True) -sch.bind(loop=l37, thread_axis="vthread.x") -l38 = sch.fuse(l17, l27, preserve_unit_iters=True) -sch.bind(loop=l38, thread_axis="threadIdx.x") -b39 = sch.cache_write(block=b6, write_buffer_index=0, storage_scope="local") -sch.reverse_compute_at(block=b39, loop=l38, preserve_unit_loops=True, index=-1) -b40 = sch.cache_read(block=b6, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b40, loop=l33, preserve_unit_loops=True, index=-1) -l41, l42, l43, l44, l45, l46 = sch.get_loops(block=b40) -l47 = sch.fuse(l45, l46, preserve_unit_iters=True) -v48 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b40, ann_key="meta_schedule.cooperative_fetch", ann_val=v48) -b49 = sch.cache_read(block=b6, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b49, loop=l33, preserve_unit_loops=True, index=-1) -l50, l51, l52, l53, l54, l55 = sch.get_loops(block=b49) -l56 = sch.fuse(l54, l55, preserve_unit_iters=True) -v57 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b49, ann_key="meta_schedule.cooperative_fetch", ann_val=v57)""".split( - "\n" - ) - ] - - _test_multi_level_tiling_dense_dp4a(m, n, k, "int8", "int32", expected) - - -def test_multi_level_tiling_dense_dp4a_non_tensorizable(): - _test_multi_level_tiling_dense_dp4a(128, 128, 128, "float32", "float32", [""]) - _test_multi_level_tiling_dense_dp4a(127, 127, 127, "int8", "int32", [""]) - - -def test_cuda_tensor_core_matmul_relu(): - m = n = k = 128 - target = Target("cuda", host="llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul_relu( - n=n, - m=m, - k=k, - in_dtype="float16", - out_dtype="float32", - ) - ), - target=target, - rule=[ - multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"), - auto_inline(target), - ], - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - - expected = [ - """b0 = sch.get_block(name="C", func_name="main") -b1 = sch.get_block(name="compute", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") -b2 = sch.reindex(block=b0, buffer=("write", 0)) -b3 = sch.reindex(block=b0, buffer=("read", 0)) -b4 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) -sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, )) -l5, l6, l7 = sch.get_loops(block=b0) -l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True) -l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) -l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) -l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0) -sch.reorder(l16, l18, l13, l11, l9) -b20 = sch.blockize(loop=l13) -sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32") -sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") -sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1) -l21, l22, l23 = sch.get_loops(block=b20) -v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4) -l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True) -v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4) -l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True) -v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4) -l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True) -sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43) -l50 = sch.fuse(l29, l39, preserve_unit_iters=True) -sch.bind(loop=l50, thread_axis="blockIdx.y") -l51 = sch.fuse(l30, l40, preserve_unit_iters=True) -sch.bind(loop=l51, thread_axis="blockIdx.x") -l52 = sch.fuse(l31, l41, preserve_unit_iters=True) -sch.bind(loop=l52, thread_axis="threadIdx.y") -b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared") -sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1) -b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1) -v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) -sch.reverse_compute_inline(block=b2) -l56, l57, l58, l59, l60 = sch.get_loops(block=b54) -l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True) -l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True) -l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54) -sch.reorder(l70, l64, l62) -b72 = sch.blockize(loop=l64) -sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared") -b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1) -l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73) -l80 = sch.fuse(l78, l79, preserve_unit_iters=True) -v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) -b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1) -l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82) -l89 = sch.fuse(l87, l88, preserve_unit_iters=True) -v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90) -b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1) -l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91) -l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True) -l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True) -l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91) -sch.reorder(l110, l102, l100) -b112 = sch.blockize(loop=l102) -sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") -b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1) -l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113) -l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True) -l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True) -l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113) -sch.reorder(l132, l124, l122) -b134 = sch.blockize(loop=l124) -sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b") -sch.compute_inline(block=b3) -sch.compute_inline(block=b4) -sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8) -sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8) -sch.reverse_compute_inline(block=b1)""".split( - "\n" - ) - ] - check_trace(spaces, expected) - - # test multi_level_tiling_tensor_core and multi_level_tiling can be used together in order - # to use multi_level_tiling as a fallback when the workload can't be tensorized - ctx = _create_context( - create_prim_func( - te_workload.matmul_relu( - n=n, - m=m, - k=k, - in_dtype="float16", - out_dtype="float32", - ) - ), - target=target, - rule=[ - multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"), - multi_level_tiling(target=target), - auto_inline(target), - ], - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - check_trace(spaces, expected) - - -def test_cuda_tensor_core_software_pipeline_matmul_relu(): - m = n = k = 128 - target = Target("cuda", host="llvm") - ctx = _create_context( - create_prim_func( - te_workload.matmul_relu( - n=n, - m=m, - k=k, - in_dtype="float16", - out_dtype="float32", - ) - ), - target=target, - rule=[ - multi_level_tiling_tensor_core( - target=target, write_reuse_scope="shared", use_software_pipeline=True - ), - auto_inline(target), - ], - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - - expected = [ - """b0 = sch.get_block(name="C", func_name="main") -b1 = sch.get_block(name="compute", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") -b2 = sch.reindex(block=b0, buffer=("write", 0)) -b3 = sch.reindex(block=b0, buffer=("read", 0)) -b4 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) -sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b4, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, )) -l5, l6, l7 = sch.get_loops(block=b0) -l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True) -l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) -l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) -l14, l15, l16, l17, l18, l19 = sch.get_loops(block=b0) -sch.reorder(l16, l18, l13, l11, l9) -b20 = sch.blockize(loop=l13) -sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32") -sch.annotate(block_or_loop=b20, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") -sch.annotate(block_or_loop=b20, ann_key="warp_execution", ann_val=1) -l21, l22, l23 = sch.get_loops(block=b20) -v24, v25, v26, v27, v28 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4) -l29, l30, l31, l32, l33 = sch.split(loop=l21, factors=[v24, v25, v26, v27, v28], preserve_unit_iters=True) -v34, v35, v36, v37, v38 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4) -l39, l40, l41, l42, l43 = sch.split(loop=l22, factors=[v34, v35, v36, v37, v38], preserve_unit_iters=True) -v44, v45, v46 = sch.sample_perfect_tile(loop=l23, n=3, max_innermost_factor=4) -l47, l48, l49 = sch.split(loop=l23, factors=[v44, v45, v46], preserve_unit_iters=True) -sch.reorder(l29, l39, l30, l40, l31, l41, l47, l48, l32, l42, l49, l33, l43) -l50 = sch.fuse(l29, l39, preserve_unit_iters=True) -sch.bind(loop=l50, thread_axis="blockIdx.y") -l51 = sch.fuse(l30, l40, preserve_unit_iters=True) -sch.bind(loop=l51, thread_axis="blockIdx.x") -l52 = sch.fuse(l31, l41, preserve_unit_iters=True) -sch.bind(loop=l52, thread_axis="threadIdx.y") -b53 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="shared") -sch.reverse_compute_at(block=b53, loop=l51, preserve_unit_loops=True, index=-1) -b54 = sch.cache_write(block=b20, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b54, loop=l52, preserve_unit_loops=True, index=-1) -v55 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b53, ann_key="meta_schedule.cooperative_fetch", ann_val=v55) -sch.reverse_compute_inline(block=b2) -l56, l57, l58, l59, l60 = sch.get_loops(block=b54) -l61, l62 = sch.split(loop=l60, factors=[None, 16], preserve_unit_iters=True) -l63, l64 = sch.split(loop=l59, factors=[None, 16], preserve_unit_iters=True) -l65, l66, l67, l68, l69, l70, l71 = sch.get_loops(block=b54) -sch.reorder(l70, l64, l62) -b72 = sch.blockize(loop=l64) -sch.annotate(block_or_loop=b72, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared") -b73 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b73, loop=l47, preserve_unit_loops=True, index=-1) -l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b73) -l80 = sch.fuse(l78, l79, preserve_unit_iters=True) -v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b73, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) -b82 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b82, loop=l47, preserve_unit_loops=True, index=-1) -l83, l84, l85, l86, l87, l88 = sch.get_loops(block=b82) -l89 = sch.fuse(l87, l88, preserve_unit_iters=True) -v90 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b82, ann_key="meta_schedule.cooperative_fetch", ann_val=v90) -b91 = sch.cache_read(block=b20, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b91, loop=l48, preserve_unit_loops=True, index=-1) -l92, l93, l94, l95, l96, l97, l98 = sch.get_loops(block=b91) -l99, l100 = sch.split(loop=l98, factors=[None, 16], preserve_unit_iters=True) -l101, l102 = sch.split(loop=l97, factors=[None, 16], preserve_unit_iters=True) -l103, l104, l105, l106, l107, l108, l109, l110, l111 = sch.get_loops(block=b91) -sch.reorder(l110, l102, l100) -b112 = sch.blockize(loop=l102) -sch.annotate(block_or_loop=b112, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") -b113 = sch.cache_read(block=b20, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b113, loop=l48, preserve_unit_loops=True, index=-1) -l114, l115, l116, l117, l118, l119, l120 = sch.get_loops(block=b113) -l121, l122 = sch.split(loop=l120, factors=[None, 16], preserve_unit_iters=True) -l123, l124 = sch.split(loop=l119, factors=[None, 16], preserve_unit_iters=True) -l125, l126, l127, l128, l129, l130, l131, l132, l133 = sch.get_loops(block=b113) -sch.reorder(l132, l124, l122) -b134 = sch.blockize(loop=l124) -sch.annotate(block_or_loop=b134, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b") -sch.compute_inline(block=b3) -sch.compute_inline(block=b4) -sch.storage_align(block=b73, buffer_index=0, axis=-2, factor=32, offset=8) -sch.storage_align(block=b82, buffer_index=0, axis=-2, factor=32, offset=8) -sch.annotate(block_or_loop=b73, ann_key="tir.manifest_shared_memory_local_stage", ann_val=1) -sch.annotate(block_or_loop=b73, ann_key="double_buffer_scope", ann_val=0) -sch.annotate(block_or_loop=b82, ann_key="tir.manifest_shared_memory_local_stage", ann_val=1) -sch.annotate(block_or_loop=b82, ann_key="double_buffer_scope", ann_val=0) -sch.annotate(block_or_loop=l48, ann_key="software_pipeline_stage", ann_val=[0, 0, 1]) -sch.annotate(block_or_loop=l48, ann_key="software_pipeline_order", ann_val=[0, 1, 2]) -sch.annotate(block_or_loop=l47, ann_key="software_pipeline_stage", ann_val=[0, 0, 0, 0, 0, 1, 1]) -sch.annotate(block_or_loop=l47, ann_key="software_pipeline_order", ann_val=[0, 3, 1, 4, 5, 2, 6]) -sch.reverse_compute_inline(block=b1)""".split( - "\n" - ) - ] - check_trace(spaces, expected) - - -def test_cuda_tensor_core_matmul_relu_global(): - m = n = k = 128 - target = Target("cuda", host="llvm") - workload = create_prim_func( - te_workload.matmul_relu( - n=n, - m=m, - k=k, - in_dtype="float16", - out_dtype="float32", - ), - ) - ctx = _create_context( - workload, - target=target, - rule=[ - multi_level_tiling_tensor_core(target=target, write_reuse_scope="global"), - auto_inline(target), - ], - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - - expected = [ - """b0 = sch.get_block(name="C", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") -b1 = sch.reindex(block=b0, buffer=("write", 0)) -b2 = sch.reindex(block=b0, buffer=("read", 0)) -b3 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (k, j, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) -sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, )) -l4, l5, l6 = sch.get_loops(block=b0) -l7, l8 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) -l9, l10 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) -l11, l12 = sch.split(loop=l4, factors=[None, 16], preserve_unit_iters=True) -l13, l14, l15, l16, l17, l18 = sch.get_loops(block=b0) -sch.reorder(l15, l17, l12, l10, l8) -b19 = sch.blockize(loop=l12) -sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32") -sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") -sch.annotate(block_or_loop=b19, ann_key="warp_execution", ann_val=1) -l20, l21, l22 = sch.get_loops(block=b19) -v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=4) -l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True) -v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4) -l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True) -v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=4) -l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45], preserve_unit_iters=True) -sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42) -l49 = sch.fuse(l28, l38, preserve_unit_iters=True) -sch.bind(loop=l49, thread_axis="blockIdx.y") -l50 = sch.fuse(l29, l39, preserve_unit_iters=True) -sch.bind(loop=l50, thread_axis="blockIdx.x") -l51 = sch.fuse(l30, l40, preserve_unit_iters=True) -sch.bind(loop=l51, thread_axis="threadIdx.y") -b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1) -sch.reverse_compute_inline(block=b1) -l53, l54, l55, l56, l57 = sch.get_loops(block=b52) -l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True) -l60, l61 = sch.split(loop=l56, factors=[None, 16], preserve_unit_iters=True) -l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b52) -sch.reorder(l67, l61, l59) -b69 = sch.blockize(loop=l61) -sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global") -b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1) -l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70) -l77 = sch.fuse(l75, l76, preserve_unit_iters=True) -v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) -b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1) -l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79) -l86 = sch.fuse(l84, l85, preserve_unit_iters=True) -v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87) -b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1) -l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88) -l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True) -l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True) -l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b88) -sch.reorder(l107, l99, l97) -b109 = sch.blockize(loop=l99) -sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") -b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1) -l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110) -l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True) -l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True) -l122, l123, l124, l125, l126, l127, l128, l129, l130 = sch.get_loops(block=b110) -sch.reorder(l129, l121, l119) -b131 = sch.blockize(loop=l121) -sch.annotate(block_or_loop=b131, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b") -sch.compute_inline(block=b2) -sch.compute_inline(block=b3) -sch.storage_align(block=b70, buffer_index=0, axis=-2, factor=32, offset=8) -sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".split( - "\n" - ) - ] - check_trace(spaces, expected) - - ctx = _create_context( - workload, - target=target, - rule=[ - multi_level_tiling_tensor_core( - target=target, write_reuse_scope="global", trans_b=[False, True] - ), - auto_inline(target), - ], - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 2 - - expected = [ - expected[0], - """b0 = sch.get_block(name="C", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") -b1 = sch.reindex(block=b0, buffer=("write", 0)) -b2 = sch.reindex(block=b0, buffer=("read", 0)) -b3 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda i, k: (i, k, )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda j, k: (j, k, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda i, j: (i, j, )) -sch.transform_block_layout(block=b1, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b2, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b3, index_map=lambda i, j, k: (i, j, k, )) -sch.transform_block_layout(block=b0, index_map=lambda i, j, k: (i, j, k, )) -l4, l5, l6 = sch.get_loops(block=b0) -l7, l8 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) -l9, l10 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) -l11, l12 = sch.split(loop=l4, factors=[None, 16], preserve_unit_iters=True) -l13, l14, l15, l16, l17, l18 = sch.get_loops(block=b0) -sch.reorder(l15, l17, l12, l10, l8) -b19 = sch.blockize(loop=l12) -sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32_trans") -sch.annotate(block_or_loop=b19, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") -sch.annotate(block_or_loop=b19, ann_key="warp_execution", ann_val=1) -l20, l21, l22 = sch.get_loops(block=b19) -v23, v24, v25, v26, v27 = sch.sample_perfect_tile(loop=l20, n=5, max_innermost_factor=4) -l28, l29, l30, l31, l32 = sch.split(loop=l20, factors=[v23, v24, v25, v26, v27], preserve_unit_iters=True) -v33, v34, v35, v36, v37 = sch.sample_perfect_tile(loop=l21, n=5, max_innermost_factor=4) -l38, l39, l40, l41, l42 = sch.split(loop=l21, factors=[v33, v34, v35, v36, v37], preserve_unit_iters=True) -v43, v44, v45 = sch.sample_perfect_tile(loop=l22, n=3, max_innermost_factor=4) -l46, l47, l48 = sch.split(loop=l22, factors=[v43, v44, v45], preserve_unit_iters=True) -sch.reorder(l28, l38, l29, l39, l30, l40, l46, l47, l31, l41, l48, l32, l42) -l49 = sch.fuse(l28, l38, preserve_unit_iters=True) -sch.bind(loop=l49, thread_axis="blockIdx.y") -l50 = sch.fuse(l29, l39, preserve_unit_iters=True) -sch.bind(loop=l50, thread_axis="blockIdx.x") -l51 = sch.fuse(l30, l40, preserve_unit_iters=True) -sch.bind(loop=l51, thread_axis="threadIdx.y") -b52 = sch.cache_write(block=b19, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b52, loop=l51, preserve_unit_loops=True, index=-1) -sch.reverse_compute_inline(block=b1) -l53, l54, l55, l56, l57 = sch.get_loops(block=b52) -l58, l59 = sch.split(loop=l57, factors=[None, 16], preserve_unit_iters=True) -l60, l61 = sch.split(loop=l56, factors=[None, 16], preserve_unit_iters=True) -l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b52) -sch.reorder(l67, l61, l59) -b69 = sch.blockize(loop=l61) -sch.annotate(block_or_loop=b69, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_global") -b70 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b70, loop=l46, preserve_unit_loops=True, index=-1) -l71, l72, l73, l74, l75, l76 = sch.get_loops(block=b70) -l77 = sch.fuse(l75, l76, preserve_unit_iters=True) -v78 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b70, ann_key="meta_schedule.cooperative_fetch", ann_val=v78) -b79 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b79, loop=l46, preserve_unit_loops=True, index=-1) -l80, l81, l82, l83, l84, l85 = sch.get_loops(block=b79) -l86 = sch.fuse(l84, l85, preserve_unit_iters=True) -v87 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b79, ann_key="meta_schedule.cooperative_fetch", ann_val=v87) -b88 = sch.cache_read(block=b19, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b88, loop=l47, preserve_unit_loops=True, index=-1) -l89, l90, l91, l92, l93, l94, l95 = sch.get_loops(block=b88) -l96, l97 = sch.split(loop=l95, factors=[None, 16], preserve_unit_iters=True) -l98, l99 = sch.split(loop=l94, factors=[None, 16], preserve_unit_iters=True) -l100, l101, l102, l103, l104, l105, l106, l107, l108 = sch.get_loops(block=b88) -sch.reorder(l107, l99, l97) -b109 = sch.blockize(loop=l99) -sch.annotate(block_or_loop=b109, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") -b110 = sch.cache_read(block=b19, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b110, loop=l47, preserve_unit_loops=True, index=-1) -l111, l112, l113, l114, l115, l116, l117 = sch.get_loops(block=b110) -l118, l119 = sch.split(loop=l117, factors=[None, 16], preserve_unit_iters=True) -l120, l121 = sch.split(loop=l116, factors=[None, 16], preserve_unit_iters=True) -l122, l123, l124, l125, l126, l127, l128, l129, l130 = sch.get_loops(block=b110) -sch.reorder(l129, l121, l119) -b131 = sch.blockize(loop=l121) -sch.annotate(block_or_loop=b131, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b_trans") -sch.compute_inline(block=b2) -sch.compute_inline(block=b3) -sch.storage_align(block=b70, buffer_index=0, axis=-2, factor=32, offset=8) -sch.storage_align(block=b79, buffer_index=0, axis=-2, factor=32, offset=8)""".split( - "\n" - ), - ] - check_trace(spaces, expected) - - -def test_multi_level_tiling_non_tensorizable(): - # expected to do nothing on non-tensorizable workloads - m = n = k = 128 - target = Target("cuda", host="llvm") - ctx = _create_context( - create_prim_func( - # dtype doesn't match tensor intrin - te_workload.matmul_relu( - n=n, - m=m, - k=k, - ) - ), - target=target, - rule=multi_level_tiling_tensor_core(target=target, write_reuse_scope="global"), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - - expected = [ - "", # expected to do nothing when the workload can't be tensorized - ] - check_trace(spaces, expected) - - -def test_cuda_tensor_core_conv2d(): - target = Target("cuda", host="llvm") - workload = create_prim_func( - te_workload.conv2d_nhwc( - N=1, - H=16, - W=16, - CI=32, - CO=32, - kernel_size=3, - stride=1, - padding=1, - in_dtype="float16", - out_dtype="float32", - ) - ) - ctx = _create_context( - workload, - target=target, - rule=multi_level_tiling_tensor_core(target=target, write_reuse_scope="shared"), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - - expected = [ - """b0 = sch.get_block(name="conv2d_nhwc", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") -b1 = sch.reindex(block=b0, buffer=("write", 0)) -b2 = sch.reindex(block=b0, buffer=("read", 0)) -b3 = sch.reindex(block=b0, buffer=("read", 1)) -sch.transform_layout(block=b0, buffer=("read", 0), index_map=lambda h, w, rh, rw, rc: (((h*16) + w), (((rh*96) + (rw*32)) + rc), )) -sch.transform_layout(block=b0, buffer=("read", 1), index_map=lambda co, rh, rw, rc: ((((rh*96) + (rw*32)) + rc), co, )) -sch.transform_layout(block=b0, buffer=("write", 0), index_map=lambda h, w, co: (((h*16) + w), co, )) -sch.transform_block_layout(block=b1, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) -sch.transform_block_layout(block=b2, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) -sch.transform_block_layout(block=b3, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) -sch.transform_block_layout(block=b0, index_map=lambda n, h, w, co, rh, rw, rc: (n, ((h*16) + w), co, (((rh*96) + (rw*32)) + rc), )) -l4, l5, l6, l7 = sch.get_loops(block=b0) -l8, l9 = sch.split(loop=l7, factors=[None, 16], preserve_unit_iters=True) -l10, l11 = sch.split(loop=l6, factors=[None, 16], preserve_unit_iters=True) -l12, l13 = sch.split(loop=l5, factors=[None, 16], preserve_unit_iters=True) -l14, l15, l16, l17, l18, l19, l20 = sch.get_loops(block=b0) -sch.reorder(l17, l19, l13, l11, l9) -b21 = sch.blockize(loop=l13) -sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_sync_16x16x16_f16f16f32") -sch.annotate(block_or_loop=b21, ann_key="meta_schedule.auto_tensorize_init", ann_val="wmma_fill_16x16x16_f32") -sch.annotate(block_or_loop=b21, ann_key="warp_execution", ann_val=1) -l22, l23, l24, l25 = sch.get_loops(block=b21) -v26, v27, v28, v29, v30 = sch.sample_perfect_tile(loop=l22, n=5, max_innermost_factor=4) -l31, l32, l33, l34, l35 = sch.split(loop=l22, factors=[v26, v27, v28, v29, v30], preserve_unit_iters=True) -v36, v37, v38, v39, v40 = sch.sample_perfect_tile(loop=l23, n=5, max_innermost_factor=4) -l41, l42, l43, l44, l45 = sch.split(loop=l23, factors=[v36, v37, v38, v39, v40], preserve_unit_iters=True) -v46, v47, v48, v49, v50 = sch.sample_perfect_tile(loop=l24, n=5, max_innermost_factor=4) -l51, l52, l53, l54, l55 = sch.split(loop=l24, factors=[v46, v47, v48, v49, v50], preserve_unit_iters=True) -v56, v57, v58 = sch.sample_perfect_tile(loop=l25, n=3, max_innermost_factor=4) -l59, l60, l61 = sch.split(loop=l25, factors=[v56, v57, v58], preserve_unit_iters=True) -sch.reorder(l31, l41, l51, l32, l42, l52, l33, l43, l53, l59, l60, l34, l44, l54, l61, l35, l45, l55) -l62 = sch.fuse(l31, l41, l51, preserve_unit_iters=True) -sch.bind(loop=l62, thread_axis="blockIdx.y") -l63 = sch.fuse(l32, l42, l52, preserve_unit_iters=True) -sch.bind(loop=l63, thread_axis="blockIdx.x") -l64 = sch.fuse(l33, l43, l53, preserve_unit_iters=True) -sch.bind(loop=l64, thread_axis="threadIdx.y") -b65 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="shared") -sch.reverse_compute_at(block=b65, loop=l63, preserve_unit_loops=True, index=-1) -b66 = sch.cache_write(block=b21, write_buffer_index=0, storage_scope="wmma.accumulator") -sch.reverse_compute_at(block=b66, loop=l64, preserve_unit_loops=True, index=-1) -v67 = sch.sample_categorical(candidates=[1, 2, 3, 4], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b65, ann_key="meta_schedule.cooperative_fetch", ann_val=v67) -sch.reverse_compute_inline(block=b1) -l68, l69, l70, l71, l72 = sch.get_loops(block=b66) -l73, l74 = sch.split(loop=l72, factors=[None, 16], preserve_unit_iters=True) -l75, l76 = sch.split(loop=l71, factors=[None, 16], preserve_unit_iters=True) -l77, l78, l79, l80, l81, l82, l83 = sch.get_loops(block=b66) -sch.reorder(l82, l76, l74) -b84 = sch.blockize(loop=l76) -sch.annotate(block_or_loop=b84, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_store_16x16x16_f32_shared") -b85 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="shared") -sch.compute_at(block=b85, loop=l59, preserve_unit_loops=True, index=-1) -l86, l87, l88, l89, l90, l91 = sch.get_loops(block=b85) -l92 = sch.fuse(l90, l91, preserve_unit_iters=True) -v93 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b85, ann_key="meta_schedule.cooperative_fetch", ann_val=v93) -b94 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="shared") -sch.compute_at(block=b94, loop=l59, preserve_unit_loops=True, index=-1) -l95, l96, l97, l98, l99, l100 = sch.get_loops(block=b94) -l101 = sch.fuse(l99, l100, preserve_unit_iters=True) -v102 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25]) -sch.annotate(block_or_loop=b94, ann_key="meta_schedule.cooperative_fetch", ann_val=v102) -b103 = sch.cache_read(block=b21, read_buffer_index=0, storage_scope="wmma.matrix_a") -sch.compute_at(block=b103, loop=l60, preserve_unit_loops=True, index=-1) -l104, l105, l106, l107, l108, l109, l110 = sch.get_loops(block=b103) -l111, l112 = sch.split(loop=l110, factors=[None, 16], preserve_unit_iters=True) -l113, l114 = sch.split(loop=l109, factors=[None, 16], preserve_unit_iters=True) -l115, l116, l117, l118, l119, l120, l121, l122, l123 = sch.get_loops(block=b103) -sch.reorder(l122, l114, l112) -b124 = sch.blockize(loop=l114) -sch.annotate(block_or_loop=b124, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_a") -b125 = sch.cache_read(block=b21, read_buffer_index=1, storage_scope="wmma.matrix_b") -sch.compute_at(block=b125, loop=l60, preserve_unit_loops=True, index=-1) -l126, l127, l128, l129, l130, l131, l132 = sch.get_loops(block=b125) -l133, l134 = sch.split(loop=l132, factors=[None, 16], preserve_unit_iters=True) -l135, l136 = sch.split(loop=l131, factors=[None, 16], preserve_unit_iters=True) -l137, l138, l139, l140, l141, l142, l143, l144, l145 = sch.get_loops(block=b125) -sch.reorder(l144, l136, l134) -b146 = sch.blockize(loop=l136) -sch.annotate(block_or_loop=b146, ann_key="meta_schedule.auto_tensorize", ann_val="wmma_load_16x16x16_f16_b") -sch.compute_inline(block=b2) -sch.compute_inline(block=b3) -sch.storage_align(block=b85, buffer_index=0, axis=-2, factor=32, offset=8) -sch.storage_align(block=b94, buffer_index=0, axis=-2, factor=32, offset=8)""".split( - "\n" - ) - ] - check_trace(spaces, expected) - - # test adding unappliable tensor intrinsics doesn't change the search space - ctx = _create_context( - workload, - target, - multi_level_tiling_tensor_core( - target=target, - write_reuse_scope="shared", - in_dtype="float16", - out_dtype=["float16", "float32"], - ), - ) - check_trace(spaces, expected) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 - - -if __name__ == "__main__": - tvm.testing.main()