Skip to content

Commit d67b37d

Browse files
knickishsuperlopuh
andauthored
Arith to riscv lowering fastmath cmpf (#3277)
`backend`(lowering)-only part of #3272. Depends on #3275 and #3276. This should close #2725 once merged Co-authored-by: Sasha Lopoukhine <superlopuh@gmail.com>
1 parent 163fde2 commit d67b37d

File tree

2 files changed

+65
-18
lines changed

2 files changed

+65
-18
lines changed

tests/filecheck/backend/riscv/convert_arith_to_riscv.mlir

+45
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,51 @@ builtin.module {
219219
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf14_2, 1 : (!riscv.reg) -> !riscv.reg
220220
%cmpf15 = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 15 : i32} : (f32, f32) -> i1
221221
// CHECK-NEXT: %{{.*}} = riscv.li 1 : !riscv.reg
222+
223+
// tests with fastmath flags when set to "fast"
224+
%cmpf1_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 1 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
225+
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
226+
%cmpf2_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 2 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
227+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
228+
%cmpf3_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 3 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
229+
// CHECK-NEXT: %{{.*}} = riscv.fle.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
230+
%cmpf4_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 4 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
231+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
232+
%cmpf5_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 5 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
233+
// CHECK-NEXT: %{{.*}} = riscv.fle.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
234+
%cmpf6_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 6 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
235+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
236+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
237+
// CHECK-NEXT: %{{.*}} = riscv.or %cmpf6_fm_1, %cmpf6_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
238+
%cmpf7_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 7 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
239+
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
240+
// CHECK-NEXT: %{{.*}} = riscv.feq.s %rhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
241+
// CHECK-NEXT: %{{.*}} = riscv.and %cmpf7_fm_1, %cmpf7_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
242+
%cmpf8_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 8 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
243+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
244+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
245+
// CHECK-NEXT: %{{.*}} = riscv.or %cmpf8_fm_1, %cmpf8_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
246+
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf8_fm_2, 1 : (!riscv.reg) -> !riscv.reg
247+
%cmpf9_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 9 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
248+
// CHECK-NEXT: %{{.*}} = riscv.fle.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
249+
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf9_fm, 1 : (!riscv.reg) -> !riscv.reg
250+
%cmpf10_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 10 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
251+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
252+
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf10_fm, 1 : (!riscv.reg) -> !riscv.reg
253+
%cmpf11_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 11 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
254+
// CHECK-NEXT: %{{.*}} = riscv.fle.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
255+
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf11_fm, 1 : (!riscv.reg) -> !riscv.reg
256+
%cmpf12_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 12 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
257+
// CHECK-NEXT: %{{.*}} = riscv.flt.s %rhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
258+
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf12_fm, 1 : (!riscv.reg) -> !riscv.reg
259+
%cmpf13_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 13 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
260+
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
261+
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf13_fm, 1 : (!riscv.reg) -> !riscv.reg
262+
%cmpf14_fm = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 14 : i32, "fastmath" = #arith.fastmath<fast>} : (f32, f32) -> i1
263+
// CHECK-NEXT: %{{.*}} = riscv.feq.s %lhsf32_1, %lhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
264+
// CHECK-NEXT: %{{.*}} = riscv.feq.s %rhsf32_1, %rhsf32_1 fastmath<fast> : (!riscv.freg, !riscv.freg) -> !riscv.reg
265+
// CHECK-NEXT: %{{.*}} = riscv.and %cmpf14_fm_1, %cmpf14_fm : (!riscv.reg, !riscv.reg) -> !riscv.reg
266+
// CHECK-NEXT: %{{.*}} = riscv.xori %cmpf14_fm_2, 1 : (!riscv.reg) -> !riscv.reg
222267
%index_cast = "arith.index_cast"(%lhsindex) : (index) -> i32
223268
// CHECK-NEXT: }
224269
}

xdsl/backend/riscv/lowering/convert_arith_to_riscv.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -352,29 +352,31 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None:
352352
lhs, rhs = cast_operands_to_regs(rewriter)
353353
cast_matched_op_results(rewriter)
354354

355+
fastmath = riscv.FastMathFlagsAttr(op.fastmath.data)
356+
355357
match op.predicate.value.data:
356358
# false
357359
case 0:
358360
rewriter.replace_matched_op([riscv.LiOp(0)])
359361
# oeq
360362
case 1:
361-
rewriter.replace_matched_op([riscv.FeqSOP(lhs, rhs)])
363+
rewriter.replace_matched_op([riscv.FeqSOP(lhs, rhs, fastmath=fastmath)])
362364
# ogt
363365
case 2:
364-
rewriter.replace_matched_op([riscv.FltSOP(rhs, lhs)])
366+
rewriter.replace_matched_op([riscv.FltSOP(rhs, lhs, fastmath=fastmath)])
365367
# oge
366368
case 3:
367-
rewriter.replace_matched_op([riscv.FleSOP(rhs, lhs)])
369+
rewriter.replace_matched_op([riscv.FleSOP(rhs, lhs, fastmath=fastmath)])
368370
# olt
369371
case 4:
370-
rewriter.replace_matched_op([riscv.FltSOP(lhs, rhs)])
372+
rewriter.replace_matched_op([riscv.FltSOP(lhs, rhs, fastmath=fastmath)])
371373
# ole
372374
case 5:
373-
rewriter.replace_matched_op([riscv.FleSOP(lhs, rhs)])
375+
rewriter.replace_matched_op([riscv.FleSOP(lhs, rhs, fastmath=fastmath)])
374376
# one
375377
case 6:
376-
flt1 = riscv.FltSOP(lhs, rhs)
377-
flt2 = riscv.FltSOP(rhs, lhs)
378+
flt1 = riscv.FltSOP(lhs, rhs, fastmath=fastmath)
379+
flt2 = riscv.FltSOP(rhs, lhs, fastmath=fastmath)
378380
rewriter.replace_matched_op(
379381
[
380382
flt1,
@@ -384,8 +386,8 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None:
384386
)
385387
# ord
386388
case 7:
387-
feq1 = riscv.FeqSOP(lhs, lhs)
388-
feq2 = riscv.FeqSOP(rhs, rhs)
389+
feq1 = riscv.FeqSOP(lhs, lhs, fastmath=fastmath)
390+
feq2 = riscv.FeqSOP(rhs, rhs, fastmath=fastmath)
389391
rewriter.replace_matched_op(
390392
[
391393
feq1,
@@ -395,34 +397,34 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None:
395397
)
396398
# ueq
397399
case 8:
398-
flt1 = riscv.FltSOP(lhs, rhs)
399-
flt2 = riscv.FltSOP(rhs, lhs)
400+
flt1 = riscv.FltSOP(lhs, rhs, fastmath=fastmath)
401+
flt2 = riscv.FltSOP(rhs, lhs, fastmath=fastmath)
400402
or_ = riscv.OrOp(flt2, flt1, rd=riscv.IntRegisterType.unallocated())
401403
rewriter.replace_matched_op([flt1, flt2, or_, riscv.XoriOp(or_, 1)])
402404
# ugt
403405
case 9:
404-
fle = riscv.FleSOP(lhs, rhs)
406+
fle = riscv.FleSOP(lhs, rhs, fastmath=fastmath)
405407
rewriter.replace_matched_op([fle, riscv.XoriOp(fle, 1)])
406408
# uge
407409
case 10:
408-
fle = riscv.FltSOP(lhs, rhs)
410+
fle = riscv.FltSOP(lhs, rhs, fastmath=fastmath)
409411
rewriter.replace_matched_op([fle, riscv.XoriOp(fle, 1)])
410412
# ult
411413
case 11:
412-
fle = riscv.FleSOP(rhs, lhs)
414+
fle = riscv.FleSOP(rhs, lhs, fastmath=fastmath)
413415
rewriter.replace_matched_op([fle, riscv.XoriOp(fle, 1)])
414416
# ule
415417
case 12:
416-
flt = riscv.FltSOP(rhs, lhs)
418+
flt = riscv.FltSOP(rhs, lhs, fastmath=fastmath)
417419
rewriter.replace_matched_op([flt, riscv.XoriOp(flt, 1)])
418420
# une
419421
case 13:
420-
feq = riscv.FeqSOP(lhs, rhs)
422+
feq = riscv.FeqSOP(lhs, rhs, fastmath=fastmath)
421423
rewriter.replace_matched_op([feq, riscv.XoriOp(feq, 1)])
422424
# uno
423425
case 14:
424-
feq1 = riscv.FeqSOP(lhs, lhs)
425-
feq2 = riscv.FeqSOP(rhs, rhs)
426+
feq1 = riscv.FeqSOP(lhs, lhs, fastmath=fastmath)
427+
feq2 = riscv.FeqSOP(rhs, rhs, fastmath=fastmath)
426428
and_ = riscv.AndOp(feq2, feq1, rd=riscv.IntRegisterType.unallocated())
427429
rewriter.replace_matched_op([feq1, feq2, and_, riscv.XoriOp(and_, 1)])
428430
# true

0 commit comments

Comments
 (0)