Skip to content

Commit

Permalink
[TIR] Avoid unnecessary dtype escalation in loop splitting (apache#12035
Browse files Browse the repository at this point in the history
)

This PR introduces a type check to cast loop split decisions (sometimes given as `int64`) back to a smaller datatype when the loop variable's data type is smaller. This issue usually happens during reloading a trace from disk using JSON database and causes the failure of `CompactBufferAllocation` pass.
  • Loading branch information
zxybazh authored and junrushao committed Jul 27, 2022
1 parent 6c4ddbd commit 3319a14
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,9 @@ Array<LoopRV> ConcreteScheduleNode::Split(const LoopRV& loop_rv,
if (is_const_int(factor) && !is_positive_const(factor)) {
throw NonPositiveFactorError(state_->mod, factor.as<IntImmNode>()->value, i);
}
if (factor.dtype().bits() > loop->extent.dtype().bits()) {
factor = cast(loop->extent.dtype(), factor);
}
factors.push_back(factor);
tot_length *= factor;
}
Expand Down
9 changes: 9 additions & 0 deletions tests/python/unittest/test_tir_schedule_split_fuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tvm.testing
from tvm import te, tir
from tvm.script import tir as T
from tvm.tir.expr import IntImm
from tvm.tir.schedule.testing import verify_trace_roundtrip

# pylint: disable=no-member,invalid-name,unused-variable
Expand Down Expand Up @@ -637,5 +638,13 @@ def _create_prim_func():
)


def test_split_int64_factors():
sch = tir.Schedule(elementwise_symbolic, debug_mask="all")
block_b = sch.get_block("B")
_, _, k = sch.get_loops(block_b)
sch.split(k, factors=[IntImm(dtype="int64", value=10), None])
tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"])


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 3319a14

Please sign in to comment.