diff --git a/tests/backend/riscv/test_preallocated.py b/tests/backend/riscv/test_preallocated.py index 01bcaddd9e..2b97a8eb73 100644 --- a/tests/backend/riscv/test_preallocated.py +++ b/tests/backend/riscv/test_preallocated.py @@ -1,6 +1,7 @@ from xdsl.backend.riscv.register_allocation import gather_allocated from xdsl.builder import Builder from xdsl.dialects import riscv, riscv_func +from xdsl.dialects.builtin import SymbolRefAttr def test_gather_allocated(): @@ -56,3 +57,17 @@ def multiple_preallocated_body() -> None: ) assert len(pa_regs) == 2 + + @Builder.implicit_region + def func_call_preallocated_body() -> None: + reg1 = riscv.IntRegisterType.unallocated() + v1 = riscv.GetRegisterOp(reg1).res + v2 = riscv.GetRegisterOp(riscv.Registers.S0).res + riscv_func.CallOp(SymbolRefAttr("hello"), (v1, v2), ()) + + pa_regs = gather_allocated( + riscv_func.FuncOp("foo", func_call_preallocated_body, ((), ())) + ) + + assert len(pa_regs) == 36 + assert riscv.Registers.S0 in pa_regs diff --git a/xdsl/backend/riscv/register_allocation.py b/xdsl/backend/riscv/register_allocation.py index 7704b2943d..43fc29e01b 100644 --- a/xdsl/backend/riscv/register_allocation.py +++ b/xdsl/backend/riscv/register_allocation.py @@ -4,7 +4,7 @@ from ordered_set import OrderedSet from xdsl.backend.riscv.register_queue import RegisterQueue -from xdsl.dialects import riscv_func, riscv_scf, riscv_snitch +from xdsl.dialects import riscv, riscv_func, riscv_scf, riscv_snitch from xdsl.dialects.riscv import ( FloatRegisterType, IntRegisterType, @@ -26,6 +26,15 @@ def gather_allocated(func: riscv_func.FuncOp) -> set[RISCVRegisterType]: if not isinstance(op, RISCVAsmOperation): continue + if isinstance(op, riscv_func.CallOp): + # These registers are not guaranteed to hold the same values when the callee + # returns, according to the RISC-V calling convention. + # https://riscv.org/wp-content/uploads/2015/01/riscv-calling.pdf + allocated.update(riscv.Registers.A) + allocated.update(riscv.Registers.T) + allocated.update(riscv.Registers.FA) + allocated.update(riscv.Registers.FT) + for param in chain(op.operands, op.results): if isinstance(param.type, RISCVRegisterType) and param.type.is_allocated: if not param.type.register_name.startswith("j"):