Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule][Test] MLT uses SEqual tests #12805

Merged
merged 1 commit into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 19 additions & 119 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,122 +15,22 @@
# specific language governing permissions and limitations
# under the License.
"""Default schedule rules"""
from typing import List, Union

from tvm.meta_schedule.schedule_rule import (
AutoInline,
MultiLevelTiling,
MultiLevelTilingTensorCore,
ReuseType,
ScheduleRule,
)
from tvm.target import Target


def auto_inline(target: Target) -> ScheduleRule:
"""Default schedule rules for auto inline"""
if target.kind.name == "llvm":
return AutoInline(
into_producer=False,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=True,
require_injective=True,
require_ordered=True,
disallow_op=["tir.exp"],
)
if target.kind.name == "cuda":
return AutoInline(
into_producer=True,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
)
raise NotImplementedError(f"{target.kind.name} is not supported")


def multi_level_tiling(target: Target) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling and reuse"""
if target.kind.name == "llvm":
return MultiLevelTiling(
structure="SSRSRS",
tile_binds=None,
max_innermost_factor=64,
vector_load_lens=None,
reuse_read=None,
reuse_write=ReuseType(
req="may",
levels=[1, 2],
scope="global",
),
)
if target.kind.name == "cuda":
return MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
max_innermost_factor=64,
vector_load_lens=[1, 2, 3, 4, 8, 16],
reuse_read=ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=ReuseType(
req="must",
levels=[3],
scope="local",
),
)
raise NotImplementedError(f"{target.kind.name} is not supported")


def multi_level_tiling_tensor_core(
target: Target,
write_reuse_scope: str = "shared",
in_dtype: Union[str, List[str]] = "float16",
out_dtype: Union[str, List[str]] = "float32",
trans_b: Union[bool, List[bool]] = False,
use_software_pipeline: bool = False,
) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert write_reuse_scope in ["shared", "global"]
if not isinstance(in_dtype, list):
in_dtype = [in_dtype]
if not isinstance(out_dtype, list):
out_dtype = [out_dtype]
if not isinstance(trans_b, list):
trans_b = [trans_b]

if target.kind.name == "cuda":
from tvm.tir.tensor_intrin import ( # pylint: disable=import-outside-toplevel
cuda,
)

intrin_groups = [
cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
for _in_dtype in in_dtype
for _out_dtype in out_dtype
for _trans_b in trans_b
]
return MultiLevelTilingTensorCore(
intrin_groups=intrin_groups,
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
vector_load_lens=[1, 2, 3, 4, 8, 16],
reuse_read=ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=ReuseType(
req="must" if write_reuse_scope == "shared" else "no",
levels=[2],
scope=write_reuse_scope,
),
use_software_pipeline=use_software_pipeline,
)
raise NotImplementedError(f"{target.kind.name} is not supported")
from typing import List, Tuple, Union

from tvm.meta_schedule import default_config
from tvm.meta_schedule.schedule_rule import ScheduleRule


def get_rules(kind: str, types: Union[type, Tuple[type, ...]]) -> List[ScheduleRule]:
"""Get default schedule rules"""
# pylint: disable=protected-access
if kind == "llvm":
rules = default_config._DefaultLLVM.schedule_rules()
elif kind == "cuda":
rules = default_config._DefaultCUDA.schedule_rules()
elif kind == "tensor_core":
rules = default_config._DefaultCUDATensorCore.schedule_rules()
else:
raise NotImplementedError(f"{kind} is not supported")
# pylint: enable=protected-access
return [rule for rule in rules if isinstance(rule, types)]
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
// Add local stage and double buffering
for (int i = 0; i < 2; ++i) {
const tir::BlockRV cache_read = state->read_reuse.at(i);
sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Bool(true));
sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Integer(1));
sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
}

Expand Down Expand Up @@ -529,7 +529,7 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
state->intrin_group.compute_intrin);
state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init,
state->intrin_group.init_intrin);
state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true));
junrushao marked this conversation as resolved.
Show resolved Hide resolved
state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Integer(1));
return {std::move(state)};
}

Expand Down
35 changes: 18 additions & 17 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,33 +77,34 @@ class PyLogMessage {
// FATAL not included
};

PyLogMessage(const std::string& file, int lineno, PackedFunc logging_func, Level logging_level) {
this->logging_func = logging_func;
this->logging_level = logging_level;
}
explicit PyLogMessage(const char* file, int lineno, PackedFunc logging_func, Level logging_level)
: file_(file), lineno_(lineno), logging_func_(logging_func), logging_level_(logging_level) {}

TVM_NO_INLINE ~PyLogMessage() {
if (this->logging_func.defined()) {
logging_func(static_cast<int>(logging_level), stream_.str());
if (this->logging_func_.defined()) {
logging_func_(static_cast<int>(logging_level_), stream_.str());
} else {
if (logging_level == Level::INFO) {
LOG(INFO) << stream_.str();
} else if (logging_level == Level::WARNING) {
LOG(WARNING) << stream_.str();
} else if (logging_level == Level::ERROR) {
LOG(ERROR) << stream_.str();
} else if (logging_level == Level::DEBUG) {
DLOG(INFO) << stream_.str();
if (logging_level_ == Level::INFO) {
runtime::detail::LogMessage(file_, lineno_).stream() << stream_.str();
} else if (logging_level_ == Level::WARNING) {
runtime::detail::LogMessage(file_, lineno_).stream() << "Warning: " << stream_.str();
} else if (logging_level_ == Level::ERROR) {
runtime::detail::LogMessage(file_, lineno_).stream() << "Error: " << stream_.str();
} else if (logging_level_ == Level::DEBUG) {
runtime::detail::LogMessage(file_, lineno_).stream() << "Debug: " << stream_.str();
} else {
LOG(FATAL) << stream_.str();
runtime::detail::LogFatal(file_, lineno_).stream() << stream_.str();
}
}
}
std::ostringstream& stream() { return stream_; }

private:
const char* file_;
int lineno_;
std::ostringstream stream_;
PackedFunc logging_func;
Level logging_level;
PackedFunc logging_func_;
Level logging_level_;
};

/*! \brief The type of the random state */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.schedule_rule import get_rules
from tvm.meta_schedule.testing.space_generation import check_sketches
from tvm.script import tir as T
from tvm.target import Target
Expand Down Expand Up @@ -83,12 +84,7 @@ def elementwise_0(
mod=mod,
target=Target("nvidia/geforce-rtx-3080", host="llvm"),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[
ms.schedule_rule.AutoBind(
max_threadblocks=256,
thread_extents=[32, 64, 128, 256, 512, 1024],
)
],
sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind),
task_name="test",
).generate_design_space()
check_sketches(
Expand Down Expand Up @@ -122,12 +118,7 @@ def reduction_loop_only_0(
mod=mod,
target=Target("nvidia/geforce-rtx-3080", host="llvm"),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[
ms.schedule_rule.AutoBind(
max_threadblocks=256,
thread_extents=[32, 64, 128, 256, 512, 1024],
)
],
sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind),
task_name="test",
).generate_design_space()
check_sketches(
Expand Down Expand Up @@ -158,12 +149,7 @@ def zero_dim_add_0(
mod=mod,
target=Target("nvidia/geforce-rtx-3080", host="llvm"),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[
ms.schedule_rule.AutoBind(
max_threadblocks=256,
thread_extents=[32, 64, 128, 256, 512, 1024],
)
],
sch_rules=get_rules("cuda", ms.schedule_rule.AutoBind),
task_name="test",
).generate_design_space()
check_sketches(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import tvm
from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm.meta_schedule.testing.schedule_rule import auto_inline
from tvm.meta_schedule.tune_context import TuneContext
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.schedule_rule import get_rules
from tvm.script import tir as T
from tvm.target import Target

Expand Down Expand Up @@ -340,10 +339,10 @@ def main(T_full: T.Buffer[(1, 12, 4096), "int64"]) -> None:


def _create_context(mod, target, rule):
ctx = TuneContext(
ctx = ms.TuneContext(
mod=mod,
target=target,
space_generator=PostOrderApply(),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[rule],
task_name="test",
)
Expand All @@ -356,7 +355,7 @@ def test_inline_consumer_chain():
ctx = _create_context(
mod=mod,
target=target,
rule=auto_inline(target=target),
rule=get_rules("llvm", ms.schedule_rule.AutoInline)[0],
)
(space,) = ctx.space_generator.generate_design_space(mod=mod)
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=Conv2DBiasBnReLUInlined)
Expand All @@ -368,7 +367,7 @@ def test_inline_into_cache():
ctx = _create_context(
mod=mod,
target=target,
rule=auto_inline(target=target),
rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0],
)
(space,) = ctx.space_generator.generate_design_space(mod=mod)
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=MultiLevelTiledConv2DAfterInline)
Expand All @@ -380,7 +379,7 @@ def test_inline_into_multiple_consumers():
ctx = _create_context(
mod=mod,
target=target,
rule=auto_inline(target=target),
rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0],
)
(space,) = ctx.space_generator.generate_design_space(mod=mod)
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=SoftmaxAfterInline)
Expand All @@ -392,7 +391,7 @@ def test_inline_pure_spatial():
ctx = _create_context(
mod=mod,
target=target,
rule=auto_inline(target=target),
rule=get_rules("llvm", ms.schedule_rule.AutoInline)[0],
)
(space,) = ctx.space_generator.generate_design_space(mod=mod)
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=AfterPureSpatial)
Expand All @@ -404,7 +403,7 @@ def test_inline_constant_tensor():
ctx = _create_context(
mod=mod,
target=target,
rule=auto_inline(target=target),
rule=get_rules("cuda", ms.schedule_rule.AutoInline)[0],
)
(space,) = ctx.space_generator.generate_design_space(mod=mod)
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing import te_workload
from tvm.meta_schedule.testing.schedule_rule import get_rules
from tvm.meta_schedule.testing.space_generation import check_sketches
from tvm.script import tir as T
from tvm.target import Target
Expand Down Expand Up @@ -283,9 +284,7 @@ def softmax_mn_3(
mod=mod,
target=Target("nvidia/geforce-rtx-3090", host="llvm"),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[
ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
],
sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
task_name="test",
).generate_design_space()
check_sketches(
Expand Down Expand Up @@ -481,9 +480,7 @@ def softmax_mn_after_inline_3(
mod=mod,
target=Target("nvidia/geforce-rtx-3090", host="llvm"),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[
ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
],
sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
task_name="test",
).generate_design_space()
check_sketches(
Expand Down Expand Up @@ -559,9 +556,7 @@ def batch_norm_bmn_1(A: T.Buffer[(1, 512, 512), "float32"], D: T.Buffer[1, "floa
mod=mod,
target=Target("nvidia/geforce-rtx-3090", host="llvm"),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[
ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
],
sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
task_name="test",
).generate_design_space()
check_sketches(
Expand Down Expand Up @@ -657,9 +652,7 @@ def argmax_1(
mod=mod,
target=Target("nvidia/geforce-rtx-3090", host="llvm"),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[
ms.schedule_rule.CrossThreadReduction(thread_extents=[4, 8, 16, 32, 64, 128, 256, 512])
],
sch_rules=get_rules("cuda", ms.schedule_rule.CrossThreadReduction),
task_name="test",
).generate_design_space()
check_sketches(
Expand Down
Loading