Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

misc: fix arguments for riscv_lowering #3900

Merged
merged 10 commits into from
Feb 15, 2025
40 changes: 20 additions & 20 deletions tests/filecheck/backend/riscv/convert_func_to_riscv_func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

builtin.module {
func.func @main() {
%0, %1 = "test.op"() : () -> (i32, i32)
%2, %3 = func.call @foo(%0, %1) : (i32, i32) -> (i32, i32)
%0, %1 = "test.op"() : () -> (i32, f32)
%2, %3 = func.call @foo(%0, %1) : (i32, f32) -> (i64, f64)
func.return
}

func.func @foo(%arg0 : i32, %arg1 : i32) -> (i32, i32) {
%res0, %res1 = "test.op"(%arg0, %arg1) : (i32, i32) -> (i32, i32)
func.return %res0, %res1 : i32, i32
func.func @foo(%arg0 : i32, %arg1 : f32) -> (i64, f64) {
%res0, %res1 = "test.op"(%arg0, %arg1) : (i32, f32) -> (i64, f64)
func.return %res0, %res1 : i64, f64
}

func.func @foo_float(%farg0 : f32, %farg1 : f32) -> (f32, f32) {
Expand All @@ -35,33 +35,33 @@ builtin.module {
// CHECK-NEXT: riscv.directive ".globl" "main"
// CHECK-NEXT: riscv.directive ".p2align" "2"
// CHECK-NEXT: riscv_func.func @main() {
// CHECK-NEXT: %0, %1 = "test.op"() : () -> (i32, i32)
// CHECK-NEXT: %0, %1 = "test.op"() : () -> (i32, f32)
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %0 : i32 to !riscv.reg
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %1 : i32 to !riscv.reg
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %1 : f32 to !riscv.freg
// CHECK-NEXT: %{{.*}} = riscv.mv %{{.*}} : (!riscv.reg) -> !riscv.reg<a0>
// CHECK-NEXT: %{{.*}} = riscv.mv %{{.*}} : (!riscv.reg) -> !riscv.reg<a1>
// CHECK-NEXT: %{{.*}}, %{{.*}} = riscv_func.call @foo(%{{.*}}, %{{.*}}) : (!riscv.reg<a0>, !riscv.reg<a1>) -> (!riscv.reg<a0>, !riscv.reg<a1>)
// CHECK-NEXT: %{{.*}} = riscv.fmv.s %{{.*}} : (!riscv.freg) -> !riscv.freg<fa0>
// CHECK-NEXT: %{{.*}}, %{{.*}} = riscv_func.call @foo(%{{.*}}, %{{.*}}) : (!riscv.reg<a0>, !riscv.freg<fa0>) -> (!riscv.reg<a0>, !riscv.freg<fa0>)
// CHECK-NEXT: %{{.*}} = riscv.mv %{{.*}} : (!riscv.reg<a0>) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.mv %{{.*}} : (!riscv.reg<a1>) -> !riscv.reg
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !riscv.reg to i32
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !riscv.reg to i32
// CHECK-NEXT: %{{.*}} = riscv.fmv.d %{{.*}} : (!riscv.freg<fa0>) -> !riscv.freg
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !riscv.reg to i64
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %{{.*}} : !riscv.freg to f64
// CHECK-NEXT: riscv_func.return
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: riscv.assembly_section ".text" {
// CHECK-NEXT: riscv.directive ".globl" "foo"
// CHECK-NEXT: riscv.directive ".p2align" "2"
// CHECK-NEXT: riscv_func.func @foo(%arg0 : !riscv.reg<a0>, %arg1 : !riscv.reg<a1>) -> (!riscv.reg<a0>, !riscv.reg<a1>) {
// CHECK-NEXT: riscv_func.func @foo(%arg0 : !riscv.reg<a0>, %arg1 : !riscv.freg<fa0>) -> (!riscv.reg<a0>, !riscv.freg<fa0>) {
// CHECK-NEXT: %{{.*}} = riscv.mv %arg0 : (!riscv.reg<a0>) -> !riscv.reg
// CHECK-NEXT: %arg0_1 = builtin.unrealized_conversion_cast %{{.*}} : !riscv.reg to i32
// CHECK-NEXT: %{{.*}} = riscv.mv %arg1 : (!riscv.reg<a1>) -> !riscv.reg
// CHECK-NEXT: %arg1_1 = builtin.unrealized_conversion_cast %{{.*}} : !riscv.reg to i32
// CHECK-NEXT: %res0, %res1 = "test.op"(%arg0_1, %arg1_1) : (i32, i32) -> (i32, i32)
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %res0 : i32 to !riscv.reg
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %res1 : i32 to !riscv.reg
// CHECK-NEXT: %{{.*}} = riscv.fmv.s %arg1 : (!riscv.freg<fa0>) -> !riscv.freg
// CHECK-NEXT: %arg1_1 = builtin.unrealized_conversion_cast %{{.*}} : !riscv.freg to f32
// CHECK-NEXT: %res0, %res1 = "test.op"(%arg0_1, %arg1_1) : (i32, f32) -> (i64, f64)
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %res0 : i64 to !riscv.reg
// CHECK-NEXT: %{{.*}} = builtin.unrealized_conversion_cast %res1 : f64 to !riscv.freg
// CHECK-NEXT: %{{.*}} = riscv.mv %{{.*}} : (!riscv.reg) -> !riscv.reg<a0>
// CHECK-NEXT: %{{.*}} = riscv.mv %{{.*}} : (!riscv.reg) -> !riscv.reg<a1>
// CHECK-NEXT: riscv_func.return %{{.*}}, %{{.*}} : !riscv.reg<a0>, !riscv.reg<a1>
// CHECK-NEXT: %{{.*}} = riscv.fmv.d %{{.*}} : (!riscv.freg) -> !riscv.freg<fa0>
// CHECK-NEXT: riscv_func.return %{{.*}}, %{{.*}} : !riscv.reg<a0>, !riscv.freg<fa0>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: riscv.assembly_section ".text" {
Expand Down
5 changes: 4 additions & 1 deletion xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,13 @@ def match_and_rewrite(self, op: func.CallOp, rewriter: PatternRewriter) -> None:
move_operand_ops, moved_operands = move_to_a_regs(
register_operands, operand_types
)

new_result_value_types = [result.type for result in op.results]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a bug here that makes it raise an error for floating point registers:

elif isinstance(rd, riscv.FloatRegisterType):
match value_type:
case builtin.Float64Type():
mv_op = riscv.FMvDOp(value, rd=rd)
case builtin.Float32Type():
mv_op = riscv.FMVOp(value, rd=rd)
case _:
raise NotImplementedError(
f"Move operation for float register containing value of type {value.type} is not implemented"
)
return mv_op, mv_op.rd

value_type in the original code is neither of the two.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a FloatRegisterType.

new_result_types = list(a_regs(op.results))
new_op = riscv_func.CallOp(op.callee, moved_operands, new_result_types)

move_result_ops, moved_results = move_to_unallocated_regs(
new_op.results, operand_types
new_op.results, new_result_value_types
)
cast_result_ops = [
UnrealizedConversionCastOp.get((moved_result,), (old_result.type,))
Expand Down
Loading