Skip to content

Commit

Permalink
reorg
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 9, 2022
1 parent 0ba4a29 commit 337c1c1
Show file tree
Hide file tree
Showing 9 changed files with 165 additions and 221 deletions.
3 changes: 1 addition & 2 deletions python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .add_rfactor import AddRFactor
from .apply_custom_rule import ApplyCustomRule
from .auto_bind import AutoBind
from .auto_inline import AutoInline
from .auto_inline import AutoInline, InlineConstantScalars
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import (
MultiLevelTiling,
Expand All @@ -34,4 +34,3 @@
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
from .schedule_rule import PyScheduleRule, ScheduleRule
from .inline_const_scalars import InlineConstantScalars
17 changes: 17 additions & 0 deletions python/tvm/meta_schedule/schedule_rule/auto_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,20 @@ def __init__(
require_ordered,
disallow_op,
)


@register_object("meta_schedule.InlineConstantScalars")
class InlineConstantScalars(ScheduleRule):
"""Inline blocks that produce a constant scalar.
Such blocks get in the way of ReverseComputeInline during AutoInline, since they are also
counted as a producer block unless they are inlined first. So it is recommended to run
InlineConstantScalars before AutoInline.
"""

def __init__(
self,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleInlineConstantScalars, # type: ignore # pylint: disable=no-member
)
38 changes: 0 additions & 38 deletions python/tvm/meta_schedule/schedule_rule/inline_const_scalars.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class VerifyGPUCodeNode : public PostprocNode {
transform::PassContext pass_ctx = transform::PassContext::Current();
tir::PrimFunc f = WithAttr(GetRef<tir::PrimFunc>(prim_func), "global_symbol",
runtime::String(g_var->name_hint));
f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin
f = WithAttr(f, tvm::attr::kTarget, Target("cuda")); // Required for LowerIntrin
bool noalias = pass_ctx->GetConfig<Bool>("tir.noalias", Bool(true)).value();
if (noalias) {
f = WithAttr(std::move(f), "tir.noalias", Bool(true));
Expand Down
30 changes: 30 additions & 0 deletions src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,35 @@ TVM_REGISTER_NODE_TYPE(AutoInlineNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoInline")
.set_body_typed(ScheduleRule::AutoInline);

/*! \brief Inline blocks that produce a constant scalar. */
class InlineConstantScalarsNode : public ScheduleRuleNode {
public:
void InitializeWithTuneContext(const TuneContext& context) final {}

Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final {
const std::string block_name = sch->Get(block_rv)->name_hint;
if (block_name.find("compile_engine_const") != std::string::npos) {
sch->ComputeInline(block_rv);
}
return {sch};
}

ScheduleRule Clone() const final {
ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>(*this);
return ScheduleRule(n);
}

static constexpr const char* _type_key = "meta_schedule.InlineConstantScalars";
TVM_DECLARE_FINAL_OBJECT_INFO(InlineConstantScalarsNode, ScheduleRuleNode);
};

ScheduleRule ScheduleRule::InlineConstantScalars() {
ObjectPtr<InlineConstantScalarsNode> n = make_object<InlineConstantScalarsNode>();
return ScheduleRule(n);
}

TVM_REGISTER_NODE_TYPE(InlineConstantScalarsNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleInlineConstantScalars")
.set_body_typed(ScheduleRule::InlineConstantScalars);
} // namespace meta_schedule
} // namespace tvm
56 changes: 0 additions & 56 deletions src/meta_schedule/schedule_rule/inline_const_scalar.cc

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _schedule_packed_8x8x32_conv2d():

def schedule_fn(sch, conv2d_block: Optional[BlockRV] = None) -> bool:
if conv2d_block is None:
if has_block("conv2d_NCHWc_int8"):
if has_block(sch, "conv2d_NCHWc_int8"):
conv2d_block = sch.get_block("conv2d_NCHWc_int8")
else:
return False
Expand Down
115 changes: 115 additions & 0 deletions tests/python/unittest/test_meta_schedule_schedule_rule_auto_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
import pytest

import tvm
from tvm.tir import Schedule
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing.space_generation import generate_design_space
from tvm.script import tir as T
Expand Down Expand Up @@ -334,6 +337,101 @@ def main(T_full: T.Buffer[(1, 12, 4096), "int64"]) -> None:
T.writes(T_full[ax0, ax1, ax2])
T_full[ax0, ax1, ax2] = T.int64(0)


@tvm.script.ir_module
class Conv2dInt8:
@T.prim_func
def main(p0: T.Buffer[(16, 14, 14, 256), "int8"], p1: T.Buffer[(1024, 1, 1, 256), "int8"], p2: T.Buffer[(1, 1, 1, 1024), "int32"], p3: T.Buffer[(1, 1, 1, 1024), "int32"], p4: T.Buffer[1024, "int32"], p5: T.Buffer[1024, "int32"], p6: T.Buffer[1024, "int32"], p7: T.Buffer[1, "int32"], p8: T.Buffer[(16, 14, 14, 1024), "int32"], compute: T.Buffer[(16, 14, 14, 1024), "int32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
compile_engine_const = T.alloc_buffer([], dtype="int32")
pad_temp = T.alloc_buffer([16, 14, 14, 256], dtype="int8")
conv2d_nhwc = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_subtract = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_add = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
compute_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_add_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
compute_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_subtract_1 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
compute_3 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
T_add_2 = T.alloc_buffer([16, 14, 14, 1024], dtype="int32")
with T.block("compile_engine_const"):
vi = T.axis.spatial(1, 0)
T.reads()
T.writes(compile_engine_const[()])
compile_engine_const[()] = 59
for i0, i1, i2, i3 in T.grid(16, 14, 14, 256):
with T.block("pad_temp"):
i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(p0[i0_1, i1_1, i2_1, i3_1])
T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
pad_temp[i0_1, i1_1, i2_1, i3_1] = p0[i0_1, i1_1, i2_1, i3_1]
for i0, i1, i2, i3, i4, i5, i6 in T.grid(16, 14, 14, 1024, 1, 1, 256):
with T.block("conv2d_nhwc"):
nn, yy, xx, ff, ry, rx, rc = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
T.reads(pad_temp[nn, yy + ry, xx + rx, rc], p1[ff, ry, rx, rc])
T.writes(conv2d_nhwc[nn, yy, xx, ff])
with T.init():
conv2d_nhwc[nn, yy, xx, ff] = 0
conv2d_nhwc[nn, yy, xx, ff] = conv2d_nhwc[nn, yy, xx, ff] + T.cast(pad_temp[nn, yy + ry, xx + rx, rc], "int32") * T.cast(p1[ff, ry, rx, rc], "int32")
for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
with T.block("T_subtract"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(conv2d_nhwc[ax0, ax1, ax2, ax3], p2[0, 0, 0, ax3])
T.writes(T_subtract[ax0, ax1, ax2, ax3])
T_subtract[ax0, ax1, ax2, ax3] = conv2d_nhwc[ax0, ax1, ax2, ax3] - p2[0, 0, 0, ax3]
for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
with T.block("T_add"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(T_subtract[ax0, ax1, ax2, ax3], p3[0, 0, 0, ax3])
T.writes(T_add[ax0, ax1, ax2, ax3])
T_add[ax0, ax1, ax2, ax3] = T_subtract[ax0, ax1, ax2, ax3] + p3[0, 0, 0, ax3]
for i0, i1, i2, i3 in T.grid(16, 14, 14, 1024):
with T.block("compute"):
i0_2, i1_2, i2_2, i3_2 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2])
T.writes(compute_1[i0_2, i1_2, i2_2, i3_2])
compute_1[i0_2, i1_2, i2_2, i3_2] = T.q_multiply_shift_per_axis(T_add[i0_2, i1_2, i2_2, i3_2], p4[i3_2], p5[i3_2], p6[i3_2], 31, False, True, dtype="int32")
for i0_3, i1_3, i2_3, i3_3 in T.grid(16, 14, 14, 1024):
with T.block("T_add_1"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_3, i1_3, i2_3, i3_3])
T.reads(compile_engine_const[()], compute_1[ax0, ax1, ax2, ax3])
T.writes(T_add_1[ax0, ax1, ax2, ax3])
T_add_1[ax0, ax1, ax2, ax3] = compile_engine_const[()] + compute_1[ax0, ax1, ax2, ax3]
for i0_4, i1_4, i2_4, i3_4 in T.grid(16, 14, 14, 1024):
with T.block("compute_1"):
i0_5, i1_5, i2_5, i3_5 = T.axis.remap("SSSS", [i0_4, i1_4, i2_4, i3_4])
T.reads(T_add_1[i0_5, i1_5, i2_5, i3_5])
T.writes(compute_2[i0_5, i1_5, i2_5, i3_5])
compute_2[i0_5, i1_5, i2_5, i3_5] = T.max(T.min(T_add_1[i0_5, i1_5, i2_5, i3_5], 255), 0)
for i0_6, i1_6, i2_6, i3_6 in T.grid(16, 14, 14, 1024):
with T.block("T_subtract_1"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_6, i1_6, i2_6, i3_6])
T.reads(compute_2[ax0, ax1, ax2, ax3], p7[0])
T.writes(T_subtract_1[ax0, ax1, ax2, ax3])
T_subtract_1[ax0, ax1, ax2, ax3] = compute_2[ax0, ax1, ax2, ax3] - p7[0]
for i0_7, i1_7, i2_7, i3_7 in T.grid(16, 14, 14, 1024):
with T.block("compute_2"):
i0_8, i1_8, i2_8, i3_8 = T.axis.remap("SSSS", [i0_7, i1_7, i2_7, i3_7])
T.reads(T_subtract_1[i0_8, i1_8, i2_8, i3_8])
T.writes(compute_3[i0_8, i1_8, i2_8, i3_8])
compute_3[i0_8, i1_8, i2_8, i3_8] = T.q_multiply_shift(T_subtract_1[i0_8, i1_8, i2_8, i3_8], 1408572815, 31, 1, dtype="int32")
for i0_9, i1_9, i2_9, i3_9 in T.grid(16, 14, 14, 1024):
with T.block("T_add_2"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0_9, i1_9, i2_9, i3_9])
T.reads(compute_3[ax0, ax1, ax2, ax3], p8[ax0, ax1, ax2, ax3])
T.writes(T_add_2[ax0, ax1, ax2, ax3])
T_add_2[ax0, ax1, ax2, ax3] = compute_3[ax0, ax1, ax2, ax3] + p8[ax0, ax1, ax2, ax3]
for i0_10, i1_10, i2_10, i3_10 in T.grid(16, 14, 14, 1024):
with T.block("compute_3"):
i0_11, i1_11, i2_11, i3_11 = T.axis.remap("SSSS", [i0_10, i1_10, i2_10, i3_10])
T.reads(T_add_2[i0_11, i1_11, i2_11, i3_11])
T.writes(compute[i0_11, i1_11, i2_11, i3_11])
compute[i0_11, i1_11, i2_11, i3_11] = T.max(T.min(T_add_2[i0_11, i1_11, i2_11, i3_11], 255), 0)


# pylint: enable=no-member,invalid-name,unused-variable,no-self-argument,line-too-long,chained-comparison,not-callable,too-many-nested-blocks
# fmt: on

Expand Down Expand Up @@ -398,9 +496,26 @@ def test_inline_constant_tensor():
tvm.ir.assert_structural_equal(lhs=space.mod, rhs=ConstConsumer)


def test_conv2d_int8_inline_constant_scalars():
sch = Schedule(Conv2dInt8)

conv2d = sch.get_block("conv2d_nhwc")
sch.cache_write(conv2d, 0, "shared")

with pytest.raises(tvm.tir.ScheduleError) as e:
sch.reverse_compute_inline(sch.get_block("T_add_1"))

err_msg = "The block is only allowed to read a single buffer region, but it reads 2 region(s)"
assert err_msg in str(e)

ms.schedule_rule.InlineConstantScalars().apply(sch, sch.get_block("compile_engine_const"))
sch.reverse_compute_inline(sch.get_block("T_add_1"))


if __name__ == "__main__":
test_inline_consumer_chain()
test_inline_into_cache()
test_inline_into_multiple_consumers()
test_inline_pure_spatial()
test_inline_constant_tensor()
test_conv2d_int8_inline_constant_scalars()
Loading

0 comments on commit 337c1c1

Please sign in to comment.