Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

backend: add RegisterConstraints and use in riscv backend #2930

Merged
merged 2 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions tests/backend/riscv/test_register_allocation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import re

import pytest
from typing_extensions import Self

from xdsl.backend.register_allocatable import RegisterConstraints
from xdsl.backend.riscv.register_allocation import RegisterAllocatorLivenessBlockNaive
from xdsl.backend.riscv.register_queue import RegisterQueue
from xdsl.dialects import riscv
from xdsl.irdl import IRDLOperation, irdl_op_definition, operand_def, result_def
from xdsl.utils.exceptions import DiagnosticException
from xdsl.utils.test_value import TestSSAValue

Expand Down Expand Up @@ -72,3 +75,56 @@ def j(index: int):
),
):
register_allocator.allocate_same((e0, e1, e2))


def test_allocate_with_inout_constraints():

@irdl_op_definition
class MyInstruction(riscv.RISCVAsmOperation, IRDLOperation):

name = "riscv.my_instruction"

rs0 = operand_def()
rs1 = operand_def()
rd0 = result_def()
rd1 = result_def()

@classmethod
def get(cls, rs0: str, rs1: str, rd0: str, rd1: str) -> Self:
return cls.build(
operands=(
TestSSAValue(riscv.IntRegisterType(rs0)),
TestSSAValue(riscv.IntRegisterType(rs1)),
),
result_types=(
riscv.IntRegisterType(rd0),
riscv.IntRegisterType(rd1),
),
)

def get_register_constraints(self) -> RegisterConstraints:
return RegisterConstraints(
(self.rs0,), (self.rd0,), ((self.rs1, self.rd1),)
)

register_queue = RegisterQueue(
available_int_registers=[], available_float_registers=[]
)
register_allocator = RegisterAllocatorLivenessBlockNaive(register_queue)

# All new registers. The result register is reused by the allocator for the operand.
op0 = MyInstruction.get("", "", "", "")
register_allocator.process_riscv_op(op0)
assert op0.rs0.type == riscv.IntRegisterType("j1")
assert op0.rs1.type == riscv.IntRegisterType("j0")
assert op0.rd0.type == riscv.IntRegisterType("j1")
assert op0.rd1.type == riscv.IntRegisterType("j0")

# One register reserved for inout parameter, the allocator should allocate the output
# to the same register.
op1 = MyInstruction.get("", "", "", "a0")
register_allocator.process_riscv_op(op1)
assert op1.rs0.type == riscv.IntRegisterType("j2")
assert op1.rs1.type == riscv.IntRegisterType("a0")
assert op1.rd0.type == riscv.IntRegisterType("j2")
assert op1.rd1.type == riscv.IntRegisterType("a0")
28 changes: 28 additions & 0 deletions xdsl/backend/register_allocatable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import abc
from collections.abc import Sequence
from typing import NamedTuple

from xdsl.ir import SSAValue


class RegisterConstraints(NamedTuple):
"""
Values used by an instruction.
A collection of operations in `inouts` represents the constraint that they must be
allocated to the same register.
"""

ins: Sequence[SSAValue]
outs: Sequence[SSAValue]
inouts: Sequence[Sequence[SSAValue]]


class HasRegisterConstraints(abc.ABC):

@abc.abstractmethod
def get_register_constraints(self) -> RegisterConstraints:
"""
The values with register types used by this operation, for use in register
allocation.
"""
raise NotImplementedError()
10 changes: 8 additions & 2 deletions xdsl/backend/riscv/register_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,22 @@ def process_riscv_op(self, op: RISCVAsmOperation) -> None:
"""
Allocate registers for RISC-V Instruction.
"""
ins, outs, inouts = op.get_register_constraints()

for result in op.results:
# Allocate registers to inout operand groups since they are defined further up
# in the use-def SSA chain
for operand_group in inouts:
self.allocate_same(operand_group)

for result in outs:
# Allocate registers to result if not already allocated
self.allocate(result)
# Free the register since the SSA value is created here
self._free(result)

# Allocate registers to operands since they are defined further up
# in the use-def SSA chain
for operand in op.operands:
for operand in ins:
self.allocate(operand)

def allocate_for_loop(self, loop: riscv_scf.ForOp) -> None:
Expand Down
9 changes: 8 additions & 1 deletion xdsl/dialects/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

from typing_extensions import Self

from xdsl.backend.register_allocatable import (
HasRegisterConstraints,
RegisterConstraints,
)
from xdsl.backend.register_type import RegisterType
from xdsl.dialects.builtin import (
AnyIntegerAttr,
Expand Down Expand Up @@ -347,11 +351,14 @@ def print_parameter(self, printer: Printer) -> None:
printer.print_string_literal(self.data)


class RISCVAsmOperation(IRDLOperation, ABC):
class RISCVAsmOperation(HasRegisterConstraints, IRDLOperation, ABC):
"""
Base class for operations that can be a part of RISC-V assembly printing.
"""

def get_register_constraints(self) -> RegisterConstraints:
return RegisterConstraints(self.operands, self.results, ())

@abstractmethod
def assembly_line(self) -> str | None:
raise NotImplementedError()
Expand Down
13 changes: 7 additions & 6 deletions xdsl/transforms/canonicalization_patterns/riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,11 @@ def get_constant_value(value: SSAValue) -> riscv.Imm32Attr | None:
if value.type == riscv.Registers.ZERO:
return IntegerAttr.from_int_and_width(0, 32)

if isinstance(value.owner, riscv.MVOp):
return get_constant_value(value.owner.rs)
if not isinstance(value, OpResult):
return

if isinstance(value.owner, riscv.LiOp) and isinstance(
value.owner.immediate, IntegerAttr
):
return value.owner.immediate
if isinstance(value.op, riscv.MVOp):
return get_constant_value(value.op.rs)

if isinstance(value.op, riscv.LiOp) and isinstance(value.op.immediate, IntegerAttr):
return value.op.immediate
Loading