Skip to content

Commit

Permalink
backend: (riscv) Add prologue epilogue insertion (#2752)
Browse files Browse the repository at this point in the history
The RISC-V call ABI requires a few registers, if clobbered, to be
restored by the callee prior to the function returning. This has so far
not been implemented meaning that any C compiler calling an xDSL
generated function may cause undefined behaviour.

This PR adds a pass performing the required register saving and
restoring by pushing and popping them from the stack. The registers are
saved in the prologue at the beginning of the function and an epilogue
restoring them is emitted before every return operation.

The pass is required to run after register allocation to see the
clobbering of callee-preserved registers. It itself does not require
register allocation to be run afterward.
  • Loading branch information
zero9178 authored Jun 20, 2024
1 parent 79eeff0 commit bc1b4b7
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 0 deletions.
52 changes: 52 additions & 0 deletions tests/filecheck/backend/riscv/prologue_epilogue_insertion.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: xdsl-opt --split-input-file -p "riscv-prologue-epilogue-insertion" %s | filecheck %s
// RUN: xdsl-opt --split-input-file -p "riscv-prologue-epilogue-insertion{flen=4}" %s | filecheck %s --check-prefix=CHECK-SMALL-FLEN

// CHECK: func @main
riscv_func.func @main() {
// CHECK-NEXT: get_register
// CHECK-SAME: -> !riscv.reg<sp>
// CHECK-NEXT: addi %{{.*}}, -12
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.reg<sp>
// CHECK-NEXT: get_float_register
// CHECK-SAME: -> !riscv.freg<fs2>
// CHECK-NEXT: fsd %{{.*}}, %{{.*}}, 0
// CHECK-SAME: (!riscv.reg<sp>, !riscv.freg<fs2>) -> ()
// CHECK-NEXT: get_register
// CHECK-SAME: -> !riscv.reg<s5>
// CHECK-NEXT: sw %{{.*}}, %{{.*}}, 8
// CHECK-SAME: (!riscv.reg<sp>, !riscv.reg<s5>) -> ()

%fs0 = riscv.get_float_register : () -> !riscv.freg<fs0>
%fs1 = riscv.get_float_register : () -> !riscv.freg<fs1>
// Clobber only fs2.
%sum1 = riscv.fadd.s %fs0, %fs1 : (!riscv.freg<fs0>, !riscv.freg<fs1>) -> !riscv.freg<fs2>
%zero = riscv.get_register : () -> !riscv.reg<zero>
// Clobber s5.
%0 = riscv.mv %zero : (!riscv.reg<zero>) -> !riscv.reg<s5>
riscv_cf.blt %0 : !riscv.reg<s5>, %zero : !riscv.reg<zero>, ^0(), ^1()
^1:
// CHECK: label "l1"
riscv.label "l1"
// CHECK-NEXT: fld %{{.*}}, 0
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<fs2>
// CHECK-NEXT: lw %{{.*}}, 8
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<s5>
// CHECK-NEXT: addi %{{.*}}, 12
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<sp>
// CHECK-NEXT: return
riscv_func.return
^0:
// CHECK: label "l0"
riscv.label "l0"
// CHECK-NEXT: fld %{{.*}}, 0
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<fs2>
// CHECK-NEXT: lw %{{.*}}, 8
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<s5>
// CHECK-NEXT: addi %{{.*}}, 12
// CHECK-SAME: (!riscv.reg<sp>) -> !riscv.freg<sp>
riscv_func.return
}

// CHECK-SMALL-FLEN: func @main
// CHECK-SMALL-FLEN: addi %{{.*}}, -8
// CHECK-SMALL-FLEN-SAME: (!riscv.reg<sp>) -> !riscv.reg<sp>
6 changes: 6 additions & 0 deletions tests/interactive/test_get_all_available_passes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from xdsl.backend.riscv import prologue_epilogue_insertion
from xdsl.backend.riscv.lowering import (
convert_arith_to_riscv,
convert_func_to_riscv_func,
Expand Down Expand Up @@ -37,6 +38,11 @@ def test_get_all_available_passes():
module_pass=reconcile_unrealized_casts.ReconcileUnrealizedCastsPass,
pass_spec=None,
),
AvailablePass(
display_name="riscv-prologue-epilogue-insertion",
module_pass=prologue_epilogue_insertion.PrologueEpilogueInsertion,
pass_spec=None,
),
)
)

Expand Down
93 changes: 93 additions & 0 deletions xdsl/backend/riscv/prologue_epilogue_insertion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from dataclasses import dataclass, field

from ordered_set import OrderedSet

from xdsl.builder import Builder
from xdsl.context import MLContext
from xdsl.dialects import builtin, riscv, riscv_func
from xdsl.dialects.riscv import (
IntRegisterType,
Registers,
RISCVRegisterType,
)
from xdsl.passes import ModulePass


@dataclass(frozen=True)
class PrologueEpilogueInsertion(ModulePass):
"""
Pass inserting a prologue and epilogue according to the RISC-V ABI.
The prologues and epilogues are responsible for saving any callee-preserved
registers.
In RISC-V these are 's0' to 's11' and 'fs0' to `fs11'.
The stack pointer 'sp' must also be restored to its original value.
This pass should be run late in the pipeline after register allocation.
It does not itself require register allocation nor invalidate the result of the
register allocator.
"""

name = "riscv-prologue-epilogue-insertion"
xlen: int = field(default=4)
flen: int = field(default=8)

def _process_function(self, func: riscv_func.FuncOp) -> None:
# Find all callee-preserved registers that are clobbered. We define clobbered
# as it being the result of some operation and therefore written to.
used_callee_preserved_registers = OrderedSet(
res.type
for op in func.walk()
if not isinstance(op, riscv.GetRegisterOp | riscv.GetFloatRegisterOp)
for res in op.results
if res.type in Registers.S or res.type in Registers.FS
)

def get_register_size(r: RISCVRegisterType):
if isinstance(r, IntRegisterType):
return self.xlen
return self.flen

# Build the prologue at the beginning of the function.
builder = Builder.at_start(func.body.blocks[0])
sp_register = builder.insert(riscv.GetRegisterOp(Registers.SP))
stack_size = sum(get_register_size(r) for r in used_callee_preserved_registers)
builder.insert(riscv.AddiOp(sp_register, -stack_size, rd=Registers.SP))
offset = 0
for reg in used_callee_preserved_registers:
if isinstance(reg, IntRegisterType):
reg_op = builder.insert(riscv.GetRegisterOp(reg))
op = riscv.SwOp(rs1=sp_register, rs2=reg_op, immediate=offset)
else:
reg_op = builder.insert(riscv.GetFloatRegisterOp(reg))
op = riscv.FSdOp(rs1=sp_register, rs2=reg_op, immediate=offset)

builder.insert(op)
offset += get_register_size(reg)

# Now build the epilogue right before every return operation.
for block in func.body.blocks:
ret_op = block.last_op
if not isinstance(ret_op, riscv_func.ReturnOp):
continue

builder = Builder.before(ret_op)
offset = 0
for reg in used_callee_preserved_registers:
if isinstance(reg, IntRegisterType):
op = riscv.LwOp(rs1=sp_register, rd=reg, immediate=offset)
else:
op = riscv.FLdOp(rs1=sp_register, rd=reg, immediate=offset)
builder.insert(op)
offset += get_register_size(reg)

builder.insert(riscv.AddiOp(sp_register, stack_size, rd=Registers.SP))

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
for func in op.walk():
if not isinstance(func, riscv_func.FuncOp):
continue

if len(func.body.blocks) == 0:
continue

self._process_function(func)
6 changes: 6 additions & 0 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ def get_snitch_register_allocation():

return snitch_register_allocation.SnitchRegisterAllocation

def get_riscv_prologue_epilogue_insertion():
from xdsl.backend.riscv import prologue_epilogue_insertion

return prologue_epilogue_insertion.PrologueEpilogueInsertion

def get_convert_arith_to_riscv():
from xdsl.backend.riscv.lowering import convert_arith_to_riscv

Expand Down Expand Up @@ -350,6 +355,7 @@ def get_test_lower_snitch_stream_to_asm():
"riscv-scf-loop-range-folding": get_riscv_scf_loop_range_folding,
"scf-parallel-loop-tiling": get_scf_parallel_loop_tiling,
"snitch-allocate-registers": get_snitch_register_allocation,
"riscv-prologue-epilogue-insertion": get_riscv_prologue_epilogue_insertion,
"stencil-shape-inference": get_stencil_shape_inference,
"stencil-storage-materialization": get_stencil_storage_materialization,
"stencil-tensorize-z-dimension": get_stencil_tensorize_z_dimension,
Expand Down

0 comments on commit bc1b4b7

Please sign in to comment.