Skip to content

Commit

Permalink
transformations: (riscv) handle f64 lowering in int to float conversi…
Browse files Browse the repository at this point in the history
…on (#2561)
  • Loading branch information
superlopuh authored May 13, 2024
1 parent ad73bdb commit cdbe2da
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 5 deletions.
4 changes: 2 additions & 2 deletions tests/filecheck/backend/csl/print_csl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func.func @initialize() {
%y = memref.get_global @y : memref<4xf32>

scf.for %idx = %lb to %ub step %step {
%idx_f32 = "arith.sitofp"(%idx) : (i16) -> f32
%idx_f32 = arith.sitofp %idx : i16 to f32
%idx_index = "arith.index_cast"(%idx) : (i16) -> index
memref.store %idx_f32, %A[%idx_index] : memref<24xf32>
}
Expand Down Expand Up @@ -64,7 +64,7 @@ func.func @initialize() {
// CHECK-NEXT: //unknown op GetGlobal(%b = memref.get_global @b : memref<4xf32>)
// CHECK-NEXT: //unknown op GetGlobal(%y = memref.get_global @y : memref<4xf32>)
// CHECK-NEXT: for(@range(i16, lb, ub, step)) |idx| {
// CHECK-NEXT: //unknown op SIToFPOp(%idx_f32 = "arith.sitofp"(%idx) : (i16) -> f32)
// CHECK-NEXT: //unknown op SIToFPOp(%idx_f32 = arith.sitofp %idx : i16 to f32
// CHECK-NEXT: //unknown op IndexCastOp(%idx_index = "arith.index_cast"(%idx) : (i16) -> index)
// CHECK-NEXT: //unknown op Store(memref.store %idx_f32, %A[%idx_index] : memref<24xf32>)
// CHECK-NEXT: //unknown op Yield(scf.yield)
Expand Down
6 changes: 4 additions & 2 deletions tests/filecheck/backend/riscv/convert_arith_to_riscv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,12 @@ builtin.module {
%maxf64_fm_contract = "arith.maximumf"(%lhsf64, %rhsf64) {"fastmath" = #arith.fastmath<contract>} : (f64, f64) -> f64
// CHECK-NEXT: %{{.*}} = riscv.fmax.d %lhsf64_reg, %rhsf64_reg fastmath<contract> : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<>

%sitofp = "arith.sitofp"(%lhsi32) : (i32) -> f32
%sitofp32 = arith.sitofp %lhsi32 : i32 to f32
// CHECK-NEXT: %{{.*}} = riscv.fcvt.s.w %lhsi32 : (!riscv.reg<>) -> !riscv.freg<>
%fptosi = "arith.fptosi"(%lhsf32) : (f32) -> i32
%fp32tosi = arith.fptosi %lhsf32 : f32 to i32
// CHECK-NEXT: %{{.*}} = riscv.fcvt.w.s %lhsf32_1 : (!riscv.freg<>) -> !riscv.reg<>
%sitofp64 = arith.sitofp %lhsi32 : i32 to f64
// CHECK-NEXT: %{{.*}} = riscv.fcvt.d.w %lhsi32 : (!riscv.reg<>) -> !riscv.freg<>

%cmpf0 = "arith.cmpf"(%lhsf32, %rhsf32) {"predicate" = 0 : i32} : (f32, f32) -> i1
// CHECK-NEXT: %{{.*}} = riscv.li 0 : () -> !riscv.reg<>
Expand Down
10 changes: 9 additions & 1 deletion xdsl/backend/riscv/lowering/convert_arith_to_riscv.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,12 +411,20 @@ def match_and_rewrite(self, op: arith.Cmpf, rewriter: PatternRewriter) -> None:
class LowerArithSIToFPOp(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: arith.SIToFPOp, rewriter: PatternRewriter) -> None:
match op.result.type:
case Float32Type():
cls = riscv.FCvtSWOp
case Float64Type():
cls = riscv.FCvtDWOp
case _:
assert False, f"Unexpected float type {op.result.type}"

rewriter.replace_matched_op(
(
cast_input := UnrealizedConversionCastOp.get(
(op.input,), (_INT_REGISTER_TYPE,)
),
new_op := riscv.FCvtSWOp(
new_op := cls(
cast_input.results[0], rd=riscv.FloatRegisterType.unallocated()
),
UnrealizedConversionCastOp.get((new_op.rd,), (op.result.type,)),
Expand Down
4 changes: 4 additions & 0 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,8 @@ class FPToSIOp(IRDLOperation):
input: Operand = operand_def(AnyFloat)
result: OpResult = result_def(IntegerType)

assembly_format = "$input attr-dict `:` type($input) `to` type($result)"

def __init__(self, op: SSAValue | Operation, target_type: IntegerType):
return super().__init__(operands=[op], result_types=[target_type])

Expand All @@ -885,6 +887,8 @@ class SIToFPOp(IRDLOperation):
input: Operand = operand_def(IntegerType)
result: OpResult = result_def(AnyFloat)

assembly_format = "$input attr-dict `:` type($input) `to` type($result)"

def __init__(self, op: SSAValue | Operation, target_type: AnyFloat):
return super().__init__(operands=[op], result_types=[target_type])

Expand Down

0 comments on commit cdbe2da

Please sign in to comment.