@@ -363,13 +363,11 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
363363
364364// CHECK-LABEL: @conv2d_i8
365365func.func @conv2d_i8 (%input: tensor <1 x49 x42 x27 xi8 >, %weights: tensor <28 x1 x1 x27 xi8 >, %bias: tensor <28 xi8 >) -> () {
366- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
367- // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
368366 // CHECK: %[[M_IN:.+]] = tensor.empty()
369367 // CHECK: %[[CST:.+]] = arith.constant 0
370368 // CHECK: %[[FILL:.+]] = linalg.fill
371369 // CHECK: %[[B_IN:.+]] = tensor.empty()
372- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
370+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 , %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8 >, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
373371 // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
374372 // CHECK: arith.extsi
375373 // CHECK: arith.addi
@@ -385,13 +383,11 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
385383
386384// CHECK-LABEL: @conv2d_f32
387385func.func @conv2d_f32 (%input: tensor <1 x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >, %bias: tensor <28 xf32 >) -> () {
388- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
389- // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
390386 // CHECK: %[[M_IN:.+]] = tensor.empty()
391387 // CHECK: %[[CST:.+]] = arith.constant 0
392388 // CHECK: %[[FILL:.+]] = linalg.fill
393389 // CHECK: %[[B_IN:.+]] = tensor.empty()
394- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
390+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
395391 // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
396392 // CHECK: arith.addf
397393 // CHECK: linalg.yield
@@ -408,13 +404,11 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
408404func.func @conv2d_dyn (%input: tensor <?x49 x42 x27 xf32 >, %weights: tensor <28 x3 x3 x27 xf32 >, %bias: tensor <28 xf32 >) -> () {
409405 // CHECK: %[[C0:.+]] = arith.constant 0
410406 // CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
411- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
412- // CHECK: %[[W:.+]] = tosa.transpose %arg1, %[[PERM]]
413407 // CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
414408 // CHECK: %[[CST:.+]] = arith.constant 0
415409 // CHECK: %[[FILL:.+]] = linalg.fill
416410 // CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
417- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[W]] : tensor<?x49x42x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
411+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
418412 // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
419413 // CHECK: %[[ADD:.+]] = arith.addf
420414 // CHECK: linalg.yield %[[ADD]] : f32
@@ -468,13 +462,11 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
468462 // CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
469463
470464 // Running convolution
471- // CHECK: %[[PERM:.+]] = arith.constant dense<[1, 2, 3, 0]>
472- // CHECK: %[[WEIGHT:.+]] = tosa.transpose %arg1, %[[PERM]]
473465 // CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
474466 // CHECK: %[[CST:.+]] = arith.constant 0
475467 // CHECK: %[[FILL:.+]] = linalg.fill
476468 // CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
477- // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[WEIGHT]] : tensor<1x?x?x27xf32>, tensor<3x3x27x28xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
469+ // CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32 >) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
478470 // CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
479471 // CHECK: %[[ADD:.+]] = arith.addf
480472 // CHECK: linalg.yield %[[ADD]] : f32
@@ -489,7 +481,7 @@ func.func @conv2d_padded_f32(%input: tensor<1x47x40x28xf32>, %weights: tensor<28
489481 // CHECK: %[[C0:.+]] = arith.constant 0
490482 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
491483 // CHECK: tensor.yield %[[C0]]
492- // CHECK: linalg.conv_2d_nhwc_hwcf
484+ // CHECK: linalg.conv_2d_nhwc_fhwc
493485 %0 = tosa.conv2d %input , %weights , %bias {pad = array<i64 : 1 , 1 , 1 , 1 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 2 , 1 >} : (tensor <1 x47 x40 x28 xf32 >, tensor <28 x3 x3 x28 xf32 >, tensor <28 xf32 >) -> tensor <1 x45 x40 x28 xf32 >
494486 return
495487}
@@ -501,7 +493,7 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
501493 // CHECK: %[[C22:.+]] = arith.constant -22
502494 // CHECK: tensor.pad %arg0 low[0, 1, 1, 0] high[0, 1, 1, 0]
503495 // CHECK: tensor.yield %[[C22]]
504- // CHECK: linalg.conv_2d_nhwc_hwcf_q
496+ // CHECK: linalg.conv_2d_nhwc_fhwc_q
505497 %0 = tosa.conv2d %arg0 , %arg1 , %arg2 {dilation = array<i64 : 1 , 1 >, pad = array<i64 : 1 , 1 , 1 , 1 >, quantization_info = #tosa.conv_quant <input_zp = -22 , weight_zp = 42 >, stride = array<i64 : 1 , 1 >} : (tensor <1 x12 x12 x1 xi8 >, tensor <1024 x3 x3 x1 xi8 >, tensor <1024 xi32 >) -> tensor <1 x12 x12 x1024 xi32 >
506498 return
507499}
0 commit comments