Skip to content

Commit

Permalink
testing: no more inout parameters in bottom-up test
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed May 12, 2024
1 parent 2cecbd2 commit 32b2396
Showing 1 changed file with 106 additions and 76 deletions.
182 changes: 106 additions & 76 deletions tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
// RUN: xdsl-opt -p convert-arith-to-riscv,convert-func-to-riscv-func,convert-memref-stream-to-snitch,reconcile-unrealized-casts,test-lower-snitch-stream-to-asm -t riscv-asm %s | filecheck %s

// x[ M x K ]
// y[ K x N ]
// g[ M x N ]
func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
%X: memref<1x1x8x8xf64>,
%Y: memref<1x1x3x3xf64>,
Expand All @@ -21,21 +18,23 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
%c9_val = arith.constant 9 : i32
%c9 = builtin.unrealized_conversion_cast %c9_val : i32 to !riscv.reg<>

%zero_float = arith.sitofp %c0_val : i32 to f64
%zero_reg = builtin.unrealized_conversion_cast %zero_float : f64 to !riscv.freg<>

memref_stream.streaming_region {
patterns = [
#memref_stream.stride_pattern<ub = [1, 1, 6, 6, 1, 3, 3], index_map = (d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>,
#memref_stream.stride_pattern<ub = [1, 1, 6, 6, 1, 3, 3], index_map = (d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
#memref_stream.stride_pattern<ub = [1, 1, 6, 6, 1, 3, 3], index_map = (d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>,
#memref_stream.stride_pattern<ub = [1, 1, 6, 6], index_map = (d0, d1, d2, d3) -> (d0, d1, d2, d3)>
]
} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) {
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.readable<f64>):
} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) {
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.readable<f64>, %z_stream : !stream.writable<f64>):

%c288_val = arith.constant 288 : i32
%c288 = builtin.unrealized_conversion_cast %c288_val : i32 to !riscv.reg<>
riscv_scf.for %z_i : !riscv.reg<> = %c0 to %c288 step %c8 {
%Z_dest = riscv.add %Z_moved, %z_i : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
%c = riscv.fld %Z_dest, 0 : (!riscv.reg<>) -> !riscv.freg<>

%z = riscv_scf.for %i : !riscv.reg<> = %c0 to %c9 step %c1 iter_args(%acc = %c) -> (!riscv.freg<>) {
%init = riscv.fmv.d %zero_reg : (!riscv.freg<>) -> !riscv.freg<>
%z = riscv_scf.for %i : !riscv.reg<> = %c0 to %c9 step %c1 iter_args(%acc = %init) -> (!riscv.freg<>) {
%x = memref_stream.read from %x_stream : f64
%y = memref_stream.read from %y_stream : f64
%acc_val = builtin.unrealized_conversion_cast %acc : !riscv.freg<> to f64
Expand All @@ -45,7 +44,9 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
riscv_scf.yield %res : !riscv.freg<>
}

riscv.fsd %Z_dest, %z, 0 : (!riscv.reg<>, !riscv.freg<>) -> ()
%res = riscv.fmv.d %z : (!riscv.freg<>) -> !riscv.freg<>
%z_val = builtin.unrealized_conversion_cast %res : !riscv.freg<> to f64
memref_stream.write %z_val to %z_stream : f64

riscv_scf.yield
}
Expand All @@ -60,8 +61,9 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
// CHECK-NEXT: .p2align 2
// CHECK-NEXT: conv_2d_nchw_fchw_d1_s1_3x3:
// CHECK-NEXT: mv t4, a0
// CHECK-NEXT: mv t2, a1
// CHECK-NEXT: mv t0, a2
// CHECK-NEXT: mv t3, a1
// CHECK-NEXT: mv t1, a2
// CHECK-NEXT: fcvt.d.w ft3, zero
// CHECK-NEXT: li t5, 8
// CHECK-NEXT: li a5, 2
// CHECK-NEXT: li t6, 2
Expand Down Expand Up @@ -90,21 +92,25 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
// CHECK-NEXT: scfgwi t5, 225
// CHECK-NEXT: li t5, -64
// CHECK-NEXT: scfgwi t5, 257
// CHECK-NEXT: li t5, 8
// CHECK-NEXT: li t6, 35
// CHECK-NEXT: scfgwi t6, 66
// CHECK-NEXT: scfgwi t5, 194
// CHECK-NEXT: scfgwi t4, 864
// CHECK-NEXT: scfgwi t2, 833
// CHECK-NEXT: scfgwi t3, 833
// CHECK-NEXT: scfgwi t1, 898
// CHECK-NEXT: csrrsi zero, 1984, 1
// CHECK-NEXT: li t2, 288
// CHECK-NEXT: mv t1, zero
// CHECK-NEXT: li t1, 288
// CHECK-NEXT: mv t0, zero
// CHECK-NEXT: # Constant folded riscv_cf.bge
// CHECK-NEXT: scf_body_{{\d+}}_for:
// CHECK-NEXT: add t4, t0, t1
// CHECK-NEXT: fld ft3, 0(t4)
// CHECK-NEXT: li t5, 8
// CHECK-NEXT: frep.o t5, 1, 0, 0
// CHECK-NEXT: fmadd.d ft3, ft0, ft1, ft3
// CHECK-NEXT: fsd ft3, 0(t4)
// CHECK-NEXT: addi t1, t1, 8
// CHECK-NEXT: blt t1, t2, scf_body_{{\d+}}_for
// CHECK-NEXT: fmv.d ft4, ft3
// CHECK-NEXT: li t3, 8
// CHECK-NEXT: frep.o t3, 1, 0, 0
// CHECK-NEXT: fmadd.d ft4, ft0, ft1, ft4
// CHECK-NEXT: fmv.d ft2, ft4
// CHECK-NEXT: addi t0, t0, 8
// CHECK-NEXT: blt t0, t1, scf_body_{{\d+}}_for
// CHECK-NEXT: scf_body_end_{{\d+}}_for:
// CHECK-NEXT: csrrci zero, 1984, 1
// CHECK-NEXT: ret
Expand All @@ -124,7 +130,9 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
]
} ins(%X, %Y : memref<128xf64>, memref<128xf64>) {
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.readable<f64>):
%init = riscv.fld %G_moved, 0 : (!riscv.reg<>) -> !riscv.freg<>
%zero_int = arith.constant 0 : i32
%zero_float = arith.sitofp %zero_int : i32 to f64
%init = builtin.unrealized_conversion_cast %zero_float : f64 to !riscv.freg<>

%c0 = riscv.li 0: () -> !riscv.reg<>
%c1 = riscv.li 1: () -> !riscv.reg<>
Expand Down Expand Up @@ -160,7 +168,7 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
// CHECK-NEXT: scfgwi t2, 768
// CHECK-NEXT: scfgwi t1, 769
// CHECK-NEXT: csrrsi zero, 1984, 1
// CHECK-NEXT: fld ft3, 0(t0)
// CHECK-NEXT: fcvt.d.w ft3, zero
// CHECK-NEXT: li t1, 127
// CHECK-NEXT: frep.o t1, 1, 0, 0
// CHECK-NEXT: fmadd.d ft3, ft0, ft1, ft3
Expand Down Expand Up @@ -429,7 +437,7 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
// CHECK-NEXT: csrrsi zero, 1984, 1
// CHECK-NEXT: mv t0, zero
// CHECK-NEXT: # Constant folded riscv_cf.bge
// CHECK-NEXT: scf_body_2_for:
// CHECK-NEXT: scf_body_{{\d+}}_for:
// CHECK-NEXT: fmul.d ft6, ft0, ft1
// CHECK-NEXT: fmul.d ft5, ft0, ft1
// CHECK-NEXT: fmul.d ft4, ft0, ft1
Expand All @@ -445,20 +453,22 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
// CHECK-NEXT: fmadd.d ft2, ft0, ft1, ft4
// CHECK-NEXT: fmadd.d ft2, ft0, ft1, ft3
// CHECK-NEXT: addi t0, t0, 4
// CHECK-NEXT: blt t0, t1, scf_body_2_for
// CHECK-NEXT: scf_body_end_2_for:
// CHECK-NEXT: blt t0, t1, scf_body_{{\d+}}_for
// CHECK-NEXT: scf_body_end_{{\d+}}_for:
// CHECK-NEXT: csrrci zero, 1984, 1
// CHECK-NEXT: ret

riscv.assembly_section ".data" {
riscv.label ".min_val"
riscv.directive ".quad" "0xc0c3880000000000" {"comment" = "double -1.0E+4"}
}
// x[ M x K ]
// y[ K x N ]
// g[ M x N ]
func.func public @pooling_nchw_max_d1_s2_3x3(
%X: memref<1x1x16x16xf64>,
%Y: memref<1x1x7x7xf64>
) -> () {
%Y_moved = builtin.unrealized_conversion_cast %Y : memref<1x1x7x7xf64> to !riscv.reg<>

%c0_val = arith.constant 0 : i32
%c0 = builtin.unrealized_conversion_cast %c0_val : i32 to !riscv.reg<>
%c1_val = arith.constant 1 : i32
Expand All @@ -470,27 +480,31 @@ func.func public @pooling_nchw_max_d1_s2_3x3(
%c512_val = arith.constant 512 : i32
%c512 = builtin.unrealized_conversion_cast %c512_val : i32 to !riscv.reg<>

%min_val = riscv.fld %c0, ".min_val" : (!riscv.reg<>) -> !riscv.freg<>

memref_stream.streaming_region {
patterns = [
#memref_stream.stride_pattern<ub = [1, 1, 7, 7, 3, 3], index_map = (d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
#memref_stream.stride_pattern<ub = [1, 1, 7, 7, 3, 3], index_map = (d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>,
#memref_stream.stride_pattern<ub = [1, 1, 7, 7], index_map = (d0, d1, d2, d3) -> (d0, d1, d2, d3)>
]
} ins(%X : memref<1x1x16x16xf64>) {
^0(%x_stream : !stream.readable<f64>):

} ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) {
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.writable<f64>):
%c392_val = arith.constant 392 : i32
%c392 = builtin.unrealized_conversion_cast %c392_val : i32 to !riscv.reg<>
riscv_scf.for %y_i : !riscv.reg<> = %c0 to %c392 step %c8 {
%Y_dest = riscv.add %Y_moved, %y_i : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
%init = riscv.fld %Y_dest, 0 : (!riscv.reg<>) -> !riscv.freg<>
%init = riscv.fmv.d %min_val : (!riscv.freg<>) -> !riscv.freg<>

%y = riscv_scf.for %i : !riscv.reg<> = %c0 to %c9 step %c1 iter_args(%acc = %init) -> (!riscv.freg<>) {
%res = riscv_scf.for %i : !riscv.reg<> = %c0 to %c9 step %c1 iter_args(%acc = %init) -> (!riscv.freg<>) {
%x_val = memref_stream.read from %x_stream : f64
%x = builtin.unrealized_conversion_cast %x_val : f64 to !riscv.freg<>
%res = riscv.fmax.d %x, %acc : (!riscv.freg<>, !riscv.freg<>) -> !riscv.freg<>
riscv_scf.yield %res : !riscv.freg<>
}

riscv.fsd %Y_dest, %y, 0 : (!riscv.reg<>, !riscv.freg<>) -> ()
%y = riscv.fmv.d %res : (!riscv.freg<>) -> !riscv.freg<>
%y_val = builtin.unrealized_conversion_cast %y : !riscv.freg<> to f64

memref_stream.write %y_val to %y_stream : f64

riscv_scf.yield
}
Expand All @@ -500,12 +514,16 @@ func.func public @pooling_nchw_max_d1_s2_3x3(
}


// CHECK: .text
// CHECK-NEXT: .data
// CHECK-NEXT: .min_val:
// CHECK-NEXT: .quad 0xc0c3880000000000
// CHECK-NEXT: .text
// CHECK-NEXT: .globl pooling_nchw_max_d1_s2_3x3
// CHECK-NEXT: .p2align 2
// CHECK-NEXT: pooling_nchw_max_d1_s2_3x3:
// CHECK-NEXT: mv t2, a0
// CHECK-NEXT: mv t0, a1
// CHECK-NEXT: mv t3, a0
// CHECK-NEXT: mv t1, a1
// CHECK-NEXT: fld ft3, .min_val, zero
// CHECK-NEXT: li t4, 8
// CHECK-NEXT: li a3, 2
// CHECK-NEXT: li a2, 2
Expand All @@ -522,20 +540,24 @@ func.func public @pooling_nchw_max_d1_s2_3x3(
// CHECK-NEXT: scfgwi t4, 256
// CHECK-NEXT: li t4, -112
// CHECK-NEXT: scfgwi t4, 288
// CHECK-NEXT: scfgwi t2, 864
// CHECK-NEXT: li t4, 8
// CHECK-NEXT: li t5, 48
// CHECK-NEXT: scfgwi t5, 65
// CHECK-NEXT: scfgwi t4, 193
// CHECK-NEXT: scfgwi t3, 864
// CHECK-NEXT: scfgwi t1, 897
// CHECK-NEXT: csrrsi zero, 1984, 1
// CHECK-NEXT: li t2, 392
// CHECK-NEXT: mv t1, zero
// CHECK-NEXT: li t1, 392
// CHECK-NEXT: mv t0, zero
// CHECK-NEXT: # Constant folded riscv_cf.bge
// CHECK-NEXT: scf_body_{{\d+}}_for:
// CHECK-NEXT: add t4, t0, t1
// CHECK-NEXT: fld ft3, 0(t4)
// CHECK-NEXT: li t5, 8
// CHECK-NEXT: frep.o t5, 1, 0, 0
// CHECK-NEXT: fmax.d ft3, ft0, ft3
// CHECK-NEXT: fsd ft3, 0(t4)
// CHECK-NEXT: addi t1, t1, 8
// CHECK-NEXT: blt t1, t2, scf_body_{{\d+}}_for
// CHECK-NEXT: fmv.d ft4, ft3
// CHECK-NEXT: li t3, 8
// CHECK-NEXT: frep.o t3, 1, 0, 0
// CHECK-NEXT: fmax.d ft4, ft0, ft4
// CHECK-NEXT: fmv.d ft1, ft4
// CHECK-NEXT: addi t0, t0, 8
// CHECK-NEXT: blt t0, t1, scf_body_{{\d+}}_for
// CHECK-NEXT: scf_body_end_{{\d+}}_for:
// CHECK-NEXT: csrrci zero, 1984, 1
// CHECK-NEXT: ret
Expand Down Expand Up @@ -601,8 +623,6 @@ func.func public @pooling_nchw_sum_d1_s2_3x3(
%X: memref<1x1x16x16xf64>,
%Y: memref<1x1x7x7xf64>
) -> () {
%Y_moved = builtin.unrealized_conversion_cast %Y : memref<1x1x7x7xf64> to !riscv.reg<>

%c0_val = arith.constant 0 : i32
%c0 = builtin.unrealized_conversion_cast %c0_val : i32 to !riscv.reg<>
%c1_val = arith.constant 1 : i32
Expand All @@ -614,28 +634,33 @@ func.func public @pooling_nchw_sum_d1_s2_3x3(
%c512_val = arith.constant 512 : i32
%c512 = builtin.unrealized_conversion_cast %c512_val : i32 to !riscv.reg<>

%zero_float = arith.sitofp %c0_val : i32 to f64
%zero_reg = builtin.unrealized_conversion_cast %zero_float : f64 to !riscv.freg<>

memref_stream.streaming_region {
patterns = [
#memref_stream.stride_pattern<ub = [1, 1, 7, 7, 3, 3], index_map = (d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>
#memref_stream.stride_pattern<ub = [1, 1, 7, 7, 3, 3], index_map = (d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>,
#memref_stream.stride_pattern<ub = [1, 1, 7, 7], index_map = (d0, d1, d2, d3) -> (d0, d1, d2, d3)>
]
} ins(%X : memref<1x1x16x16xf64>) {
^0(%x_stream : !stream.readable<f64>):

} ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) {
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.writable<f64>):
%c392_val = arith.constant 392 : i32
%c392 = builtin.unrealized_conversion_cast %c392_val : i32 to !riscv.reg<>
riscv_scf.for %y_i : !riscv.reg<> = %c0 to %c392 step %c8 {
%Y_dest = riscv.add %Y_moved, %y_i : (!riscv.reg<>, !riscv.reg<>) -> !riscv.reg<>
%init = riscv.fld %Y_dest, 0 : (!riscv.reg<>) -> !riscv.freg<>
%init = riscv.fmv.d %zero_reg : (!riscv.freg<>) -> !riscv.freg<>

%y = riscv_scf.for %i : !riscv.reg<> = %c0 to %c9 step %c1 iter_args(%acc = %init) -> (!riscv.freg<>) {
%res = riscv_scf.for %i : !riscv.reg<> = %c0 to %c9 step %c1 iter_args(%acc = %init) -> (!riscv.freg<>) {
%x_val = memref_stream.read from %x_stream : f64
%acc_val = builtin.unrealized_conversion_cast %acc : !riscv.freg<> to f64
%res_val = arith.addf %x_val, %acc_val : f64
%res = builtin.unrealized_conversion_cast %res_val : f64 to !riscv.freg<>
riscv_scf.yield %res : !riscv.freg<>
}

riscv.fsd %Y_dest, %y, 0 : (!riscv.reg<>, !riscv.freg<>) -> ()
%y = riscv.fmv.d %res : (!riscv.freg<>) -> !riscv.freg<>
%y_val = builtin.unrealized_conversion_cast %y : !riscv.freg<> to f64

memref_stream.write %y_val to %y_stream : f64

riscv_scf.yield
}
Expand All @@ -650,8 +675,9 @@ func.func public @pooling_nchw_sum_d1_s2_3x3(
// CHECK-NEXT: .globl pooling_nchw_sum_d1_s2_3x3
// CHECK-NEXT: .p2align 2
// CHECK-NEXT: pooling_nchw_sum_d1_s2_3x3:
// CHECK-NEXT: mv t2, a0
// CHECK-NEXT: mv t0, a1
// CHECK-NEXT: mv t3, a0
// CHECK-NEXT: mv t1, a1
// CHECK-NEXT: fcvt.d.w ft3, zero
// CHECK-NEXT: li t4, 8
// CHECK-NEXT: li a3, 2
// CHECK-NEXT: li a2, 2
Expand All @@ -668,20 +694,24 @@ func.func public @pooling_nchw_sum_d1_s2_3x3(
// CHECK-NEXT: scfgwi t4, 256
// CHECK-NEXT: li t4, -112
// CHECK-NEXT: scfgwi t4, 288
// CHECK-NEXT: scfgwi t2, 864
// CHECK-NEXT: li t4, 8
// CHECK-NEXT: li t5, 48
// CHECK-NEXT: scfgwi t5, 65
// CHECK-NEXT: scfgwi t4, 193
// CHECK-NEXT: scfgwi t3, 864
// CHECK-NEXT: scfgwi t1, 897
// CHECK-NEXT: csrrsi zero, 1984, 1
// CHECK-NEXT: li t2, 392
// CHECK-NEXT: mv t1, zero
// CHECK-NEXT: li t1, 392
// CHECK-NEXT: mv t0, zero
// CHECK-NEXT: # Constant folded riscv_cf.bge
// CHECK-NEXT: scf_body_{{\d+}}_for:
// CHECK-NEXT: add t4, t0, t1
// CHECK-NEXT: fld ft3, 0(t4)
// CHECK-NEXT: li t5, 8
// CHECK-NEXT: frep.o t5, 1, 0, 0
// CHECK-NEXT: fadd.d ft3, ft0, ft3
// CHECK-NEXT: fsd ft3, 0(t4)
// CHECK-NEXT: addi t1, t1, 8
// CHECK-NEXT: blt t1, t2, scf_body_{{\d+}}_for
// CHECK-NEXT: fmv.d ft4, ft3
// CHECK-NEXT: li t3, 8
// CHECK-NEXT: frep.o t3, 1, 0, 0
// CHECK-NEXT: fadd.d ft4, ft0, ft4
// CHECK-NEXT: fmv.d ft1, ft4
// CHECK-NEXT: addi t0, t0, 8
// CHECK-NEXT: blt t0, t1, scf_body_{{\d+}}_for
// CHECK-NEXT: scf_body_end_{{\d+}}_for:
// CHECK-NEXT: csrrci zero, 1984, 1
// CHECK-NEXT: ret

0 comments on commit 32b2396

Please sign in to comment.