Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance Align] fixing codegen problems #5

Merged
merged 1 commit into from
Jan 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,11 @@ TVM_DLL Pass InjectSoftwarePipeline();
*/
TVM_DLL Pass LowerAutoCopy();

/*!
* \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
* \return The pass.
*/
TVM_DLL Pass RenormalizeSplitPattern();

} // namespace transform
} // namespace tir
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ def InjectSoftwarePipeline():
"""
return _ffi_api.InjectSoftwarePipeline() # type: ignore


def LowerAutoCopy():
"""Automatically do memory optimizations for auto copy blocks

Expand All @@ -771,3 +772,13 @@ def LowerAutoCopy():
"""
return _ffi_api.LowerAutoCopy()


def RenomalizeSplitPattern():
"""Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.RenormalizeSplitPattern()
12 changes: 12 additions & 0 deletions src/arith/modular_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,18 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
return Everything();
}

Entry VisitExpr_(const FloorModNode* op) final {
Entry b = VisitExpr(op->b);
if (b.is_const()) {
int64_t c2 = b.base;
ICHECK(c2 != 0) << "MathError: the divisor is 0";
Entry a = VisitExpr(op->a);
int64_t coeff = ZeroAwareGCD(a.coeff, c2);
return Entry(coeff, a.base % c2);
}
return Everything();
}

Entry VisitExpr_(const MinNode* op) final {
Entry a = VisitExpr(op->a);
Entry b = VisitExpr(op->b);
Expand Down
2 changes: 2 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) {
TVM_TRY_REWRITE(truncdiv(x, c1) * c1 + truncmod(x, c1), x);
// floor div
TVM_TRY_REWRITE(floordiv(x, c1) * c1 + floormod(x, c1), x);
TVM_TRY_REWRITE_IF(floordiv(floormod(x, c2) + c1, c2) + floordiv(x, c2), floordiv(x + c1, c2),
c2.Eval()->value > 0);

// canonicalization rule
// will try rewrite again after canonicalization.
Expand Down
12 changes: 12 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
#include <mutex>
#include <stack>

#include "../printer/text_printer.h"
#include <tvm/ir/transform.h>

namespace tvm {

// Register build pipeline related options
Expand Down Expand Up @@ -187,6 +190,14 @@ transform::Pass Filter(FCond fcond) {
return tir::transform::CreatePrimFuncPass(fpass, 0, "Filter", {});
}

Pass Print() {
auto pass_func = [](tir::PrimFunc f, IRModule m, transform::PassContext ctx) {
LOG(INFO) << tir::AsTVMScript(f);
return f;
};
return tir::transform::CreatePrimFuncPass(pass_func, 0, "tir.Print", {});
}

Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
transform::PassContext pass_ctx = transform::PassContext::Current();

Expand Down Expand Up @@ -271,6 +282,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {

// PHASE 3
pass_list.push_back(tir::transform::Simplify());
pass_list.push_back(tir::transform::RenormalizeSplitPattern());
pass_list.push_back(tir::transform::RemoveNoOp());
pass_list.push_back(tir::transform::RewriteUnsafeSelect());
pass_list.push_back(tir::transform::HoistIfThenElse());
Expand Down
87 changes: 87 additions & 0 deletions src/tir/transforms/renormalize_split_pattern.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file renormalize_split_pattern.cc
* \brief Renormalize the split pattern from floordiv(floormod()) to floormod(floordiv())
*/
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../../arith/ir_mutator_with_analyzer.h"

namespace tvm {
namespace tir {

using namespace arith;

class SplitPatternReNormalizer : public IRMutatorWithAnalyzer {
public:
explicit SplitPatternReNormalizer(Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {}

PrimExpr VisitExpr_(const FloorDivNode* op) final {
PrimExpr a = VisitExpr(op->a);
PrimExpr b = VisitExpr(op->b);
// floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1)
if (const auto* inner = op->a.as<FloorModNode>()) {
if (const auto* c2 = op->b.as<IntImmNode>()) {
if (const auto* c1c2 = inner->b.as<IntImmNode>()) {
if (c1c2->value % c2->value == 0) {
return analyzer_->Simplify(FloorMod(FloorDiv(inner->a, op->b), IntImm(op->b.dtype(), c1c2->value / c2->value)));
}
}
}
}
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
return FloorDiv(a, b);
}
}

Stmt VisitStmt_(const ForNode* op) final {
analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent));
With<ConstraintContext> ctx1(analyzer_, op->loop_var >= op->min);
With<ConstraintContext> ctx2(analyzer_, op->loop_var < op->min + op->extent);
return IRMutatorWithAnalyzer::VisitStmt_(op);
}
};

namespace transform {

Pass RenormalizeSplitPattern() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
arith::Analyzer analyzer;
n->body = SplitPatternReNormalizer(&analyzer)(std::move(n->body));
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.RenormalizeSplitPattern", {});
}

TVM_REGISTER_GLOBAL("tir.transform.RenormalizeSplitPattern").set_body_typed(RenormalizeSplitPattern);

} // namespace transform

} // namespace tir
} // namespace tvm
10 changes: 10 additions & 0 deletions tests/python/unittest/test_arith_modular_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import tvm
from tvm import te
from tvm.arith import analyzer


def test_cast():
Expand Down Expand Up @@ -50,6 +51,14 @@ def test_mul():
assert m.base == 2


def test_floormod():
analyzer = tvm.arith.Analyzer()
x, y = te.var("x"), te.var("y")
m = analyzer.modular_set(tvm.tir.floormod(x * 128 + y * 4, 256))
assert m.coeff == 4
assert m.base == 0


def test_div_shift():
analyzer = tvm.arith.Analyzer()
x, y = te.var("x"), te.var("y")
Expand Down Expand Up @@ -175,6 +184,7 @@ def test_let():
test_add_sub()
test_mul()
test_div_shift()
test_floormod()
test_min_max_select()
test_mix_index()
test_constraint_scope()
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def test_vector_simplify():
ck.verify(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4"))
ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)), tvm.tir.Ramp(fld(x, 4), 2, 5))
ck.verify(
fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)),
tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4)
)
ck.verify(
fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
Expand Down Expand Up @@ -277,6 +281,7 @@ def test_add_index_simplify():
flm = tvm.te.floormod
ck.verify(y * flm(x, 8) + 10 * flm(x, 8), flm(x, 8) * (y + 10))
ck.verify(fld(x, 8) * 8 + flm(x, 8), x)
ck.verify(fld(flm(x, 2) + 7, 2) + fld(x, 2), fld(x + 7, 2))


def test_sub_index_simplify():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm.script import tir as T


@tvm.script.ir_module
class Before:
@T.prim_func
def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
blockIdx_x = T.env_thread("blockIdx.x")
# body
T.launch_thread(blockIdx_x, 64)
conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local")
PadInput_shared = T.allocate([768], "float32", "shared")
weight_shared = T.allocate([4096], "float32", "shared")
T.launch_thread(threadIdx_x, 32)
for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2):
T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True)
for i6_0 in T.serial(16):
for ax0_ax1_ax2_ax3_fused_0 in T.serial(24):
T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 and blockIdx_x // 32 * 2 + (ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x) % 128 // 32 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True)
for ax0_ax1_ax2_ax3_fused_0 in T.serial(32):
T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) % 256 // 8 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4))
for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2):
T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True)
for ax1, ax2 in T.grid(2, 4):
T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True)


@tvm.script.ir_module
class After:
@T.prim_func
def main(inputs: T.Buffer[(1, 4, 4, 512), "float32"], weight: T.Buffer[(4, 4, 512, 256), "float32"], conv2d_transpose_nhwc: T.Buffer[(1, 8, 8, 256), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# var definition
threadIdx_x = T.env_thread("threadIdx.x")
blockIdx_x = T.env_thread("blockIdx.x")
# body
T.launch_thread(blockIdx_x, 64)
conv2d_transpose_nhwc_local = T.allocate([8], "float32", "local")
PadInput_shared = T.allocate([768], "float32", "shared")
weight_shared = T.allocate([4096], "float32", "shared")
T.launch_thread(threadIdx_x, 32)
for i2_3_init, i1_4_init, i2_4_init in T.grid(2, 2, 2):
T.store(conv2d_transpose_nhwc_local, i1_4_init * 4 + i2_3_init * 2 + i2_4_init, T.float32(0), True)
for i6_0 in T.serial(16):
for ax0_ax1_ax2_ax3_fused_0 in T.serial(24):
T.store(PadInput_shared, ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x, T.if_then_else(128 <= ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x and ax0_ax1_ax2_ax3_fused_0 * 32 + threadIdx_x < 640 and 1 <= blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 and blockIdx_x // 32 * 2 + ax0_ax1_ax2_ax3_fused_0 % 4 < 5, T.load("float32", inputs.data, blockIdx_x // 32 * 1024 + ax0_ax1_ax2_ax3_fused_0 * 512 + i6_0 * 32 + threadIdx_x - 2560), T.float32(0), dtype="float32"), True)
for ax0_ax1_ax2_ax3_fused_0 in T.serial(32):
T.store(weight_shared, T.ramp(ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4, 1, 4), T.load("float32x4", weight.data, T.ramp((ax0_ax1_ax2_ax3_fused_0 * 128 + threadIdx_x * 4) // 256 * 131072 + i6_0 * 8192 + (ax0_ax1_ax2_ax3_fused_0 * 16 + threadIdx_x // 2) % 32 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 2 * 4, 1, 4), T.broadcast(True, 4)), T.broadcast(True, 4))
for i6_1, i2_3, i4_2, i5_2, i6_2, i1_4, i2_4 in T.grid(4, 2, 4, 4, 8, 2, 2):
T.store(conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4, T.load("float32", conv2d_transpose_nhwc_local, i1_4 * 4 + i2_3 * 2 + i2_4) + T.if_then_else((i1_4 + i4_2) % 2 == 0 and (i2_4 + i5_2) % 2 == 0, T.load("float32", PadInput_shared, threadIdx_x // 8 * 128 + (i1_4 + i4_2) // 2 * 128 + (i2_4 + i5_2) // 2 * 32 + i2_3 * 32 + i6_1 * 8 + i6_2), T.float32(0), dtype="float32") * T.load("float32", weight_shared, i6_1 * 64 + i6_2 * 8 + threadIdx_x % 8 + 3840 - i5_2 * 256 - i4_2 * 1024), True)
for ax1, ax2 in T.grid(2, 4):
T.store(conv2d_transpose_nhwc.data, threadIdx_x // 8 * 4096 + ax1 * 2048 + blockIdx_x // 32 * 1024 + ax2 * 256 + blockIdx_x % 32 * 8 + threadIdx_x % 8, T.load("float32", conv2d_transpose_nhwc_local, ax1 * 4 + ax2), True)


def tesd_renormalize_split_pattern():
after = tvm.tir.transform.RenomalizeSplitPattern()(Before)
tvm.ir.assert_structural_equal(after, After)


if __name__ == "__main__":
tesd_renormalize_split_pattern()