Skip to content

Commit

Permalink
[TensorIR][M2a] Fuse, Split (apache#8467)
Browse files Browse the repository at this point in the history
* Fuse&split (apache#408)



Co-authored-by: jinhongyi <323195289@qq.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
  • Loading branch information
3 people authored and ylc committed Jan 13, 2022
1 parent 4957d0d commit 6403b80
Show file tree
Hide file tree
Showing 11 changed files with 1,170 additions and 14 deletions.
12 changes: 12 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@ class IterSumExpr : public IterMapExpr {
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);
/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
* \param indices The indices to detect pattern for.
* \param input_iters Map from variable to iterator's range.
* \param input_pred The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
*
* \return The indices after rewrite
*/
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool require_bijective);

/*!
* \brief Apply the inverse of the affine transformation to the outputs.
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,25 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
/******** Schedule: loops manipulation ********/
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
* 2) The (i+1)-th loop must be the only child of the i-th loop.
* 3) All loops must start with 0.
* \param loop_rvs The loops to be fused
* \return The new loop after fusion
*/
virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs) = 0;
/*!
* \brief Split a loop into a list of consecutive loops. It requires:
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param loop_rv The loop to be split
* \param factors The tiling factors, and at most one of which is -1, which means that
* factor is inferred.
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
/******** Schedule: compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
Expand Down
138 changes: 136 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=unused-import
"""The TensorIR schedule class"""
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
Expand All @@ -43,7 +43,10 @@ class BlockRV(Object):
"""A random variable that refers to a block"""


ExprRV = PrimExpr # A random variable that evaluates to an integer
# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370
# This feature is not supported until python 3.10:
# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias
ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer

RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name

Expand Down Expand Up @@ -257,6 +260,137 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member

########## Schedule: loops manipulation ##########
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
3) All loops must start with 0.
Parameters
----------
*loops : List[LoopRV]
The loops to be fused
Returns
----------
fused_loop : LoopRV
The new loop after fusion
Examples
--------
Before applying fuse, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
.. code-block:: python
sch = tir.Schedule(before_fuse)
i, j = sch.get_loops(sch.get_block("B"))
sch.fuse(i, j)
print(tvm.script.asscript(sch.mod["main"]))
After applying fuse, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
# the 2 loops are fused into 1
for i_j_fused in tir.serial(0, 16384):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, tir.floordiv(i_j_fused, 128))
tir.bind(vj, tir.floormod(i_j_fused, 128))
B[vi, vj] = A[vi, vj] * 2.0
"""
return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member

def split(
self,
loop: LoopRV,
factors: List[Union[ExprRV, None]],
) -> List[LoopRV]:
"""Split a loop into a list of consecutive loops. It requires:
1) The loop can't have annotation or thread binding.
2) The loop must start with 0.
Predicates may be added to ensure the total loop numbers keeps unchanged.
In `factors`, at most one of the factors can be None,
which will be automatically inferred.
Parameters
----------
loop : LoopRV
The loop to be split
factors: List[Union[ExprRV, None]]
The splitting factors
Potential inputs are:
- None
- ExprRV
- Nonnegative constant integers
Returns
----------
split_loops : List[LoopRV]
The new loops after split
Examples
--------
Before split, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_split(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
.. code-block:: python
sch = tir.Schedule(before_split)
i, j = sch.get_loops(sch.get_block("B"))
sch.split(i, factors=[2, 64])
print(tvm.script.asscript(sch.mod["main"]))
After applying split, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_split(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
# the original loop is split into 2 loops
for i0, i1, j in tir.grid(2, 64, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, ((i0*64) + i1))
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
# it will be checked later in C++ implementation
# that there is at most one None in `factors`
return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member

########## Schedule: compute location ##########
def compute_inline(self, block: BlockRV) -> None:
"""Inline a block into its consumer(s). It requires:
Expand Down
15 changes: 15 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,21 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter
return NormalizeIterMapToExpr(expr);
});

Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool require_bijective) {
Analyzer analyzer;
Array<IterSumExpr> rewrite =
DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer);
if (rewrite.empty()) {
return indices;
}
Array<PrimExpr> res;
res.reserve(rewrite.size());
IterMapToExprNormalizer converter(&analyzer);
for (const auto& expr : rewrite) res.push_back(converter.Convert(expr));
return res;
}

/*!
* \brief Divider to divide the bindings into two sets of bindings(outer and inner)
* such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X.
Expand Down
4 changes: 4 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0);

TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));

TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));
Expand Down Expand Up @@ -882,6 +884,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0);

TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x));
TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y));

Expand Down
87 changes: 87 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,93 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) {
}

/******** Schedule: loops manipulation ********/

LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::Fuse(state_, loop_srefs);
TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(result);
}

Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
const Array<Optional<ExprRV>>& factor_rvs) {
class NotSingleInferFactorError : public ScheduleError {
public:
explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}

String FastErrorString() const final {
return "ScheduleError: only one factor can be specified as -1 or none";
}

String DetailRenderTemplate() const final {
return "Only one factor can be specified as -1 or none";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {}; }

IRModule mod_;
};

class WrongFactorProductError : public ScheduleError {
public:
explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}

String FastErrorString() const final {
return "ScheduleError: The product of factors is not larger than or equal to the extent of "
"loop";
}

String DetailRenderTemplate() const final {
return "The product of factors is not larger than or equal to the extent of loop {0}";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }

IRModule mod_;
For loop_;
};
// Prepare for the splitting
StmtSRef loop_sref = this->GetSRef(loop_rv);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
Array<PrimExpr> factors;
factors.reserve(factor_rvs.size());
int infer_index = -1;
PrimExpr tot_length = 1;
Array<StmtSRef> results;
TVM_TIR_SCHEDULE_BEGIN();
// infer factor if needed and check validity of factors
for (size_t i = 0; i < factor_rvs.size(); i++) {
if (!factor_rvs[i].defined()) {
factors.push_back(Integer(-1));
if (infer_index == -1) {
infer_index = i;
} else {
throw NotSingleInferFactorError(state_->mod);
}
} else {
PrimExpr factor = this->Get(factor_rvs[i].value());
factors.push_back(factor);
tot_length *= factor;
}
}
if (infer_index != -1) {
factors.Set(infer_index,
this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length)));
} else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) {
throw WrongFactorProductError(state_->mod, GetRef<For>(loop));
}
results = tir::Split(state_, loop_sref, factors);
TVM_TIR_SCHEDULE_END("split", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(results);
}

/******** Schedule: compute location ********/

void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {
Expand Down
Loading

0 comments on commit 6403b80

Please sign in to comment.