Skip to content

Commit

Permalink
introduce pass lower_init_block (apache#7806)
Browse files Browse the repository at this point in the history
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
  • Loading branch information
3 people authored and tmoreau89 committed Apr 11, 2021
1 parent b174927 commit 098957d
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 0 deletions.
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ TVM_DLL Pass PointerValueTypeRewrite();
*/
TVM_DLL Pass HoistIfThenElse();

/*!
* \brief Lower block init stmt into IfThenElse stmts
* \return The pass.
*/
TVM_DLL Pass LowerInitBlock();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,3 +536,14 @@ def HoistIfThenElse(variant=None):
return _ffi_api.HoistIfThenElseBasic()
elif variant is None:
return _ffi_api.HoistIfThenElse()


def LowerInitBlock():
"""Lower block init stmt into IfThenElse stmts
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerInitBlock()
85 changes: 85 additions & 0 deletions src/tir/transforms/lower_init_block.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Lower block init stmt into branch stmt
* \file lower_reduction.cc
*/
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {

class InitBlockLower : public StmtMutator {
private:
Stmt VisitStmt_(const BlockNode* block) final {
if (!block->init.defined()) {
return StmtMutator::VisitStmt_(block);
}
Stmt init = DoLowering(block->init.value(), block->iter_vars);
Stmt body = VisitStmt(block->body);
auto n = CopyOnWrite(block);
n->init = NullOpt;
n->body = SeqStmt::Flatten(init, body);
return Block(n);
}

static Stmt DoLowering(const Stmt& init, const Array<IterVar>& iter_vars) {
std::vector<PrimExpr> conditions;
for (const IterVar& var : iter_vars) {
if (var->iter_type == IterVarType::kCommReduce) {
conditions.push_back(equal(var->var, var->dom->min));
}
}
// Handle the case where there is no condition
if (conditions.empty()) {
return init;
}
// Concat the conditions with logical and (&&)
PrimExpr cond = conditions[0];
for (size_t i = 1; i < conditions.size(); ++i) {
cond = logical_and(cond, conditions[i]);
}
return IfThenElse(cond, init);
}
};

PrimFunc LowerInitBlock(PrimFunc func) {
auto fptr = func.CopyOnWrite();
fptr->body = InitBlockLower()(std::move(fptr->body));
return func;
}

namespace transform {

Pass LowerInitBlock() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
return LowerInitBlock(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.LowerReduction", {});
}

TVM_REGISTER_GLOBAL("tir.transform.LowerInitBlock").set_body_typed(LowerInitBlock);

} // namespace transform

} // namespace tir
} // namespace tvm
53 changes: 53 additions & 0 deletions tests/python/unittest/test_tir_transform_lower_init_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import tir
from tvm.script import ty


@tvm.script.tir
class WithInit:
def main(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [64, 64, 64])
B = tir.match_buffer(b, [64])

with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]:
with tir.init():
B[i] = tir.float32(0)
B[i] += A[i, j, k]


@tvm.script.tir
class WithBranch:
def main(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, [64, 64, 64])
B = tir.match_buffer(b, [64])

with tir.block([64, tir.reduce_axis(0, 64), tir.reduce_axis(32, 64)]) as [i, j, k]:
if (j == 0) and (k == 32):
B[i] = tir.float32(0)
B[i] += A[i, j, k]


def test_lower_reduction():
origin_mod = WithInit()
mod = tvm.tir.transform.LowerInitBlock()(origin_mod)
tvm.ir.assert_structural_equal(mod, WithBranch(), True)


if __name__ == "__main__":
test_lower_reduction()

0 comments on commit 098957d

Please sign in to comment.