-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
backend: (riscv) Add prologue epilogue insertion (#2752)
The RISC-V call ABI requires a few registers, if clobbered, to be restored by the callee prior to the function returning. This has so far not been implemented meaning that any C compiler calling an xDSL generated function may cause undefined behaviour. This PR adds a pass performing the required register saving and restoring by pushing and popping them from the stack. The registers are saved in the prologue at the beginning of the function and an epilogue restoring them is emitted before every return operation. The pass is required to run after register allocation to see the clobbering of callee-preserved registers. It itself does not require register allocation to be run afterward.
- Loading branch information
Showing
4 changed files
with
157 additions
and
0 deletions.
There are no files selected for viewing
52 changes: 52 additions & 0 deletions
52
tests/filecheck/backend/riscv/prologue_epilogue_insertion.mlir
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
// RUN: xdsl-opt --split-input-file -p "riscv-prologue-epilogue-insertion" %s | filecheck %s | ||
// RUN: xdsl-opt --split-input-file -p "riscv-prologue-epilogue-insertion{flen=4}" %s | filecheck %s --check-prefix=CHECK-SMALL-FLEN | ||
|
||
// CHECK: func @main | ||
riscv_func.func @main() { | ||
// CHECK-NEXT: get_register | ||
// CHECK-SAME: -> !riscv.reg<sp> | ||
// CHECK-NEXT: addi %{{.*}}, -12 | ||
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.reg<sp> | ||
// CHECK-NEXT: get_float_register | ||
// CHECK-SAME: -> !riscv.freg<fs2> | ||
// CHECK-NEXT: fsd %{{.*}}, %{{.*}}, 0 | ||
// CHECK-SAME: (!riscv.reg<sp>, !riscv.freg<fs2>) -> () | ||
// CHECK-NEXT: get_register | ||
// CHECK-SAME: -> !riscv.reg<s5> | ||
// CHECK-NEXT: sw %{{.*}}, %{{.*}}, 8 | ||
// CHECK-SAME: (!riscv.reg<sp>, !riscv.reg<s5>) -> () | ||
|
||
%fs0 = riscv.get_float_register : () -> !riscv.freg<fs0> | ||
%fs1 = riscv.get_float_register : () -> !riscv.freg<fs1> | ||
// Clobber only fs2. | ||
%sum1 = riscv.fadd.s %fs0, %fs1 : (!riscv.freg<fs0>, !riscv.freg<fs1>) -> !riscv.freg<fs2> | ||
%zero = riscv.get_register : () -> !riscv.reg<zero> | ||
// Clobber s5. | ||
%0 = riscv.mv %zero : (!riscv.reg<zero>) -> !riscv.reg<s5> | ||
riscv_cf.blt %0 : !riscv.reg<s5>, %zero : !riscv.reg<zero>, ^0(), ^1() | ||
^1: | ||
// CHECK: label "l1" | ||
riscv.label "l1" | ||
// CHECK-NEXT: fld %{{.*}}, 0 | ||
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<fs2> | ||
// CHECK-NEXT: lw %{{.*}}, 8 | ||
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<s5> | ||
// CHECK-NEXT: addi %{{.*}}, 12 | ||
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<sp> | ||
// CHECK-NEXT: return | ||
riscv_func.return | ||
^0: | ||
// CHECK: label "l0" | ||
riscv.label "l0" | ||
// CHECK-NEXT: fld %{{.*}}, 0 | ||
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<fs2> | ||
// CHECK-NEXT: lw %{{.*}}, 8 | ||
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<s5> | ||
// CHECK-NEXT: addi %{{.*}}, 12 | ||
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<sp> | ||
riscv_func.return | ||
} | ||
|
||
// CHECK-SMALL-FLEN: func @main | ||
// CHECK-SMALL-FLEN: addi %{{.*}}, -8 | ||
// CHECK-SMALL-FLEN-SAME: (!riscv.reg<sp>) -> !riscv.reg<sp> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
from dataclasses import dataclass, field | ||
|
||
from ordered_set import OrderedSet | ||
|
||
from xdsl.builder import Builder | ||
from xdsl.context import MLContext | ||
from xdsl.dialects import builtin, riscv, riscv_func | ||
from xdsl.dialects.riscv import ( | ||
IntRegisterType, | ||
Registers, | ||
RISCVRegisterType, | ||
) | ||
from xdsl.passes import ModulePass | ||
|
||
|
||
@dataclass(frozen=True) | ||
class PrologueEpilogueInsertion(ModulePass): | ||
""" | ||
Pass inserting a prologue and epilogue according to the RISC-V ABI. | ||
The prologues and epilogues are responsible for saving any callee-preserved | ||
registers. | ||
In RISC-V these are 's0' to 's11' and 'fs0' to `fs11'. | ||
The stack pointer 'sp' must also be restored to its original value. | ||
This pass should be run late in the pipeline after register allocation. | ||
It does not itself require register allocation nor invalidate the result of the | ||
register allocator. | ||
""" | ||
|
||
name = "riscv-prologue-epilogue-insertion" | ||
xlen: int = field(default=4) | ||
flen: int = field(default=8) | ||
|
||
def _process_function(self, func: riscv_func.FuncOp) -> None: | ||
# Find all callee-preserved registers that are clobbered. We define clobbered | ||
# as it being the result of some operation and therefore written to. | ||
used_callee_preserved_registers = OrderedSet( | ||
res.type | ||
for op in func.walk() | ||
if not isinstance(op, riscv.GetRegisterOp | riscv.GetFloatRegisterOp) | ||
for res in op.results | ||
if res.type in Registers.S or res.type in Registers.FS | ||
) | ||
|
||
def get_register_size(r: RISCVRegisterType): | ||
if isinstance(r, IntRegisterType): | ||
return self.xlen | ||
return self.flen | ||
|
||
# Build the prologue at the beginning of the function. | ||
builder = Builder.at_start(func.body.blocks[0]) | ||
sp_register = builder.insert(riscv.GetRegisterOp(Registers.SP)) | ||
stack_size = sum(get_register_size(r) for r in used_callee_preserved_registers) | ||
builder.insert(riscv.AddiOp(sp_register, -stack_size, rd=Registers.SP)) | ||
offset = 0 | ||
for reg in used_callee_preserved_registers: | ||
if isinstance(reg, IntRegisterType): | ||
reg_op = builder.insert(riscv.GetRegisterOp(reg)) | ||
op = riscv.SwOp(rs1=sp_register, rs2=reg_op, immediate=offset) | ||
else: | ||
reg_op = builder.insert(riscv.GetFloatRegisterOp(reg)) | ||
op = riscv.FSdOp(rs1=sp_register, rs2=reg_op, immediate=offset) | ||
|
||
builder.insert(op) | ||
offset += get_register_size(reg) | ||
|
||
# Now build the epilogue right before every return operation. | ||
for block in func.body.blocks: | ||
ret_op = block.last_op | ||
if not isinstance(ret_op, riscv_func.ReturnOp): | ||
continue | ||
|
||
builder = Builder.before(ret_op) | ||
offset = 0 | ||
for reg in used_callee_preserved_registers: | ||
if isinstance(reg, IntRegisterType): | ||
op = riscv.LwOp(rs1=sp_register, rd=reg, immediate=offset) | ||
else: | ||
op = riscv.FLdOp(rs1=sp_register, rd=reg, immediate=offset) | ||
builder.insert(op) | ||
offset += get_register_size(reg) | ||
|
||
builder.insert(riscv.AddiOp(sp_register, stack_size, rd=Registers.SP)) | ||
|
||
def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None: | ||
for func in op.walk(): | ||
if not isinstance(func, riscv_func.FuncOp): | ||
continue | ||
|
||
if len(func.body.blocks) == 0: | ||
continue | ||
|
||
self._process_function(func) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters