Skip to content

Commit 7e647e4

Browse files
committed
dialects (riscv): add fastmath flag RdRsRs Float Float Int operations, with filecheck test
1 parent f2ba5f7 commit 7e647e4

File tree

2 files changed

+75
-4
lines changed

2 files changed

+75
-4
lines changed

tests/filecheck/dialects/riscv/riscv_ops.mlir

+9
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,12 @@
273273
// CHECK-NEXT: %{{.*}} = riscv.flt.s %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.reg
274274
%fle_s = riscv.fle.s %f0, %f1 : (!riscv.freg, !riscv.freg) -> !riscv.reg
275275
// CHECK-NEXT: %{{.*}} = riscv.fle.s %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.reg
276+
%feq_s_fm = riscv.feq.s %f0, %f1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
277+
// CHECK-NEXT: %{{.*}} = riscv.feq.s %{{.*}}, %{{.*}} fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
278+
%flt_s_fm = riscv.flt.s %f0, %f1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
279+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %{{.*}}, %{{.*}} fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
280+
%fle_s_fm = riscv.fle.s %f0, %f1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
281+
// CHECK-NEXT: %{{.*}} = riscv.fle.s %{{.*}}, %{{.*}} fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
276282
%fclass_s = riscv.fclass.s %f0 : (!riscv.freg) -> !riscv.reg
277283
// CHECK-NEXT: %{{.*}} = riscv.fclass.s %{{.*}} : (!riscv.freg) -> !riscv.reg
278284
%fcvt_s_w = riscv.fcvt.s.w %0 : (!riscv.reg) -> !riscv.freg
@@ -460,6 +466,9 @@
460466
// CHECK-GENERIC-NEXT: %feq_s = "riscv.feq.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg
461467
// CHECK-GENERIC-NEXT: %flt_s = "riscv.flt.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg
462468
// CHECK-GENERIC-NEXT: %fle_s = "riscv.fle.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg
469+
// CHECK-GENERIC-NEXT: %feq_s_fm = "riscv.feq.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<fast>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
470+
// CHECK-GENERIC-NEXT: %flt_s_fm = "riscv.flt.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<fast>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
471+
// CHECK-GENERIC-NEXT: %fle_s_fm = "riscv.fle.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<fast>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
463472
// CHECK-GENERIC-NEXT: %fclass_s = "riscv.fclass.s"(%f0) : (!riscv.freg) -> !riscv.reg
464473
// CHECK-GENERIC-NEXT: %fcvt_s_w = "riscv.fcvt.s.w"(%0) : (!riscv.reg) -> !riscv.freg
465474
// CHECK-GENERIC-NEXT: %fcvt_s_wu = "riscv.fcvt.s.wu"(%0) : (!riscv.reg) -> !riscv.freg

xdsl/dialects/riscv.py

+66-4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
irdl_op_definition,
4444
operand_def,
4545
opt_attr_def,
46+
opt_prop_def,
4647
region_def,
4748
result_def,
4849
var_operand_def,
@@ -2532,7 +2533,7 @@ def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
25322533
def custom_print_attributes(self, printer: Printer) -> Set[str]:
25332534
printer.print(" ")
25342535
print_immediate_value(printer, self.immediate)
2535-
return {"immediate"}
2536+
return {"immediate", "fastmath"}
25362537

25372538
@classmethod
25382539
def parse_op_type(
@@ -3016,6 +3017,67 @@ def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
30163017
return self.rd, self.rs1, self.rs2
30173018

30183019

3020+
class RdRsRsFloatFloatIntegerOperationWithFastMath(RISCVInstruction, ABC):
3021+
"""
3022+
A base class for RISC-V operations that have two source floating-point
3023+
registers with an integer destination register, and can be annotated with fastmath flags.
3024+
3025+
This is called R-Type in the RISC-V specification.
3026+
"""
3027+
3028+
rd = result_def(IntRegisterType)
3029+
rs1 = operand_def(FloatRegisterType)
3030+
rs2 = operand_def(FloatRegisterType)
3031+
fastmath = opt_prop_def(FastMathFlagsAttr)
3032+
3033+
def __init__(
3034+
self,
3035+
rs1: Operation | SSAValue,
3036+
rs2: Operation | SSAValue,
3037+
*,
3038+
rd: IntRegisterType | str | None = None,
3039+
fastmath: FastMathFlagsAttr = FastMathFlagsAttr("none"),
3040+
comment: str | StringAttr | None = None,
3041+
):
3042+
if rd is None:
3043+
rd = IntRegisterType.unallocated()
3044+
elif isinstance(rd, str):
3045+
rd = IntRegisterType(rd)
3046+
if isinstance(comment, str):
3047+
comment = StringAttr(comment)
3048+
3049+
super().__init__(
3050+
operands=[rs1, rs2],
3051+
properties={
3052+
"fastmath": fastmath,
3053+
},
3054+
attributes={
3055+
"comment": comment,
3056+
},
3057+
result_types=[rd],
3058+
)
3059+
3060+
def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
3061+
return self.rd, self.rs1, self.rs2
3062+
3063+
@classmethod
3064+
def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
3065+
attributes = dict[str, Attribute]()
3066+
fast = FastMathFlagsAttr("none")
3067+
if parser.parse_optional_keyword("fastmath") is not None:
3068+
fast = FastMathFlagsAttr(FastMathFlagsAttr.parse_parameter(parser))
3069+
if fast != FastMathFlagsAttr("none"):
3070+
attributes["fastmath"] = fast
3071+
cls.fastmath = fast
3072+
return attributes
3073+
3074+
def custom_print_attributes(self, printer: Printer) -> Set[str]:
3075+
if self.fastmath is not None and self.fastmath != FastMathFlagsAttr("none"):
3076+
printer.print(" fastmath")
3077+
self.fastmath.print_parameter(printer)
3078+
return {"fastmath"}
3079+
3080+
30193081
class RsRsImmFloatOperation(RISCVInstruction, ABC):
30203082
"""
30213083
A base class for RV32F operations that have two source registers
@@ -3352,7 +3414,7 @@ class FMvXWOp(RdRsOperation[IntRegisterType, FloatRegisterType]):
33523414

33533415

33543416
@irdl_op_definition
3355-
class FeqSOP(RdRsRsFloatFloatIntegerOperation):
3417+
class FeqSOP(RdRsRsFloatFloatIntegerOperationWithFastMath):
33563418
"""
33573419
Performs a quiet equal comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd.
33583420
Only signaling NaN inputs cause an Invalid Operation exception.
@@ -3367,7 +3429,7 @@ class FeqSOP(RdRsRsFloatFloatIntegerOperation):
33673429

33683430

33693431
@irdl_op_definition
3370-
class FltSOP(RdRsRsFloatFloatIntegerOperation):
3432+
class FltSOP(RdRsRsFloatFloatIntegerOperationWithFastMath):
33713433
"""
33723434
Performs a quiet less comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd.
33733435
Only signaling NaN inputs cause an Invalid Operation exception.
@@ -3382,7 +3444,7 @@ class FltSOP(RdRsRsFloatFloatIntegerOperation):
33823444

33833445

33843446
@irdl_op_definition
3385-
class FleSOP(RdRsRsFloatFloatIntegerOperation):
3447+
class FleSOP(RdRsRsFloatFloatIntegerOperationWithFastMath):
33863448
"""
33873449
Performs a quiet less or equal comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd.
33883450
Only signaling NaN inputs cause an Invalid Operation exception.

0 commit comments

Comments
 (0)