Skip to content

Commit

Permalink
dialects: (riscv) Add HasInsntrait and implement trait for some opera…
Browse files Browse the repository at this point in the history
…tions (#2784)

In order to make progress on #2468 I factored out the insn
representation. This will hopefully also make it usable for the RISC-V
backend efforts.

---------

Co-authored-by: Joren Dumoulin <joren.dumoulin@kuleuven.be>
  • Loading branch information
AntonLydike and jorendumoulin authored Jun 27, 2024
1 parent ae1e3bf commit 772df11
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/dialects/test_riscv_snitch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from xdsl.dialects import riscv_snitch
from xdsl.dialects.riscv import RISCVInstruction
from xdsl.traits import HasInsnRepresentation

ground_truth = {
"dmsrc": ".insn r 0x2b, 0, 0, x0, {0}, {1}",
"dmdst": ".insn r 0x2b, 0, 1, x0, {0}, {1}",
"dmcpyi": ".insn r 0x2b, 0, 2, {0}, {1}, {2}",
"dmcpy": ".insn r 0x2b, 0, 3, {0}, {1}, {2}",
"dmstati": ".insn r 0x2b, 0, 4, {0}, {1}, {2}",
"dmstat": ".insn r 0x2b, 0, 5, {0}, {1}, {2}",
"dmstr": ".insn r 0x2b, 0, 6, x0, {0}, {1}",
"dmrep": ".insn r 0x2b, 0, 7, x0, {0}, x0",
}


@pytest.mark.parametrize(
"op",
(
riscv_snitch.DMSourceOp,
riscv_snitch.DMDestinationOp,
riscv_snitch.DMCopyImmOp,
riscv_snitch.DMCopyOp,
riscv_snitch.DMStatImmOp,
riscv_snitch.DMStatOp,
riscv_snitch.DMStrideOp,
riscv_snitch.DMRepOp,
),
)
def test_insn_repr(op: RISCVInstruction):
trait = op.get_trait(HasInsnRepresentation)
assert trait is not None
# Limitation of Pyright, see https://github.com/microsoft/pyright/issues/7105
# We are currently stuck on an older version of Pyright, the update is
# tracked in https://github.com/xdslproject/xdsl/issues/2791
assert (
trait.get_insn(op) # pyright: ignore[reportGeneralTypeIssues]
== ground_truth[op.name[13:]]
)
21 changes: 21 additions & 0 deletions xdsl/backend/riscv/traits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass, field

from xdsl.ir import Operation
from xdsl.traits import HasInsnRepresentation


@dataclass(frozen=True)
class StaticInsnRepresentation(HasInsnRepresentation):
"""
Returns the first parameter as an insn template string.
See https://sourceware.org/binutils/docs/as/RISC_002dV_002dDirectives.html for more information
"""

insn: str = field(kw_only=True)

def get_insn(self, op: Operation) -> str:
"""
Return the insn representation of the operation for printing.
"""
return self.insn
33 changes: 33 additions & 0 deletions xdsl/dialects/riscv_snitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from typing_extensions import Self

from xdsl.backend.riscv.traits import StaticInsnRepresentation
from xdsl.dialects import riscv, stream
from xdsl.dialects.builtin import (
IntAttr,
Expand Down Expand Up @@ -462,6 +463,10 @@ class DMSourceOp(IRDLOperation, RISCVInstruction):
ptrlo = operand_def(riscv.IntRegisterType)
ptrhi = operand_def(riscv.IntRegisterType)

traits = frozenset(
[StaticInsnRepresentation(insn=".insn r 0x2b, 0, 0, x0, {0}, {1}")]
)

def __init__(self, ptrlo: SSAValue | Operation, ptrhi: SSAValue | Operation):
super().__init__(operands=[ptrlo, ptrhi])

Expand All @@ -476,6 +481,10 @@ class DMDestinationOp(IRDLOperation, RISCVInstruction):
ptrlo = operand_def(riscv.IntRegisterType)
ptrhi = operand_def(riscv.IntRegisterType)

traits = frozenset(
[StaticInsnRepresentation(insn=".insn r 0x2b, 0, 1, x0, {0}, {1}")]
)

def __init__(self, ptrlo: SSAValue | Operation, ptrhi: SSAValue | Operation):
super().__init__(operands=[ptrlo, ptrhi])

Expand All @@ -490,6 +499,10 @@ class DMStrideOp(IRDLOperation, RISCVInstruction):
srcstrd = operand_def(riscv.IntRegisterType)
dststrd = operand_def(riscv.IntRegisterType)

traits = frozenset(
[StaticInsnRepresentation(insn=".insn r 0x2b, 0, 6, x0, {0}, {1}")]
)

def __init__(self, srcstrd: SSAValue | Operation, dststrd: SSAValue | Operation):
super().__init__(operands=[srcstrd, dststrd])

Expand All @@ -503,6 +516,10 @@ class DMRepOp(IRDLOperation, RISCVInstruction):

reps = operand_def(riscv.IntRegisterType)

traits = frozenset(
[StaticInsnRepresentation(insn=".insn r 0x2b, 0, 7, x0, {0}, x0")]
)

def __init__(self, reps: SSAValue | Operation):
super().__init__(operands=[reps])

Expand All @@ -518,6 +535,10 @@ class DMCopyOp(IRDLOperation, RISCVInstruction):
size = operand_def(riscv.IntRegisterType)
config = operand_def(riscv.IntRegisterType)

traits = frozenset(
[StaticInsnRepresentation(insn=".insn r 0x2b, 0, 3, {0}, {1}, {2}")]
)

def __init__(
self,
size: SSAValue | Operation,
Expand All @@ -537,6 +558,10 @@ class DMStatOp(IRDLOperation, RISCVInstruction):
dest = result_def(riscv.IntRegisterType)
status = operand_def(riscv.IntRegisterType)

traits = frozenset(
[StaticInsnRepresentation(insn=".insn r 0x2b, 0, 5, {0}, {1}, {2}")]
)

def __init__(
self,
status: SSAValue | Operation,
Expand All @@ -556,6 +581,10 @@ class DMCopyImmOp(IRDLOperation, RISCVInstruction):
size = operand_def(riscv.IntRegisterType)
config = prop_def(UImm5Attr)

traits = frozenset(
[StaticInsnRepresentation(insn=".insn r 0x2b, 0, 2, {0}, {1}, {2}")]
)

def __init__(
self,
size: SSAValue | Operation,
Expand Down Expand Up @@ -606,6 +635,10 @@ class DMStatImmOp(IRDLOperation, RISCVInstruction):
dest = result_def(riscv.IntRegisterType)
status = prop_def(UImm5Attr)

traits = frozenset(
[StaticInsnRepresentation(insn=".insn r 0x2b, 0, 4, {0}, {1}, {2}")]
)

def __init__(
self,
status: int | UImm5Attr,
Expand Down
18 changes: 18 additions & 0 deletions xdsl/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,3 +498,21 @@ def has_effects(cls, op: Operation) -> bool:

class Pure(NoMemoryEffect):
"""A trait that signals that an operation has no side effects."""


class HasInsnRepresentation(OpTrait, abc.ABC):
"""
A trait providing information on how to encode an operation using a .insn assember directive.
The returned string contains python string.format placeholders where formatted operands are inserted during
printing.
See https://sourceware.org/binutils/docs/as/RISC_002dV_002dDirectives.html for more information.
"""

@abc.abstractmethod
def get_insn(self, op: Operation) -> str:
"""
Return the insn representation of the operation for printing.
"""
raise NotImplementedError()

0 comments on commit 772df11

Please sign in to comment.