Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[MetaSchedule] Fuse loops around shared to global store block in `Mul…
Browse files Browse the repository at this point in the history
…tiLevelTilingTensorCore` (apache#13357)

* Fuse shared to global store loops in MultiLevelTilingTensorCore

* update test
  • Loading branch information
masahi authored and xinetzone committed Nov 25, 2022
1 parent 1c514a2 commit 396e31b
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 21 deletions.
30 changes: 30 additions & 0 deletions src/meta_schedule/postproc/rewrite_cooperative_fetch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,29 @@ bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) {
return ann_key == attr::warp_execution;
}

size_t GetMaxUsedDtypeBytes(Block block) {
size_t max_bytes = 1;
static auto q_multiply_shift_per_axis = Op::Get("tir.q_multiply_shift_per_axis");
static auto q_multiply_shift = Op::Get("tir.q_multiply_shift");

tir::PostOrderVisit(block->body, [&](const ObjectRef& obj) {
if (const auto* store = obj.as<tir::BufferStoreNode>()) {
max_bytes = std::max(max_bytes, static_cast<size_t>(store->value->dtype.bytes()));
} else if (const auto* load = obj.as<tir::BufferLoadNode>()) {
max_bytes = std::max(max_bytes, static_cast<size_t>(load->dtype.bytes()));
} else if (const auto* call = obj.as<tir::CallNode>()) {
if (call->op.same_as(q_multiply_shift_per_axis) || call->op.same_as(q_multiply_shift)) {
// q_multiply_shift uses 64 bit multiply
max_bytes = std::max<size_t>(max_bytes, 8);
}
} else if (const auto* cast = obj.as<tir::CastNode>()) {
max_bytes = std::max<size_t>(max_bytes, cast->dtype.bytes());
}
});

return max_bytes;
}

} // namespace tir

namespace meta_schedule {
Expand Down Expand Up @@ -154,6 +177,13 @@ bool RewriteCooperativeFetchNode::Apply(const tir::Schedule& sch) {
if (fused_extent % vector_lane != 0) {
vector_lane = 1;
}
// If the block involves 64 bit values, disable vectorization for now since
// vectorization of 64 bit values does not work well on CUDA.
// TODO(masahi, vinx13): Decouple epilogue fusion computation and shared to global store, so
// that we can always vectorize the latter.
if (tir::GetMaxUsedDtypeBytes(sch->Get(block)) > 4) {
vector_lane = 1;
}
if (thread_extent_y != -1) {
if (vector_lane > 1) {
Array<tir::LoopRV> split = sch->Split(fused, {NullOpt, //
Expand Down
4 changes: 2 additions & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,8 +269,8 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
sch->ComputeAt(cache_read_block, loop_rv, true);
// Fuse the iterators of the cache_read
Array<LoopRV> buffer_loops = sch->GetLoops(cache_read_block);
LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, //
buffer_loops.end()});
sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, //
buffer_loops.end()});
AnnotateCooperativeFetching(&sch, cache_read_block);
new_state->read_reuse.emplace(i, cache_read_block);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,11 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
sch->ReverseComputeAt(cache_write, loop, true);

if (state->write_reuse.count(0)) {
// Fuse the iterators of the cache_write
Array<LoopRV> buffer_loops = sch->GetLoops(state->write_reuse[0]);
ICHECK_GT(buffer_loops.size(), 2);
sch->Fuse(Array<LoopRV>{buffer_loops.end() - 2, // The src shmem is always 2D
buffer_loops.end()});
AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
}
sch->ReverseComputeInline(state->tensor_core_reindex_store);
Expand Down
39 changes: 20 additions & 19 deletions tests/python/unittest/test_meta_schedule_schedule_rule_mlt_tc.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,15 @@ def matmul_relu_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128, 128), "f
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(32, 32):
for ax0_ax1_fused in T.serial(1024):
with T.block("C_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1)
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
T.reads(C_reindex_shared[v0, v1])
T.writes(compute[v0, v1])
T.block_attr({"meta_schedule.cooperative_fetch":4})
compute[v0, v1] = T.max(C_reindex_shared[v0, v1], T.float32(0))

# fmt: on
decision_0 = [
("SamplePerfectTile", [4, 1, 1, 1, 2]),
Expand Down Expand Up @@ -303,10 +304,10 @@ def matmul_relu_fallback_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128,
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(32, 128):
for ax0_ax1_fused in T.serial(4096):
with T.block("C_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0)
v1 = T.axis.spatial(128, ax1)
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused // 128)
v1 = T.axis.spatial(128, ax0_ax1_fused % 128)
T.reads(C_reindex_shared[v0, v1])
T.writes(compute[v0, v1])
T.block_attr({"meta_schedule.cooperative_fetch":4})
Expand Down Expand Up @@ -451,10 +452,10 @@ def conv2d_0(inputs: T.Buffer[(1, 16, 16, 32), "float16"], weight: T.Buffer[(3,
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(16, 16):
for ax0_ax1_fused in T.serial(256):
with T.block("conv2d_nhwc_reindex_shared"):
v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0)
v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax1)
v0 = T.axis.spatial(256, ax0_0_1_ax1_0_1_fused * 16 + ax0_ax1_fused // 16)
v1 = T.axis.spatial(32, ax0_0_0_ax1_0_0_fused * 16 + ax0_ax1_fused % 16)
T.reads(conv2d_nhwc_reindex_shared[v0, v1])
T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1])
T.block_attr({"meta_schedule.cooperative_fetch":3})
Expand Down Expand Up @@ -617,10 +618,10 @@ def matmul_relu_pipeline_0(A: T.Buffer[(128, 128), "float16"], B: T.Buffer[(128,
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(32, 32):
for ax0_ax1_fused in T.grid(1024):
with T.block("C_reindex_shared"):
v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0)
v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax1)
v0 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused // 4 * 32 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax0_0_1_ax1_0_1_fused % 4 * 32 + ax0_ax1_fused % 32)
T.reads(C_reindex_shared[v0, v1])
T.writes(C[v0, v1])
T.block_attr({"meta_schedule.cooperative_fetch":3})
Expand Down Expand Up @@ -919,11 +920,11 @@ def padded_matmul_relu_0(A: T.Buffer[(127, 127), "float16"], B: T.Buffer[(127, 1
T.reads(C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
C_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = C_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(32, 32):
for ax0_ax1_fused in T.serial(1024):
with T.block("C_reindex_shared"):
T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1 < 127)
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax1)
T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32 < 127)
v0 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused // 2 * 32 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(128, ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_fused % 32)
T.reads(C_reindex_shared[v0, v1])
T.writes(compute[v0, v1])
T.block_attr({"meta_schedule.cooperative_fetch":4})
Expand Down Expand Up @@ -1063,10 +1064,10 @@ def conv2d_1x1_0(inputs: T.Buffer[(1, 16, 16, 64), "float16"], weight: T.Buffer[
T.reads(conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
T.writes(conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i])
conv2d_nhwc_reindex_shared[v0_o * 16 + v0_i, v1_o * 16 + v1_i] = conv2d_nhwc_reindex_shared_wmma_accumulator[v0_o * 16 + v0_i, v1_o * 16 + v1_i]
for ax0, ax1 in T.grid(16, 32):
for ax0_ax1_fused in T.serial(512):
with T.block("conv2d_nhwc_reindex_shared"):
v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0)
v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax1)
v0 = T.axis.spatial(256, ax2_0_0_ax3_0_0_fused // 2 * 32 + ax2_0_1_ax3_0_1_fused * 16 + ax0_ax1_fused // 32)
v1 = T.axis.spatial(64, ax2_0_0_ax3_0_0_fused % 2 * 32 + ax0_ax1_fused % 32)
T.reads(conv2d_nhwc_reindex_shared[v0, v1])
T.writes(conv2d_nhwc[v0 // 256, v0 // 16, v0 % 16, v1])
T.block_attr({"meta_schedule.cooperative_fetch":2})
Expand Down

0 comments on commit 396e31b

Please sign in to comment.