diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 3a75d2ac08157..a826b153e979f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1708,6 +1708,13 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, .reifyResultShapes(rewriter, reifiedReturnShapes); (void)status; // prevent unused variable warning on non-assert builds assert(succeeded(status) && "failed to reify result shapes"); + + // If the input vector sizes are not provided, then the vector sizes are + // determined by the result tensor shape. + if (inputVectorSizes.empty()) { + inputVectorSizes = + mlir::dyn_cast(padOp->getResultTypes()[0]).getShape(); + } auto maskedRead = vector::createReadOrMaskedRead( rewriter, loc, padOp.getSource(), inputVectorSizes, padValue, /*useInBoundsInsteadOfMasking=*/false); @@ -1915,9 +1922,16 @@ vectorizePadOpPrecondition(tensor::PadOp padOp, return failure(); } + bool satisfyEmptyCond = true; + if (inputVectorSizes.empty()) { + if (!mlir::dyn_cast(padOp->getResultTypes()[0]) + .hasStaticShape() || + !padOp.getSourceType().hasStaticShape()) + satisfyEmptyCond = false; + } ArrayRef resultTensorShape = padOp.getResultType().getShape(); - if (failed(vector::isValidMaskedInputVector(resultTensorShape, - inputVectorSizes))) + if (!satisfyEmptyCond && failed(vector::isValidMaskedInputVector( + resultTensorShape, inputVectorSizes))) return failure(); if (llvm::any_of(padOp.getLow(), [](Value v) { diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir index 5d3c07c8e23c1..ea7b81871ede3 100644 --- a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir +++ b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir @@ -126,3 +126,23 @@ module attributes {transform.with_named_sequence} { transform.yield } } + + // ----- + +func.func @test_pad_no_vectorize_dynamic_shape(%arg0: tensor, %arg1: tensor<4x16xf32>) -> tensor<4x16xf32> { + %f0 = arith.constant 0.0 : f32 + // expected-error @+1 {{Attempted to vectorize, but failed}} + %pad = tensor.pad %arg0 low[0, 0] high[1,1] { + ^bb0(%arg3: index, %arg4: index): + tensor.yield %f0 : f32 + } : tensor to tensor<4x16xf32> + return %pad : tensor<4x16xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.pad"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index bbeccc7fecd68..e5e70ca40f300 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1055,3 +1055,35 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf transform.yield } } + + // ----- + +// CHECK-LABEL: test_vectorize_pad_no_vector_sizes +func.func @test_vectorize_pad_no_vector_sizes(%arg0: tensor<63x63xf32>) -> tensor<64x64xf32> { + %f0 = arith.constant 0.0 : f32 + %pad = tensor.pad %arg0 low[0, 0] high[1, 1] { + ^bb0(%arg1: index, %arg2: index): + tensor.yield %f0 : f32 + } : tensor<63x63xf32> to tensor<64x64xf32> + return %pad : tensor<64x64xf32> +} +// CHECK-DAG: %[[cst:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[c63:.*]] = arith.constant 63 : index +// CHECK-DAG: %[[c63_0:.*]] = arith.constant 63 : index +// CHECK: %[[mask:.*]] = vector.create_mask %[[c63]], %[[c63_0]] : vector<64x64xi1> +// CHECK: %[[read:.*]] = vector.mask %0 { vector.transfer_read {{.*}}, %cst {in_bounds = [true, true]} +// CHECK-SAME : tensor<63x63xf32>, vector<64x64xf32> } : vector<64x64xi1> -> vector<64x64xf32> +// CHECK: %[[empty:.*]] = tensor.empty() : tensor<64x64xf32> +// CHECK: %[[c0_1:.*]] = arith.constant 0 : index +// CHECK: %[[result:.*]] = vector.transfer_write %[[read]], {{.*}} {in_bounds = [true, true]} : +// CHECK-SAME : vector<64x64xf32>, tensor<64x64xf32> +// CHECK: return %[[result:.*]] : tensor<64x64xf32> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.pad"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +}