Skip to content

Commit

Permalink
[MetaSchedule][Test] MLT uses SEqual tests
Browse files Browse the repository at this point in the history
This PR finishes migration from `check_trace` (string-based equality
check on TIR trace) to `check_sketch` (SEqual-based equality check on
TIR). Here, we split multi-level-tiling into 3 files:
- Plain multi-level tiling without any intrinsics
- Multi-level tiling with intrinsics like VNNI, DP4a
- Multi-level tiling with TensorCore which comes with different handling

Besides, we cleaned up the testing folder and removed several methods
that are no longer useful for unittests.
  • Loading branch information
junrushao committed Sep 15, 2022
1 parent 5b43c62 commit 225981f
Show file tree
Hide file tree
Showing 11 changed files with 2,918 additions and 1,383 deletions.
138 changes: 19 additions & 119 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,122 +15,22 @@
# specific language governing permissions and limitations
# under the License.
"""Default schedule rules"""
from typing import List, Union

from tvm.meta_schedule.schedule_rule import (
AutoInline,
MultiLevelTiling,
MultiLevelTilingTensorCore,
ReuseType,
ScheduleRule,
)
from tvm.target import Target


def auto_inline(target: Target) -> ScheduleRule:
"""Default schedule rules for auto inline"""
if target.kind.name == "llvm":
return AutoInline(
into_producer=False,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=True,
require_injective=True,
require_ordered=True,
disallow_op=["tir.exp"],
)
if target.kind.name == "cuda":
return AutoInline(
into_producer=True,
into_consumer=True,
inline_const_tensor=True,
disallow_if_then_else=False,
require_injective=False,
require_ordered=False,
disallow_op=None,
)
raise NotImplementedError(f"{target.kind.name} is not supported")


def multi_level_tiling(target: Target) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling and reuse"""
if target.kind.name == "llvm":
return MultiLevelTiling(
structure="SSRSRS",
tile_binds=None,
max_innermost_factor=64,
vector_load_lens=None,
reuse_read=None,
reuse_write=ReuseType(
req="may",
levels=[1, 2],
scope="global",
),
)
if target.kind.name == "cuda":
return MultiLevelTiling(
structure="SSSRRSRS",
tile_binds=["blockIdx.x", "vthread.x", "threadIdx.x"],
max_innermost_factor=64,
vector_load_lens=[1, 2, 3, 4, 8, 16],
reuse_read=ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=ReuseType(
req="must",
levels=[3],
scope="local",
),
)
raise NotImplementedError(f"{target.kind.name} is not supported")


def multi_level_tiling_tensor_core(
target: Target,
write_reuse_scope: str = "shared",
in_dtype: Union[str, List[str]] = "float16",
out_dtype: Union[str, List[str]] = "float32",
trans_b: Union[bool, List[bool]] = False,
use_software_pipeline: bool = False,
) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert write_reuse_scope in ["shared", "global"]
if not isinstance(in_dtype, list):
in_dtype = [in_dtype]
if not isinstance(out_dtype, list):
out_dtype = [out_dtype]
if not isinstance(trans_b, list):
trans_b = [trans_b]

if target.kind.name == "cuda":
from tvm.tir.tensor_intrin import ( # pylint: disable=import-outside-toplevel
cuda,
)

intrin_groups = [
cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
for _in_dtype in in_dtype
for _out_dtype in out_dtype
for _trans_b in trans_b
]
return MultiLevelTilingTensorCore(
intrin_groups=intrin_groups,
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
vector_load_lens=[1, 2, 3, 4, 8, 16],
reuse_read=ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=ReuseType(
req="must" if write_reuse_scope == "shared" else "no",
levels=[2],
scope=write_reuse_scope,
),
use_software_pipeline=use_software_pipeline,
)
raise NotImplementedError(f"{target.kind.name} is not supported")
from typing import List, Type, Union

from tvm.meta_schedule import default_config
from tvm.meta_schedule.schedule_rule import ScheduleRule


def get_rules(kind: str, types: Union[Type, List[Type]]) -> List[ScheduleRule]:
"""Get default schedule rules"""
# pylint: disable=protected-access
if kind == "llvm":
rules = default_config._DefaultLLVM.schedule_rules()
elif kind == "cuda":
rules = default_config._DefaultCUDA.schedule_rules()
elif kind == "tensor_core":
rules = default_config._DefaultCUDATensorCore.schedule_rules()
else:
raise NotImplementedError(f"{kind} is not supported")
# pylint: enable=protected-access
return [rule for rule in rules if isinstance(rule, types)]
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddSoftwarePipeline(
// Add local stage and double buffering
for (int i = 0; i < 2; ++i) {
const tir::BlockRV cache_read = state->read_reuse.at(i);
sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Bool(true));
sch->Annotate(cache_read, tir::attr::manifest_shared_memory_local_stage, Integer(1));
sch->Annotate(cache_read, tir::attr::double_buffer_scope, Integer(0));
}

Expand Down Expand Up @@ -529,7 +529,7 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
state->intrin_group.compute_intrin);
state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init,
state->intrin_group.init_intrin);
state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true));
state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Integer(1));
return {std::move(state)};
}

Expand Down
35 changes: 18 additions & 17 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,33 +77,34 @@ class PyLogMessage {
// FATAL not included
};

PyLogMessage(const std::string& file, int lineno, PackedFunc logging_func, Level logging_level) {
this->logging_func = logging_func;
this->logging_level = logging_level;
}
explicit PyLogMessage(const char* file, int lineno, PackedFunc logging_func, Level logging_level)
: file_(file), lineno_(lineno), logging_func_(logging_func), logging_level_(logging_level) {}

TVM_NO_INLINE ~PyLogMessage() {
if (this->logging_func.defined()) {
logging_func(static_cast<int>(logging_level), stream_.str());
if (this->logging_func_.defined()) {
logging_func_(static_cast<int>(logging_level_), stream_.str());
} else {
if (logging_level == Level::INFO) {
LOG(INFO) << stream_.str();
} else if (logging_level == Level::WARNING) {
LOG(WARNING) << stream_.str();
} else if (logging_level == Level::ERROR) {
LOG(ERROR) << stream_.str();
} else if (logging_level == Level::DEBUG) {
DLOG(INFO) << stream_.str();
if (logging_level_ == Level::INFO) {
runtime::detail::LogMessage(file_, lineno_).stream() << stream_.str();
} else if (logging_level_ == Level::WARNING) {
runtime::detail::LogMessage(file_, lineno_).stream() << "Warning: " << stream_.str();
} else if (logging_level_ == Level::ERROR) {
runtime::detail::LogMessage(file_, lineno_).stream() << "Error: " << stream_.str();
} else if (logging_level_ == Level::DEBUG) {
runtime::detail::LogMessage(file_, lineno_).stream() << "Debug: " << stream_.str();
} else {
LOG(FATAL) << stream_.str();
runtime::detail::LogFatal(file_, lineno_).stream() << stream_.str();
}
}
}
std::ostringstream& stream() { return stream_; }

private:
const char* file_;
int lineno_;
std::ostringstream stream_;
PackedFunc logging_func;
Level logging_level;
PackedFunc logging_func_;
Level logging_level_;
};

/*! \brief The type of the random state */
Expand Down
Loading

0 comments on commit 225981f

Please sign in to comment.