From e894888a852b7cb685ba734f7a6e4df7f512b6b6 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sat, 6 Jul 2024 18:00:51 +0100 Subject: [PATCH 1/2] backend: reserve registers not preserved across function calls in RISC-V --- tests/backend/riscv/test_preallocated.py | 15 +++++++++++++++ xdsl/backend/riscv/register_allocation.py | 8 +++++++- 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/backend/riscv/test_preallocated.py b/tests/backend/riscv/test_preallocated.py index 01bcaddd9e..2b97a8eb73 100644 --- a/tests/backend/riscv/test_preallocated.py +++ b/tests/backend/riscv/test_preallocated.py @@ -1,6 +1,7 @@ from xdsl.backend.riscv.register_allocation import gather_allocated from xdsl.builder import Builder from xdsl.dialects import riscv, riscv_func +from xdsl.dialects.builtin import SymbolRefAttr def test_gather_allocated(): @@ -56,3 +57,17 @@ def multiple_preallocated_body() -> None: ) assert len(pa_regs) == 2 + + @Builder.implicit_region + def func_call_preallocated_body() -> None: + reg1 = riscv.IntRegisterType.unallocated() + v1 = riscv.GetRegisterOp(reg1).res + v2 = riscv.GetRegisterOp(riscv.Registers.S0).res + riscv_func.CallOp(SymbolRefAttr("hello"), (v1, v2), ()) + + pa_regs = gather_allocated( + riscv_func.FuncOp("foo", func_call_preallocated_body, ((), ())) + ) + + assert len(pa_regs) == 36 + assert riscv.Registers.S0 in pa_regs diff --git a/xdsl/backend/riscv/register_allocation.py b/xdsl/backend/riscv/register_allocation.py index 7704b2943d..a2fc35a0e0 100644 --- a/xdsl/backend/riscv/register_allocation.py +++ b/xdsl/backend/riscv/register_allocation.py @@ -4,7 +4,7 @@ from ordered_set import OrderedSet from xdsl.backend.riscv.register_queue import RegisterQueue -from xdsl.dialects import riscv_func, riscv_scf, riscv_snitch +from xdsl.dialects import riscv, riscv_func, riscv_scf, riscv_snitch from xdsl.dialects.riscv import ( FloatRegisterType, IntRegisterType, @@ -26,6 +26,12 @@ def gather_allocated(func: riscv_func.FuncOp) -> set[RISCVRegisterType]: if not isinstance(op, RISCVAsmOperation): continue + if isinstance(op, riscv_func.CallOp): + allocated.update(riscv.Registers.A) + allocated.update(riscv.Registers.T) + allocated.update(riscv.Registers.FA) + allocated.update(riscv.Registers.FT) + for param in chain(op.operands, op.results): if isinstance(param.type, RISCVRegisterType) and param.type.is_allocated: if not param.type.register_name.startswith("j"): From f4b49232f9a48b54c99a4d2f777e97e31d75313c Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Sat, 6 Jul 2024 23:02:08 +0100 Subject: [PATCH 2/2] add a comment --- xdsl/backend/riscv/register_allocation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xdsl/backend/riscv/register_allocation.py b/xdsl/backend/riscv/register_allocation.py index a2fc35a0e0..43fc29e01b 100644 --- a/xdsl/backend/riscv/register_allocation.py +++ b/xdsl/backend/riscv/register_allocation.py @@ -27,6 +27,9 @@ def gather_allocated(func: riscv_func.FuncOp) -> set[RISCVRegisterType]: continue if isinstance(op, riscv_func.CallOp): + # These registers are not guaranteed to hold the same values when the callee + # returns, according to the RISC-V calling convention. + # https://riscv.org/wp-content/uploads/2015/01/riscv-calling.pdf allocated.update(riscv.Registers.A) allocated.update(riscv.Registers.T) allocated.update(riscv.Registers.FA)