Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects (arm): add assembly printing #3485

Merged
merged 7 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions tests/filecheck/dialects/arm/test_ops.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
// RUN: XDSL_ROUNDTRIP
// RUN: XDSL_GENERIC_ROUNDTRIP
// RUN: xdsl-opt -t arm-asm %s | filecheck %s --check-prefix=CHECK-ASM


// CHECK: %x1 = arm.get_register : !arm.reg<x1>
%x1 = arm.get_register : !arm.reg<x1>

// CHECK: %ds_mov = arm.ds.mov %x1 : (!arm.reg<x1>) -> !arm.reg<x2>
%ds_mov = arm.ds.mov %x1 : (!arm.reg<x1>) -> !arm.reg<x2>
// CHECK: %ds_mov = arm.ds.mov %x1 {"comment" = "move contents of s to d"} : (!arm.reg<x1>) -> !arm.reg<x2>
// CHECK-ASM: mov x2, x1 # move contents of s to d
%ds_mov = arm.ds.mov %x1 {"comment" = "move contents of s to d"} : (!arm.reg<x1>) -> !arm.reg<x2>

// CHECK-GENERIC: %x1 = "arm.get_register"() : () -> !arm.reg<x1>
// CHECK-GENERIC: %ds_mov = "arm.ds.mov"(%x1) : (!arm.reg<x1>) -> !arm.reg<x2>
// CHECK-GENERIC: %ds_mov = "arm.ds.mov"(%x1) {"comment" = "move contents of s to d"} : (!arm.reg<x1>) -> !arm.reg<x2>
14 changes: 13 additions & 1 deletion xdsl/dialects/arm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@
https://developer.arm.com/documentation/102374/0101/Overview
"""

from typing import IO

from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import Dialect

from .ops import DSMovOp, GetRegisterOp
from .ops import ARMOperation, DSMovOp, GetRegisterOp
from .register import IntRegisterType


def print_assembly(module: ModuleOp, output: IO[str]) -> None:
for op in module.body.walk():
assert isinstance(op, ARMOperation), f"{op}"
asm = op.assembly_line()
if asm is not None:
print(asm, file=output)


ARM = Dialect(
"arm",
[
Expand Down
38 changes: 38 additions & 0 deletions xdsl/dialects/arm/assembly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import TypeAlias

from xdsl.dialects.arm.register import ARMRegisterType
from xdsl.dialects.builtin import StringAttr
from xdsl.ir import SSAValue

AssemblyInstructionArg: TypeAlias = SSAValue


def append_comment(line: str, comment: StringAttr | None) -> str:
if comment is None:
return line

padding = " " * max(0, 48 - len(line))

return f"{line}{padding} # {comment.data}"


def assembly_arg_str(arg: AssemblyInstructionArg) -> str:
if isinstance(arg.type, ARMRegisterType):
reg = arg.type.register_name
return reg
else:
raise ValueError(f"Unexpected register type {arg.type}")


def assembly_line(
name: str,
arg_str: str,
comment: StringAttr | None = None,
is_indented: bool = True,
) -> str:
code = " " if is_indented else ""
code += name
if arg_str:
code += f" {arg_str}"
code = append_comment(code, comment)
return code
48 changes: 46 additions & 2 deletions xdsl/dialects/arm/ops.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import ABC
from abc import ABC, abstractmethod

from xdsl.dialects.builtin import StringAttr
from xdsl.ir import Operation, SSAValue
Expand All @@ -10,6 +10,7 @@
result_def,
)

from .assembly import AssemblyInstructionArg, assembly_arg_str, assembly_line
from .register import IntRegisterType


Expand All @@ -18,14 +19,51 @@ class ARMOperation(IRDLOperation, ABC):
Base class for operations that can be a part of ARM assembly printing.
"""

@abstractmethod
def assembly_line(self) -> str | None:
raise NotImplementedError()


class ARMInstruction(ARMOperation, ABC):
"""
alexarice marked this conversation as resolved.
Show resolved Hide resolved
Base class for operations that can be a part of x86 assembly printing. Must
represent an instruction in the x86 instruction set.
The name of the operation will be used as the x86 assembly instruction name.
"""

comment = opt_attr_def(StringAttr)
"""
An optional comment that will be printed along with the instruction.
"""

@abstractmethod
def assembly_line_args(self) -> tuple[AssemblyInstructionArg | None, ...]:
"""
The arguments to the instruction, in the order they should be printed in the
assembly.
"""
raise NotImplementedError()

def assembly_instruction_name(self) -> str:
"""
By default, the name of the instruction is the same as the name of the operation.
"""

return self.name.split(".")[-1]

def assembly_line(self) -> str | None:
# default assembly code generator
instruction_name = self.assembly_instruction_name()
arg_str = ", ".join(
assembly_arg_str(arg)
for arg in self.assembly_line_args()
if arg is not None
)
return assembly_line(instruction_name, arg_str, self.comment)


@irdl_op_definition
class DSMovOp(ARMOperation):
class DSMovOp(ARMInstruction):
"""
Copies the value of s into d.

Expand Down Expand Up @@ -56,6 +94,9 @@ def __init__(
result_types=(d,),
)

def assembly_line_args(self):
return (self.d, self.s)


@irdl_op_definition
class GetRegisterOp(ARMOperation):
Expand All @@ -70,3 +111,6 @@ class GetRegisterOp(ARMOperation):

def __init__(self, register_type: IntRegisterType):
super().__init__(result_types=[register_type])

def assembly_line(self):
return None
6 changes: 6 additions & 0 deletions xdsl/xdsl_opt_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,11 @@ def register_all_targets(self):
Add other/additional targets by overloading this function.
"""

def _output_arm_asm(prog: ModuleOp, output: IO[str]):
from xdsl.dialects.arm import print_assembly

print_assembly(prog, output)

def _output_mlir(prog: ModuleOp, output: IO[str]):
printer = Printer(
stream=output,
Expand Down Expand Up @@ -241,6 +246,7 @@ def _print_to_csl(prog: ModuleOp, output: IO[str]):

print_to_csl(prog, output)

self.available_targets["arm-asm"] = _output_arm_asm
self.available_targets["mlir"] = _output_mlir
self.available_targets["riscv-asm"] = _output_riscv_asm
self.available_targets["x86-asm"] = _output_x86_asm
Expand Down
Loading