diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index c19735025ddc4..35f31ac9165cf 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -452,6 +452,9 @@ Array ConcreteScheduleNode::Split(const LoopRV& loop_rv, if (is_const_int(factor) && !is_positive_const(factor)) { throw NonPositiveFactorError(state_->mod, factor.as()->value, i); } + if (factor.dtype().bits() > loop->extent.dtype().bits()) { + factor = cast(loop->extent.dtype(), factor); + } factors.push_back(factor); tot_length *= factor; } diff --git a/tests/python/unittest/test_tir_schedule_split_fuse.py b/tests/python/unittest/test_tir_schedule_split_fuse.py index 0bfac4e425b95..9fd678174dc0c 100644 --- a/tests/python/unittest/test_tir_schedule_split_fuse.py +++ b/tests/python/unittest/test_tir_schedule_split_fuse.py @@ -20,6 +20,7 @@ import tvm.testing from tvm import te, tir from tvm.script import tir as T +from tvm.tir.expr import IntImm from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable @@ -637,5 +638,13 @@ def _create_prim_func(): ) +def test_split_int64_factors(): + sch = tir.Schedule(elementwise_symbolic, debug_mask="all") + block_b = sch.get_block("B") + _, _, k = sch.get_loops(block_b) + sch.split(k, factors=[IntImm(dtype="int64", value=10), None]) + tvm.ir.assert_structural_equal(elementwise_symbolic_split, sch.mod["main"]) + + if __name__ == "__main__": tvm.testing.main()