Skip to content

Commit

Permalink
simplify extent of loop after fuse and add corresponding test case (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Courtesy-Xs authored Jan 4, 2024
1 parent 193fea3 commit 7b616c4
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
2 changes: 1 addition & 1 deletion paddle/cinn/ir/schedule/impl/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ Expr DyScheduleImpl::Fuse(const std::vector<Expr>& loops) {
for (int i = 0; i < loops_number; ++i) {
fused_extent = fused_extent * for_nodes[i]->extent;
}

fused_extent = cinn::common::AutoSimplify(fused_extent);
if (!fused_body.As<ir::Block>()) fused_body = Block::Make({fused_body});
Expr new_stmt = For::Make(fused_var,
Expr(0),
Expand Down
55 changes: 53 additions & 2 deletions test/cinn/ir/test_llir_schedule_fuse_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def elementwise_fuse_assign_loop(
def elementwise_fuse_assign_loop_gt(
X: DataArray((128, 128, 128)), Y: DataArray((128, 128, 128))
):
for i in range(((1 * 128) * 128) * 128):
for i in range(2097152):
with ir.ScheduleBlockContext("Y") as block_y:
i1_1, j1_1, k1_1 = ir.AxisMap(
"SSS", [(i / 128) / 128, (i / 128) % 128, i % 128]
Expand Down Expand Up @@ -148,7 +148,7 @@ def elementwise_fuse_assign_loop(
Y: DataArray((-1, 128, 128)),
N: ir.Var(),
):
for i_j_k_fused in range(((1 * N) * 128) * 128):
for i_j_k_fused in range(16384 * N):
with ir.ScheduleBlockContext("Y") as block_y:
i1, j1, k1 = ir.AxisMap(
"SSS",
Expand Down Expand Up @@ -207,9 +207,60 @@ def elementwise_split(
assert_llir_equal(origin.elementwise_split, expected.elementwise_split)


def test_fuse_split():
@to_cinn_llir
def elementwise_fuse_split_origin(
X: DataArray((64, 128, 128)), Y: DataArray((64, 128, 128))
):
for i in range(64):
for j in range(128):
for k in range(128):
with ir.ScheduleBlockContext("Y") as Y_block:
i1, j1, k1 = ir.AxisMap("SSS", [i, j, k])
fused = sch.fuse([i, j])
sch.split(fused, factors=[2, 512, -1])
Y[i1, j1, k1] = X[i1, j1, k1] * 2.0

@to_cinn_llir
def elementwise_fuse_split_expected(
X: DataArray((64, 128, 128)), Y: DataArray((64, 128, 128))
):
for i_j_fused in range(2):
for i_j_fused_0 in range(512):
for i_j_fused_1 in range(8):
for k in range(128):
with ir.ScheduleBlockContext("Y") as Y_block:
i1, j1, k1 = ir.AxisMap(
"SSS",
[
(
(
(4096 * i_j_fused)
+ ((8 * i_j_fused_0) + i_j_fused_1)
)
/ 128
),
(
(
(4096 * i_j_fused)
+ ((8 * i_j_fused_0) + i_j_fused_1)
)
% 128
),
k,
],
)
Y[i1, j1, k1] = X[i1, j1, k1] * 2.0

assert_llir_equal(
elementwise_fuse_split_origin, elementwise_fuse_split_expected
)


if __name__ == "__main__":
test_fuse()
test_split()
test_fuse_split()
test_split_predicate()
test_fuse_dynamic()
test_split_dynamic()

0 comments on commit 7b616c4

Please sign in to comment.