diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 751ae785bda6f..17873444b2d71 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -558,7 +558,8 @@ static LogicalResult verifyConvOpErrorIf(T op) { return success(); const int64_t biasChannels = biasType.getDimSize(0); - const int64_t outputChannels = outputType.getDimSize(3); + const int64_t outputChannels = + outputType.getDimSize(outputType.getRank() - 1); if (biasChannels == ShapedType::kDynamic || outputChannels == ShapedType::kDynamic) // Skip following checks if biasChannels or outputChannels is dynamic dim diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 7374cfd1145b9..75126a11ac504 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -38,12 +38,12 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % // ----- // CHECK-LABEL: conv3d -func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<21xf32>) -> tensor<1x4x8x21x34xf32> { +func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>) -> tensor<1x4x8x21x34xf32> { // CHECK: profiles: [ [pro_int, pro_fp] ] // CHECK: extensions: [ [int4, int16, fp8e4m3, fp8e5m2, bf16] ] %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32> + %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32> return %0 : tensor<1x4x8x21x34xf32> } diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index c862ae375f33b..bb0d3b46955a1 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -26,9 +26,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<8x1x1x4xi4>, %ar } // ----- -func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xi16>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<21xi48>, %arg3: tensor<1xi16>, %arg4: tensor<1xi8>) -> tensor<1x4x8x21x34xi48> { +func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xi16>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi48>, %arg3: tensor<1xi16>, %arg4: tensor<1xi8>) -> tensor<1x4x8x21x34xi48> { // expected-error@+1 {{'tosa.conv3d' op illegal: requires [int16] but not enabled in target}} - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i48, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xi16>, tensor<34x1x1x1x17xi8>, tensor<21xi48>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x4x8x21x34xi48> + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i48, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xi16>, tensor<34x1x1x1x17xi8>, tensor<34xi48>, tensor<1xi16>, tensor<1xi8>) -> tensor<1x4x8x21x34xi48> return %0 : tensor<1x4x8x21x34xi48> } @@ -445,10 +445,10 @@ func.func @test_conv2d_non_const_input_zp(%arg0: tensor<1x4x4x4xi8>, %arg1: tens // ----- -func.func @test_conv3d_non_const_weight_zp(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<21xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> { +func.func @test_conv3d_non_const_weight_zp(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi32>, %arg3: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> { %input_zp = "tosa.const"() {values = dense<0> : tensor<1xi8> } : () -> tensor<1xi8> // expected-error@+1 {{'tosa.conv3d' op expected compile time resolvable constant, but got variable value for operand #4}} - %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %arg3 {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<21xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi32> + %0 = tosa.conv3d %arg0, %arg1, %arg2, %input_zp, %arg3 {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi32> return %0 : tensor<1x4x8x21x34xi32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index c1181825f0c97..b64074e412ed1 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -104,15 +104,15 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8> // ----- // CHECK-LABEL: conv3d -func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<21xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x8x21x34xf32> { - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32> +func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x8x21x34xf32> { + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32> return %0 : tensor<1x4x8x21x34xf32> } // ----- // CHECK-LABEL: conv3d_with_local_bound -func.func @test_conv3d_with_local_bound(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<21xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x8x21x34xf32> { - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<21xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32> +func.func @test_conv3d_with_local_bound(%arg0: tensor<1x4x8x21x17xf32>, %arg1: tensor<34x1x1x1x17xf32>, %arg2: tensor<34xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) -> tensor<1x4x8x21x34xf32> { + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array, local_bound = true} : (tensor<1x4x8x21x17xf32>, tensor<34x1x1x1x17xf32>, tensor<34xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x4x8x21x34xf32> return %0 : tensor<1x4x8x21x34xf32> } @@ -823,8 +823,8 @@ func.func @test_conv2d_f8E5M2(%arg0: tensor<1x4x4x4xf8E5M2>, %arg1: tensor<8x1x1 // ----- // CHECK-LABEL: conv3d_f8E5M2 -func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<21xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> { - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<21xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> +func.func @test_conv3d_f8E5M2(%arg0: tensor<1x4x8x21x17xf8E5M2>, %arg1: tensor<34x1x1x1x17xf8E5M2>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E5M2>, %arg4: tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> { + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf8E5M2>, tensor<34x1x1x1x17xf8E5M2>, tensor<34xf16>, tensor<1xf8E5M2>, tensor<1xf8E5M2>) -> tensor<1x4x8x21x34xf16> return %0 : tensor<1x4x8x21x34xf16> } @@ -968,8 +968,8 @@ func.func @test_conv2d_f8E4M3FN(%arg0: tensor<1x4x4x4xf8E4M3FN>, %arg1: tensor<8 // ----- // CHECK-LABEL: conv3d_f8E4M3FN -func.func @test_conv3d_f8E4M3FN(%arg0: tensor<1x4x8x21x17xf8E4M3FN>, %arg1: tensor<34x1x1x1x17xf8E4M3FN>, %arg2: tensor<21xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16> { - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf8E4M3FN>, tensor<34x1x1x1x17xf8E4M3FN>, tensor<21xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16> +func.func @test_conv3d_f8E4M3FN(%arg0: tensor<1x4x8x21x17xf8E4M3FN>, %arg1: tensor<34x1x1x1x17xf8E4M3FN>, %arg2: tensor<34xf16>, %arg3: tensor<1xf8E4M3FN>, %arg4: tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16> { + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f16, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf8E4M3FN>, tensor<34x1x1x1x17xf8E4M3FN>, tensor<34xf16>, tensor<1xf8E4M3FN>, tensor<1xf8E4M3FN>) -> tensor<1x4x8x21x34xf16> return %0 : tensor<1x4x8x21x34xf16> } diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index efbb9e9d1843f..72669c62c95ca 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -33,9 +33,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xf32>, %arg1: tensor<8x1x1x4xf32>, % } // ----- -func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf16>, %arg1: tensor<34x1x1x1x17xf16>, %arg2: tensor<21xf16>, %arg3: tensor<1xf16>, %arg4: tensor<1xf16>) -> tensor<1x4x8x21x34xf16> { +func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xf16>, %arg1: tensor<34x1x1x1x17xf16>, %arg2: tensor<34xf16>, %arg3: tensor<1xf16>, %arg4: tensor<1xf16>) -> tensor<1x4x8x21x34xf16> { // expected-error@+1 {{'tosa.conv3d' op illegal: requires [pro_fp] but not enabled in target}} - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf16>, tensor<34x1x1x1x17xf16>, tensor<21xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x4x8x21x34xf16> + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xf16>, tensor<34x1x1x1x17xf16>, tensor<34xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x4x8x21x34xf16> return %0 : tensor<1x4x8x21x34xf16> } diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir index b102eea5699dd..e98b906377b22 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -38,9 +38,9 @@ func.func @test_conv2d(%arg0: tensor<1x4x4x4xi8>, %arg1: tensor<8x1x1x4xi8>, %ar } // ----- -func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<21xi32>, %arg3: tensor<1xi8>, %arg4: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> { +func.func @test_conv3d(%arg0: tensor<1x4x8x21x17xi8>, %arg1: tensor<34x1x1x1x17xi8>, %arg2: tensor<34xi32>, %arg3: tensor<1xi8>, %arg4: tensor<1xi8>) -> tensor<1x4x8x21x34xi32> { // expected-error@+1 {{'tosa.conv3d' op illegal: requires [pro_int] but not enabled in target}} - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<21xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi32> + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x4x8x21x17xi8>, tensor<34x1x1x1x17xi8>, tensor<34xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x8x21x34xi32> return %0 : tensor<1x4x8x21x34xi32> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index c6ac8074c0326..1ad1e6c76c294 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -824,18 +824,18 @@ func.func @conv2d_strided(%input: tensor<1x13x15x1xf32>, %weights: tensor<1x1x1x // ----- // CHECK-LABEL: @conv3d_static -func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x4x3xf32>, %bias: tensor<7xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) -> () { +func.func @conv3d_static(%input: tensor<2x8x9x10x3xf32>, %weights: tensor<5x3x6x4x3xf32>, %bias: tensor<5xf32>, %input_zp: tensor<1xf32>, %weight_zp: tensor<1xf32>) -> () { // CHECK: -> tensor<2x6x4x7x5xf32> - %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor + %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor return } // ----- // CHECK-LABEL: @conv3d_dynamic_input -func.func @conv3d_dynamic_input(%arg0: tensor, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<7xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) { +func.func @conv3d_dynamic_input(%arg0: tensor, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) { // CHECK: -> tensor - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<5x3x6x4x3xf32>, tensor<7xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor, tensor<5x3x6x4x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor return } @@ -860,18 +860,18 @@ func.func @conv3d_dynamic_bias(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x // ----- // CHECK-LABEL: @conv3d_padded -func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<18xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) { +func.func @conv3d_padded(%arg0: tensor<2x8x9x10x3xf32>, %arg1: tensor<5x3x6x4x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) { // CHECK: -> tensor<2x9x11x18x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<18xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x8x9x10x3xf32>, tensor<5x3x6x4x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor return } // ----- // CHECK-LABEL: @conv3d_dilated -func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2x3xf32>, %arg2: tensor<12xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) { +func.func @conv3d_dilated(%arg0: tensor<2x12x14x16x3xf32>, %arg1: tensor<5x3x6x2x3xf32>, %arg2: tensor<5xf32>, %arg3: tensor<1xf32>, %arg4: tensor<1xf32>) { // CHECK: -> tensor<2x6x4x12x5xf32> - %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<12xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor + %0 = tosa.conv3d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<2x12x14x16x3xf32>, tensor<5x3x6x2x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor return }