Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: (riscv_scf) add a pass to fuse perfectly nested loops #2540

Merged
merged 5 commits into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions tests/filecheck/dialects/riscv_scf/loop_fusion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// RUN: xdsl-opt -p riscv-scf-loop-fusion %s | filecheck %s

// CHECK: builtin.module {

// Success case
%c0 = riscv.li 0 : () -> !riscv.reg<>
%c1 = riscv.li 1 : () -> !riscv.reg<>
%c8 = riscv.li 8 : () -> !riscv.reg<>
%c64 = riscv.li 64 : () -> !riscv.reg<>

riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
}

// CHECK-NEXT: %c0 = riscv.li 0 : () -> !riscv.reg<>
// CHECK-NEXT: %c1 = riscv.li 1 : () -> !riscv.reg<>
// CHECK-NEXT: %c8 = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %c64 = riscv.li 64 : () -> !riscv.reg<>
// CHECK-NEXT: riscv_scf.for %0 : !riscv.reg<> = %c0 to %c64 step %c1 {
// CHECK-NEXT: %1 = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%0) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }

// Cannot fuse outer loop with iteration arguments
%res0 = riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 iter_args(%arg0 = %c0) -> (!riscv.reg<>) {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
riscv_scf.yield %arg0 : !riscv.reg<>
}

// CHECK-NEXT: %{{.*}} = riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (!riscv.reg<>) {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: riscv_scf.yield %{{.*}} : !riscv.reg<>
// CHECK-NEXT: }

// Inner loop must be the only operation in the outer loop, aside from yield
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
%20 = riscv.li 42 : () -> !riscv.reg<>
}

// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: %{{.*}} = riscv.li 42 : () -> !riscv.reg<>
// CHECK-NEXT: }

// Cannot fuse inner loop with iteration arguments
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
%res1 = riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 iter_args(%arg1 = %c0) -> (!riscv.reg<>) {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
riscv_scf.yield %arg1 : !riscv.reg<>
}
}
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}) -> (!riscv.reg<>) {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: riscv_scf.yield %{{.*}} : !riscv.reg<>
// CHECK-NEXT: }
// CHECK-NEXT: }

// Cannot fuse inner loop with non-zero lb
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c8 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
}

// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }


// Each iter arg must only be used once, in an add

riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19, %16) : (!riscv.reg<>, !riscv.reg<>) -> ()
}
}
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.add %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19, %17) : (!riscv.reg<>, !riscv.reg<>) -> ()
}
}
riscv_scf.for %16 : !riscv.reg<> = %c0 to %c64 step %c8 {
riscv_scf.for %17 : !riscv.reg<> = %c0 to %c8 step %c1 {
%18 = riscv.li 8 : () -> !riscv.reg<>
%19 = riscv.mul %16, %17 : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
"test.op"(%19) : (!riscv.reg<>) -> ()
}
}

// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}) : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.add %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}, %{{.*}}) : (!riscv.reg<>, !riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: riscv_scf.for %{{.*}} : !riscv.reg<> = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = riscv.li 8 : () -> !riscv.reg<>
// CHECK-NEXT: %{{.*}} = riscv.mul %{{.*}}, %{{.*}} : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
// CHECK-NEXT: "test.op"(%{{.*}}) : (!riscv.reg<>) -> ()
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK-NEXT: }


6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,11 @@ def get_riscv_register_allocation():

return riscv_register_allocation.RISCVRegisterAllocation

def get_riscv_scf_loop_fusion():
from xdsl.transforms import riscv_scf_loop_fusion

return riscv_scf_loop_fusion.RiscvScfLoopFusionPass

def get_riscv_scf_loop_range_folding():
from xdsl.transforms import riscv_scf_loop_range_folding

Expand Down Expand Up @@ -603,6 +608,7 @@ def get_test_lower_snitch_stream_to_asm():
"replace-incompatible-fpga": get_replace_incompatible_fpga,
"riscv-allocate-registers": get_riscv_register_allocation,
"riscv-cse": get_riscv_cse,
"riscv-scf-loop-fusion": get_riscv_scf_loop_fusion,
"riscv-scf-loop-range-folding": get_riscv_scf_loop_range_folding,
"scf-parallel-loop-tiling": get_scf_parallel_loop_tiling,
"snitch-allocate-registers": get_snitch_register_allocation,
Expand Down
92 changes: 92 additions & 0 deletions xdsl/transforms/riscv_scf_loop_fusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import cast

from xdsl.dialects import builtin, riscv, riscv_scf
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.transforms.canonicalization_patterns.riscv import get_constant_value


class FuseNestedLoopsPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter) -> None:
if op.iter_args:
return

outer_body = op.body.block
if not isinstance(inner_loop := outer_body.first_op, riscv_scf.ForOp):
# Outer loop must contain inner loop
return
if inner_loop is not cast(riscv_scf.YieldOp, outer_body.last_op).prev_op:
# Outer loop must contain only inner loop and yield
return
if inner_loop.iter_args:
return

if (inner_lb := get_constant_value(inner_loop.lb)) is None:
return
if inner_lb.value.data != 0:
return

if (inner_ub := get_constant_value(inner_loop.ub)) is None:
return
if (outer_step := get_constant_value(op.step)) is None:
return
if inner_ub != outer_step:
return

outer_index = outer_body.args[0]
inner_index = inner_loop.body.block.args[0]

if len(outer_index.uses) != 1 or len(inner_index.uses) != 1:
# If the induction variable is used more than once, we can't fold it
return

outer_user = next(iter(outer_index.uses)).operation
inner_user = next(iter(inner_index.uses)).operation
if outer_user is not inner_user:
return

user = outer_user

if not isinstance(user, riscv.AddOp):
return

# We can fuse
user.rd.replace_by(inner_index)
rewriter.erase_op(user)
moved_region = rewriter.move_region_contents_to_new_regions(inner_loop.body)
rewriter.erase_op(inner_loop)

rewriter.replace_matched_op(
riscv_scf.ForOp(
op.lb,
op.ub,
inner_loop.step,
(),
moved_region,
)
)


class RiscvScfLoopFusionPass(ModulePass):
"""
Folds perfect loop nests if they can be represented with a single loop.
Currently does this by matching the inner loop range with the outer loop step.
If the inner iteration space fits perfectly in the outer iteration step, then merge.
Other conditions:
- the only use of the induction arguments must be an add operation, this op is fused
into a single induction argument,
- the lower bound of the inner loop must be 0,
- the loops must have no iteration arguments.
"""

name = "riscv-scf-loop-fusion"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(FuseNestedLoopsPattern()).rewrite_module(op)
Loading