Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M2a] Fuse, Split #8467

Merged
merged 16 commits into from
Jul 21, 2021
12 changes: 12 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@ class IterSumExpr : public IterMapExpr {
Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);
/*!
* \brief Use IterVarMap detector to rewrite and simplify the indices
*
* \param indices The indices to detect pattern for.
* \param input_iters Map from variable to iterator's range.
* \param input_pred The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
*
* \return The indices after rewrite
*/
Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool require_bijective);

/*!
* \brief Apply the inverse of the affine transformation to the outputs.
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,25 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
/******** Schedule: loops manipulation ********/
/*!
* \brief Fuse a list of consecutive loops into one. It requires:
* 1) The loops can't have annotations or thread bindings.
* 2) The (i+1)-th loop must be the only child of the i-th loop.
* 3) All loops must start with 0.
* \param loop_rvs The loops to be fused
* \return The new loop after fusion
*/
virtual LoopRV Fuse(const Array<LoopRV>& loop_rvs) = 0;
/*!
* \brief Split a loop into a list of consecutive loops. It requires:
* 1) The loop can't have annotation or thread binding.
* 2) The loop must start with 0.
* \param loop_rv The loop to be split
* \param factors The tiling factors, and at most one of which is -1, which means that
* factor is inferred.
* \return The new loops after split
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
/******** Schedule: compute location ********/
/*!
* \brief Inline a block into its consumer(s). It requires:
Expand Down
138 changes: 136 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
# pylint: disable=unused-import
"""The TensorIR schedule class"""
from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple
jinhongyii marked this conversation as resolved.
Show resolved Hide resolved

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
Expand All @@ -43,7 +43,10 @@ class BlockRV(Object):
"""A random variable that refers to a block"""


ExprRV = PrimExpr # A random variable that evaluates to an integer
# It is a workaround for mypy: https://github.com/python/mypy/issues/7866#issuecomment-549454370
# This feature is not supported until python 3.10:
# https://docs.python.org/3.10/whatsnew/3.10.html#pep-613-typealias
ExprRV = Union[PrimExpr] # A random variable that evaluates to an integer

RAND_VAR_TYPE = Union[ExprRV, BlockRV, LoopRV] # type: ignore # pylint: disable=invalid-name

Expand Down Expand Up @@ -257,6 +260,137 @@ def get_loops(self, block: BlockRV) -> List[LoopRV]:
return _ffi_api_schedule.ScheduleGetLoops(self, block) # type: ignore # pylint: disable=no-member

########## Schedule: loops manipulation ##########
def fuse(self, *loops: List[LoopRV]) -> LoopRV:
"""Fuse a list of consecutive loops into one. It requires:
1) The loops can't have annotations or thread bindings.
2) The (i+1)-th loop must be the only child of the i-th loop.
3) All loops must start with 0.
Parameters
----------
*loops : List[LoopRV]
The loops to be fused
Returns
----------
fused_loop : LoopRV
The new loop after fusion
Examples
--------
Before applying fuse, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
.. code-block:: python
sch = tir.Schedule(before_fuse)
i, j = sch.get_loops(sch.get_block("B"))
sch.fuse(i, j)
print(tvm.script.asscript(sch.mod["main"]))
After applying fuse, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_fuse(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
# the 2 loops are fused into 1
for i_j_fused in tir.serial(0, 16384):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, tir.floordiv(i_j_fused, 128))
tir.bind(vj, tir.floormod(i_j_fused, 128))
B[vi, vj] = A[vi, vj] * 2.0
"""
return _ffi_api_schedule.ScheduleFuse(self, loops) # type: ignore # pylint: disable=no-member

def split(
self,
loop: LoopRV,
factors: List[Union[ExprRV, None]],
) -> List[LoopRV]:
"""Split a loop into a list of consecutive loops. It requires:
1) The loop can't have annotation or thread binding.
2) The loop must start with 0.
Predicates may be added to ensure the total loop numbers keeps unchanged.
In `factors`, at most one of the factors can be None,
which will be automatically inferred.
jinhongyii marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
loop : LoopRV
The loop to be split
factors: List[Union[ExprRV, None]]
The splitting factors
Potential inputs are:
- None
- ExprRV
- Nonnegative constant integers
Returns
----------
split_loops : List[LoopRV]
The new loops after split
Examples
--------
Before split, in TensorIR, the IR is:
.. code-block:: python
@tvm.script.tir
def before_split(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0
Create the schedule and do fuse:
.. code-block:: python
sch = tir.Schedule(before_split)
i, j = sch.get_loops(sch.get_block("B"))
sch.split(i, factors=[2, 64])
print(tvm.script.asscript(sch.mod["main"]))
After applying split, the IR becomes:
.. code-block:: python
@tvm.script.tir
def after_split(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
# the original loop is split into 2 loops
for i0, i1, j in tir.grid(2, 64, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, ((i0*64) + i1))
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0
"""
# it will be checked later in C++ implementation
# that there is at most one None in `factors`
return _ffi_api_schedule.ScheduleSplit(self, loop, factors) # type: ignore # pylint: disable=no-member

########## Schedule: compute location ##########
def compute_inline(self, block: BlockRV) -> None:
"""Inline a block into its consumer(s). It requires:
Expand Down
15 changes: 15 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,21 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter
return NormalizeIterMapToExpr(expr);
});

Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
const PrimExpr& input_pred, bool require_bijective) {
Analyzer analyzer;
Array<IterSumExpr> rewrite =
DetectIterMap(indices, input_iters, input_pred, require_bijective, &analyzer);
if (rewrite.empty()) {
return indices;
}
Array<PrimExpr> res;
res.reserve(rewrite.size());
IterMapToExprNormalizer converter(&analyzer);
for (const auto& expr : rewrite) res.push_back(converter.Convert(expr));
return res;
}

/*!
* \brief Divider to divide the bindings into two sets of bindings(outer and inner)
* such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X.
Expand Down
4 changes: 4 additions & 0 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
TVM_TRY_REWRITE_IF(floordiv(x + c1, c2), floordiv(x, c2) + floordiv(c1, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floordiv(x * c1, x * c2), floordiv(c1, c2), c2.Eval()->value > 0);

TVM_TRY_REWRITE_IF(floordiv(x + y, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));

TVM_TRY_REWRITE_IF(floordiv(y + x, x), floordiv(y, x) + 1, CanProveGreaterEqual(x.Eval(), 0));
Expand Down Expand Up @@ -882,6 +884,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x, c2),
c2.Eval()->value > 0 && c1.Eval()->value % c2.Eval()->value == 0);

TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), c2.Eval()->value != 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please send this as separate PR with regression testcases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines are added to support symbolic split/fuse in some simple cases, but it still can't handle situations where there are 2 or more symbolic vars. Do you think we should support this simple situation for now by merging it in a new PR or we put it aside and support the complete symbolic usage after we improve symbolic divisible check in iter_affine_map?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tqchen @junrushao1994 what's your opinion on this?


TVM_TRY_REWRITE(floormod(x * y, y), ZeroWithTypeLike(x));
TVM_TRY_REWRITE(floormod(y * x, y), ZeroWithTypeLike(y));

Expand Down
87 changes: 87 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,93 @@ Array<LoopRV> ConcreteScheduleNode::GetLoops(const BlockRV& block_rv) {
}

/******** Schedule: loops manipulation ********/

LoopRV ConcreteScheduleNode::Fuse(const Array<LoopRV>& loop_rvs) {
CHECK(!loop_rvs.empty()) << "ValueError: 'fuse' requires at least 1 loop(s)";
Array<StmtSRef> loop_srefs = this->GetSRefs(loop_rvs);
StmtSRef result{nullptr};
TVM_TIR_SCHEDULE_BEGIN();
result = tir::Fuse(state_, loop_srefs);
TVM_TIR_SCHEDULE_END("fuse", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(result);
}

Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
const Array<Optional<ExprRV>>& factor_rvs) {
class NotSingleInferFactorError : public ScheduleError {
public:
explicit NotSingleInferFactorError(IRModule mod) : mod_(mod) {}

String FastErrorString() const final {
return "ScheduleError: only one factor can be specified as -1 or none";
}

String DetailRenderTemplate() const final {
return "Only one factor can be specified as -1 or none";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {}; }

IRModule mod_;
};

class WrongFactorProductError : public ScheduleError {
public:
explicit WrongFactorProductError(IRModule mod, For loop) : mod_(mod), loop_(std::move(loop)) {}

String FastErrorString() const final {
return "ScheduleError: The product of factors is not larger than or equal to the extent of "
"loop";
}

String DetailRenderTemplate() const final {
return "The product of factors is not larger than or equal to the extent of loop {0}";
}

IRModule mod() const final { return mod_; }
Array<ObjectRef> LocationsOfInterest() const final { return {loop_}; }

IRModule mod_;
For loop_;
};
// Prepare for the splitting
StmtSRef loop_sref = this->GetSRef(loop_rv);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
Array<PrimExpr> factors;
factors.reserve(factor_rvs.size());
int infer_index = -1;
PrimExpr tot_length = 1;
Array<StmtSRef> results;
TVM_TIR_SCHEDULE_BEGIN();
// infer factor if needed and check validity of factors
for (size_t i = 0; i < factor_rvs.size(); i++) {
if (!factor_rvs[i].defined()) {
factors.push_back(Integer(-1));
if (infer_index == -1) {
infer_index = i;
} else {
throw NotSingleInferFactorError(state_->mod);
}
} else {
PrimExpr factor = this->Get(factor_rvs[i].value());
factors.push_back(factor);
tot_length *= factor;
}
}
if (infer_index != -1) {
factors.Set(infer_index,
this->analyzer_->Simplify(floordiv(loop->extent + tot_length - 1, tot_length)));
} else if (!this->analyzer_->CanProve(tot_length >= loop->extent)) {
throw WrongFactorProductError(state_->mod, GetRef<For>(loop));
}
results = tir::Split(state_, loop_sref, factors);
TVM_TIR_SCHEDULE_END("split", this->error_render_level_);
this->state_->DebugVerify();
return CreateRV<LoopRV>(results);
}

/******** Schedule: compute location ********/

void ConcreteScheduleNode::ComputeInline(const BlockRV& block_rv) {
Expand Down
Loading