Skip to content

Commit

Permalink
[MetaSchedule][Test] Migrate AddRFactor to SEqual (#12758)
Browse files Browse the repository at this point in the history
This PR migrates the usage of `check_trace` to `check_sketch`,
which prefers structural equality of TIRs insteda of string equalty
of traces.
  • Loading branch information
junrushao authored Sep 12, 2022
1 parent 4d27664 commit a23b71c
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 66 deletions.
16 changes: 6 additions & 10 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing import List, Union

from tvm.meta_schedule.schedule_rule import (
AddRFactor,
AutoBind,
AutoInline,
CrossThreadReduction,
Expand All @@ -28,7 +27,9 @@
ReuseType,
ScheduleRule,
)
from tvm.meta_schedule.schedule_rule.multi_level_tiling import MultiLevelTilingTensorCore
from tvm.meta_schedule.schedule_rule.multi_level_tiling import (
MultiLevelTilingTensorCore,
)
from tvm.target import Target


Expand Down Expand Up @@ -64,13 +65,6 @@ def auto_inline(target: Target) -> ScheduleRule:
raise NotImplementedError(f"{target.kind.name} is not supported")


def add_rfactor(target: Target) -> ScheduleRule:
"""Default schedule rules for with add_rfactor"""
if target.kind.name == "llvm":
return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64)
raise NotImplementedError(f"{target.kind.name} is not supported")


def cross_thread_reduction(target: Target) -> ScheduleRule:
"""Default schedule rules for with cross-thread reduction"""
if target.kind.name == "cuda":
Expand Down Expand Up @@ -131,7 +125,9 @@ def multi_level_tiling_tensor_core(
trans_b = [trans_b]

if target.kind.name == "cuda":
from tvm.tir.tensor_intrin import cuda # pylint: disable=import-outside-toplevel
from tvm.tir.tensor_intrin import ( # pylint: disable=import-outside-toplevel
cuda,
)

intrin_groups = [
cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/tir/schedule/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
# specific language governing permissions and limitations
# under the License.
"""Testing utilities for the TensorIR schedule API"""
from typing import Union, Sequence
from typing import Sequence, Union

import tvm
from tvm.ir import IRModule, structural_equal
from tvm.ir import IRModule, assert_structural_equal
from tvm.tir import PrimFunc
from tvm.tir.schedule import Trace, Schedule
from tvm.tir.schedule import Schedule, Trace


def verify_trace_roundtrip(
Expand Down Expand Up @@ -70,7 +70,7 @@ def verify_trace_roundtrip(
assert text_format in ("json", "python"), f"Unknown text format: {text_format}"

# Step 2. Verify that the round-trip produced the same scheduling
assert structural_equal(new_sch.mod, sch.mod)
assert_structural_equal(new_sch.mod, sch.mod)

# Step 3. Check the consistency of the text format between the old and new traces
py_repr = "\n".join(trace.as_python())
Expand Down
5 changes: 2 additions & 3 deletions src/meta_schedule/schedule_rule/add_rfactor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const tir::

// Split the fused reduction loop.
Array<tir::ExprRV> factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor);
const Array<tir::LoopRV>& split_loops =
sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});
Array<tir::LoopRV> split_loops = sch->Split(fused_reduce_loop, {factors.begin(), factors.end()});

Array<tir::Schedule> res;
for (const tir::LoopRV& split_loop : split_loops) {
Expand All @@ -104,7 +103,7 @@ Array<tir::Schedule> AddRFactorNode::Apply(const tir::Schedule& sch, const tir::

// Annotate that the rfactor block, which is now the producer of the original block, needs to
// be considered by the rule Random-Compute-Location.
sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Bool(true));
sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Integer(1));
res.push_back(sch_tmp);
} catch (const tvm::runtime::Error& e) {
}
Expand Down
4 changes: 3 additions & 1 deletion src/tir/schedule/primitive/sampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,9 @@ std::vector<int64_t> SamplePerfectTile(
} else {
// Case 3. Use fresh new sampling result
result = SamplePerfectTile(rand_state, *extent, n_splits, max_innermost_factor);
ICHECK_LE(result.back(), max_innermost_factor);
if (max_innermost_factor != -1) {
ICHECK_LE(result.back(), max_innermost_factor);
}
}
*decision = support::AsArray<int64_t, Integer>(result);
return result;
Expand Down
142 changes: 94 additions & 48 deletions tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,62 +15,108 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring

from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm import meta_schedule as ms
from tvm.meta_schedule.testing import te_workload
from tvm.meta_schedule.testing.schedule_rule import add_rfactor
from tvm.meta_schedule.testing.space_generation import check_trace
from tvm.meta_schedule.tune_context import TuneContext
from tvm.meta_schedule.testing.space_generation import check_sketches
from tvm.script import tir as T
from tvm.target import Target
from tvm.te.operation import create_prim_func
from tvm.te import create_prim_func


def _create_context(mod, target, rule) -> TuneContext:
ctx = TuneContext(
mod=mod,
target=target,
space_generator=PostOrderApply(),
sch_rules=[rule],
task_name="test",
)
return ctx
def test_cpu_matmul():
@T.prim_func
def cpu_matmul_0(
A: T.Buffer[(4, 512), "float32"],
B: T.Buffer[(512, 4), "float32"],
C: T.Buffer[(4, 4), "float32"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
for i0, i1, i2 in T.grid(4, 4, 512):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(A[i, k], B[k, j])
T.writes(C[i, j])
with T.init():
C[i, j] = T.float32(0)
C[i, j] = C[i, j] + A[i, k] * B[k, j]

@T.prim_func
def cpu_matmul_1(
A: T.Buffer[(4, 512), "float32"],
B: T.Buffer[(512, 4), "float32"],
C: T.Buffer[(4, 4), "float32"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
C_rf = T.alloc_buffer([4, 4, 128], dtype="float32")
for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128):
with T.block("C_rf"):
vi2_1, i, j, vi2_0 = T.axis.remap("SSSR", [i2_1, i0, i1, i2_0])
T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j])
T.writes(C_rf[i, j, vi2_1])
with T.init():
C_rf[i, j, vi2_1] = T.float32(0)
C_rf[i, j, vi2_1] = (
C_rf[i, j, vi2_1] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 * 128 + vi2_1, j]
)
for i0, i1, i2_1 in T.grid(4, 4, 128):
with T.block("C"):
vi2_1, i, j = T.axis.remap("RSS", [i2_1, i0, i1])
T.reads(C_rf[i, j, vi2_1])
T.writes(C[i, j])
T.block_attr({"meta_schedule.random_compute_producer": 1})
with T.init():
C[i, j] = T.float32(0)
C[i, j] = C[i, j] + C_rf[i, j, vi2_1]

def test_cpu_matmul():
expected = [
[],
[
'b0 = sch.get_block(name="C", func_name="main")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
"l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)",
"b8 = sch.rfactor(loop=l7, factor_axis=2)",
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)',
],
[
'b0 = sch.get_block(name="C", func_name="main")',
"l1, l2, l3 = sch.get_loops(block=b0)",
"v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)",
"l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)",
"b8 = sch.rfactor(loop=l6, factor_axis=2)",
'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)',
],
@T.prim_func
def cpu_matmul_2(
A: T.Buffer[(4, 512), "float32"],
B: T.Buffer[(512, 4), "float32"],
C: T.Buffer[(4, 4), "float32"],
) -> None:
T.func_attr({"global_symbol": "main", "tir.noalias": True})
C_rf = T.alloc_buffer([4, 4, 4], dtype="float32")
for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128):
with T.block("C_rf"):
vi2_0, i, j, vi2_1 = T.axis.remap("SSSR", [i2_0, i0, i1, i2_1])
T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j])
T.writes(C_rf[i, j, vi2_0])
with T.init():
C_rf[i, j, vi2_0] = T.float32(0)
C_rf[i, j, vi2_0] = (
C_rf[i, j, vi2_0] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 * 128 + vi2_1, j]
)
for i0, i1, i2_0 in T.grid(4, 4, 4):
with T.block("C"):
vi2_0, i, j = T.axis.remap("RSS", [i2_0, i0, i1])
T.reads(C_rf[i, j, vi2_0])
T.writes(C[i, j])
T.block_attr({"meta_schedule.random_compute_producer": 1})
with T.init():
C[i, j] = T.float32(0)
C[i, j] = C[i, j] + C_rf[i, j, vi2_0]

decision_0 = [] # type: ignore
decision_1 = [
("SamplePerfectTile", [4, 128]),
]
decision_2 = [
("SamplePerfectTile", [4, 128]),
]
target = Target("llvm --num-cores=32")
ctx = _create_context(
create_prim_func(
te_workload.matmul(
n=4,
m=4,
k=512,
)
),
target=target,
rule=add_rfactor(target=target),
mod = create_prim_func(te_workload.matmul(n=4, m=4, k=512))
actual = ms.TuneContext(
mod=mod,
target=Target("llvm --num-cores=32"),
space_generator=ms.space_generator.PostOrderApply(),
sch_rules=[ms.schedule_rule.AddRFactor()],
task_name="test",
).generate_design_space()
check_sketches(
mod,
sketches=actual,
expected_mods=[cpu_matmul_0, cpu_matmul_1, cpu_matmul_2],
expected_decisions=[decision_0, decision_1, decision_2],
)
spaces = ctx.space_generator.generate_design_space(mod=ctx.mod)
assert len(spaces) == 3
check_trace(spaces, expected)


if __name__ == "__main__":
Expand Down

0 comments on commit a23b71c

Please sign in to comment.