Skip to content

Commit 1af6acd

Browse files
committed
dialects (riscv): add fastmath flag to Li and RdRsRs Float Float Int operations, with filecheck test
1 parent b205332 commit 1af6acd

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

tests/filecheck/dialects/riscv/riscv_ops.mlir

+12
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@
175175

176176
%li = riscv.li 1 : !riscv.reg
177177
// CHECK-NEXT: %{{.*}} = riscv.li 1 : !riscv.reg
178+
%li_fm = riscv.li 1 fastmath<fast> : !riscv.reg
179+
// CHECK-NEXT: %{{.*}} = riscv.li 1 fastmath<fast> : !riscv.reg
178180
// Environment Call and Breakpoints
179181
riscv.ecall
180182
// CHECK-NEXT: riscv.ecall
@@ -273,6 +275,12 @@
273275
// CHECK-NEXT: %{{.*}} = riscv.flt.s %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.reg
274276
%fle_s = riscv.fle.s %f0, %f1 : (!riscv.freg, !riscv.freg) -> !riscv.reg
275277
// CHECK-NEXT: %{{.*}} = riscv.fle.s %{{.*}}, %{{.*}} : (!riscv.freg, !riscv.freg) -> !riscv.reg
278+
%feq_s_fm = riscv.feq.s %f0, %f1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
279+
// CHECK-NEXT: %{{.*}} = riscv.feq.s %{{.*}}, %{{.*}} fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
280+
%flt_s_fm = riscv.flt.s %f0, %f1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
281+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %{{.*}}, %{{.*}} fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
282+
%fle_s_fm = riscv.fle.s %f0, %f1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
283+
// CHECK-NEXT: %{{.*}} = riscv.fle.s %{{.*}}, %{{.*}} fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
276284
%fclass_s = riscv.fclass.s %f0 : (!riscv.freg) -> !riscv.reg
277285
// CHECK-NEXT: %{{.*}} = riscv.fclass.s %{{.*}} : (!riscv.freg) -> !riscv.reg
278286
%fcvt_s_w = riscv.fcvt.s.w %0 : (!riscv.reg) -> !riscv.freg
@@ -419,6 +427,7 @@
419427
// CHECK-GENERIC-NEXT: %rem = "riscv.rem"(%0, %1) : (!riscv.reg, !riscv.reg) -> !riscv.reg
420428
// CHECK-GENERIC-NEXT: %remu = "riscv.remu"(%0, %1) : (!riscv.reg, !riscv.reg) -> !riscv.reg
421429
// CHECK-GENERIC-NEXT: %li = "riscv.li"() {"immediate" = 1 : i32} : () -> !riscv.reg
430+
// CHECK-GENERIC-NEXT: %li_fm = "riscv.li"() {"immediate" = 1 : i32, "fastmath" = #riscv.fastmath<fast>} : () -> !riscv.reg
422431
// CHECK-GENERIC-NEXT: "riscv.ecall"() : () -> ()
423432
// CHECK-GENERIC-NEXT: "riscv.ebreak"() : () -> ()
424433
// CHECK-GENERIC-NEXT: "riscv.directive"() {"directive" = ".bss"} : () -> ()
@@ -460,6 +469,9 @@
460469
// CHECK-GENERIC-NEXT: %feq_s = "riscv.feq.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg
461470
// CHECK-GENERIC-NEXT: %flt_s = "riscv.flt.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg
462471
// CHECK-GENERIC-NEXT: %fle_s = "riscv.fle.s"(%f0, %f1) : (!riscv.freg, !riscv.freg) -> !riscv.reg
472+
// CHECK-GENERIC-NEXT: %feq_s_fm = "riscv.feq.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<fast>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
473+
// CHECK-GENERIC-NEXT: %flt_s_fm = "riscv.flt.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<fast>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
474+
// CHECK-GENERIC-NEXT: %fle_s_fm = "riscv.fle.s"(%f0, %f1) {"fastmath" = #riscv.fastmath<fast>} : (!riscv.freg, !riscv.freg) -> !riscv.reg
463475
// CHECK-GENERIC-NEXT: %fclass_s = "riscv.fclass.s"(%f0) : (!riscv.freg) -> !riscv.reg
464476
// CHECK-GENERIC-NEXT: %fcvt_s_w = "riscv.fcvt.s.w"(%0) : (!riscv.reg) -> !riscv.freg
465477
// CHECK-GENERIC-NEXT: %fcvt_s_wu = "riscv.fcvt.s.wu"(%0) : (!riscv.reg) -> !riscv.freg

xdsl/dialects/riscv.py

+76-4
Original file line numberDiff line numberDiff line change
@@ -2491,6 +2491,7 @@ class LiOp(RISCVInstruction, ABC):
24912491

24922492
rd = result_def(IntRegisterType)
24932493
immediate = attr_def(base(Imm32Attr) | base(LabelAttr))
2494+
fastmath = opt_attr_def(FastMathFlagsAttr)
24942495

24952496
traits = frozenset((Pure(), ConstantLike(), LiOpHasCanonicalizationPatternTrait()))
24962497

@@ -2499,6 +2500,7 @@ def __init__(
24992500
immediate: int | Imm32Attr | str | LabelAttr,
25002501
*,
25012502
rd: IntRegisterType | str | None = None,
2503+
fastmath: FastMathFlagsAttr | None = None,
25022504
comment: str | StringAttr | None = None,
25032505
):
25042506
if isinstance(immediate, int):
@@ -2517,6 +2519,7 @@ def __init__(
25172519
attributes={
25182520
"immediate": immediate,
25192521
"comment": comment,
2522+
"fastmath": fastmath,
25202523
},
25212524
)
25222525

@@ -2527,12 +2530,22 @@ def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
25272530
def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
25282531
attributes = dict[str, Attribute]()
25292532
attributes["immediate"] = parse_immediate_value(parser, i32)
2533+
fast = FastMathFlagsAttr("none")
2534+
if parser.parse_optional_keyword("fastmath") is not None:
2535+
fast = FastMathFlagsAttr(FastMathFlagsAttr.parse_parameter(parser))
2536+
attributes["fastmath"] = fast
25302537
return attributes
25312538

25322539
def custom_print_attributes(self, printer: Printer) -> Set[str]:
25332540
printer.print(" ")
25342541
print_immediate_value(printer, self.immediate)
2535-
return {"immediate"}
2542+
2543+
if "fastmath" in self.attributes and self.attributes[
2544+
"fastmath"
2545+
] != FastMathFlagsAttr("none"):
2546+
printer.print(" fastmath")
2547+
self.fastmath.print_parameter(printer)
2548+
return {"immediate", "fastmath"}
25362549

25372550
@classmethod
25382551
def parse_op_type(
@@ -3016,6 +3029,65 @@ def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
30163029
return self.rd, self.rs1, self.rs2
30173030

30183031

3032+
class RdRsRsFloatFloatIntegerOperationWithFastMath(RISCVInstruction, ABC):
3033+
"""
3034+
A base class for RISC-V operations that have two source floating-point
3035+
registers with an integer destination register, and can be annotated with fastmath flags.
3036+
3037+
This is called R-Type in the RISC-V specification.
3038+
"""
3039+
3040+
rd = result_def(IntRegisterType)
3041+
rs1 = operand_def(FloatRegisterType)
3042+
rs2 = operand_def(FloatRegisterType)
3043+
fastmath = opt_attr_def(FastMathFlagsAttr)
3044+
3045+
def __init__(
3046+
self,
3047+
rs1: Operation | SSAValue,
3048+
rs2: Operation | SSAValue,
3049+
*,
3050+
rd: IntRegisterType | str | None = None,
3051+
fastmath: FastMathFlagsAttr | None = None,
3052+
comment: str | StringAttr | None = None,
3053+
):
3054+
if rd is None:
3055+
rd = IntRegisterType.unallocated()
3056+
elif isinstance(rd, str):
3057+
rd = IntRegisterType(rd)
3058+
if isinstance(comment, str):
3059+
comment = StringAttr(comment)
3060+
3061+
super().__init__(
3062+
operands=[rs1, rs2],
3063+
attributes={
3064+
"fastmath": fastmath,
3065+
"comment": comment,
3066+
},
3067+
result_types=[rd],
3068+
)
3069+
3070+
def assembly_line_args(self) -> tuple[AssemblyInstructionArg, ...]:
3071+
return self.rd, self.rs1, self.rs2
3072+
3073+
@classmethod
3074+
def custom_parse_attributes(cls, parser: Parser) -> dict[str, Attribute]:
3075+
attributes = dict[str, Attribute]()
3076+
flags = FastMathFlagsAttr("none")
3077+
if parser.parse_optional_keyword("fastmath") is not None:
3078+
flags = FastMathFlagsAttr(FastMathFlagsAttr.parse_parameter(parser))
3079+
if flags != FastMathFlagsAttr("none"):
3080+
attributes["fastmath"] = flags
3081+
cls.fastmath = flags
3082+
return attributes
3083+
3084+
def custom_print_attributes(self, printer: Printer) -> Set[str]:
3085+
if self.fastmath is not None and self.fastmath != FastMathFlagsAttr("none"):
3086+
printer.print(" fastmath")
3087+
self.fastmath.print_parameter(printer)
3088+
return {"fastmath"}
3089+
3090+
30193091
class RsRsImmFloatOperation(RISCVInstruction, ABC):
30203092
"""
30213093
A base class for RV32F operations that have two source registers
@@ -3352,7 +3424,7 @@ class FMvXWOp(RdRsOperation[IntRegisterType, FloatRegisterType]):
33523424

33533425

33543426
@irdl_op_definition
3355-
class FeqSOP(RdRsRsFloatFloatIntegerOperation):
3427+
class FeqSOP(RdRsRsFloatFloatIntegerOperationWithFastMath):
33563428
"""
33573429
Performs a quiet equal comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd.
33583430
Only signaling NaN inputs cause an Invalid Operation exception.
@@ -3367,7 +3439,7 @@ class FeqSOP(RdRsRsFloatFloatIntegerOperation):
33673439

33683440

33693441
@irdl_op_definition
3370-
class FltSOP(RdRsRsFloatFloatIntegerOperation):
3442+
class FltSOP(RdRsRsFloatFloatIntegerOperationWithFastMath):
33713443
"""
33723444
Performs a quiet less comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd.
33733445
Only signaling NaN inputs cause an Invalid Operation exception.
@@ -3382,7 +3454,7 @@ class FltSOP(RdRsRsFloatFloatIntegerOperation):
33823454

33833455

33843456
@irdl_op_definition
3385-
class FleSOP(RdRsRsFloatFloatIntegerOperation):
3457+
class FleSOP(RdRsRsFloatFloatIntegerOperationWithFastMath):
33863458
"""
33873459
Performs a quiet less or equal comparison between floating-point registers rs1 and rs2 and record the Boolean result in integer register rd.
33883460
Only signaling NaN inputs cause an Invalid Operation exception.

0 commit comments

Comments
 (0)