diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 4eb334f8bbbfa..e0fd5f1b14070 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3388,8 +3388,9 @@ DiagnosedSilenceableFailure transform::VectorizeOp::apply( // TODO: Check that the correct number of vectorSizes was provided. for (Operation *target : targets) { - if (!isa( - target)) { + if (!isa(target)) { return mlir::emitSilenceableFailure(target->getLoc()) << "Unsupported Op, cannot vectorize"; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 3a75d2ac08157..7a4db82749fd1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1718,6 +1718,209 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp, return success(); } +/// Vectorize a `tensor::expandshape` to these 3 Ops: +/// Vector::TransferReadOp - Reads a vector from the source tensor +/// ShapeCastOp - Reshape the data based on the target. +/// vector::TransferWriteOp. - Write the result vector back to the destination +/// tensor +static LogicalResult lowerTensorReshape(RewriterBase &rewriter, + Operation *inputOp, + ArrayRef inputVectorSizes, + SmallVectorImpl &newResults) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(inputOp); + auto src = inputOp->getOperand(0); + auto srcType = mlir::dyn_cast(src.getType()); + auto result = inputOp->getResults()[0]; + auto resultType = mlir::dyn_cast(result.getType()); + ArrayRef resultShape = resultType.getShape(); + ArrayRef srcShape = srcType.getShape(); + Location loc = inputOp->getLoc(); + + llvm::SmallVector srcVectorizedShape; + llvm::SmallDenseMap shapeScales; + + auto getVectorizeShape = [&](ArrayRef &retShape, + ArrayRef &inputShape) { + bool isResultShapeBigger = srcType.getRank() < resultType.getRank(); + + int64_t cur = 1, resultIdx = 0; + for (auto [srcIdx, ss] : llvm::enumerate(inputShape)) { + cur *= ss; + if (!isResultShapeBigger) { + // collapse + srcVectorizedShape.emplace_back(ss); + if (cur == retShape[resultIdx]) { + if (shapeScales.count(resultIdx)) { + srcVectorizedShape.back() *= shapeScales[resultIdx]; + } + cur = 1; + resultIdx++; + } + } else { + // expand + if (cur == retShape[resultIdx]) { + srcVectorizedShape.emplace_back(cur); + if (shapeScales.count(srcIdx)) { + srcVectorizedShape.back() *= shapeScales[srcIdx]; + } + cur = 1; + resultIdx++; + } + } + } + }; + if (!inputVectorSizes.empty()) { + for (auto [idx, vs] : llvm::enumerate(inputVectorSizes)) { + if (vs != resultShape[idx]) + shapeScales[idx] = vs / resultShape[idx]; + } + + bool isResultShapeBigger = srcType.getRank() < resultType.getRank(); + if (!isResultShapeBigger) { + getVectorizeShape(resultShape, srcShape); + } else { + getVectorizeShape(srcShape, resultShape); + } + } else { + srcVectorizedShape.assign(srcShape.begin(), srcShape.end()); + } + // read + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(srcType.getElementType())); + Value readResult = vector::createReadOrMaskedRead( + rewriter, loc, src, + inputVectorSizes.empty() ? srcType.getShape() : srcVectorizedShape, + padValue, false); + + auto shapeCastType = + VectorType::get(inputVectorSizes.empty() ? resultShape : inputVectorSizes, + resultType.getElementType()); + vector::ShapeCastOp shapeCastOp = + rewriter.create(loc, shapeCastType, readResult); + + // write + SmallVector destSizes; + for (auto size : resultShape) { + destSizes.emplace_back(rewriter.getIndexAttr(size)); + } + Operation *write = createWriteOrMaskedWrite( + rewriter, loc, shapeCastOp->getResults()[0], destSizes, + inputVectorSizes.empty() ? resultShape : inputVectorSizes, false); + newResults.push_back(write->getResult(0)); + return success(); +} + +/// Vectorize a `tensor::bitcast` to these 3 Ops: +/// vector::TransferReadOp - Reads a vector from the source tensor +/// vector.Bitcast - Bitcast the data based on the target. +/// vector::TransferWriteOp. - Write the result vector back to the destination +/// tensor +static LogicalResult lowerTensorBitcastOp(RewriterBase &rewriter, + tensor::BitcastOp bitCastOp, + ArrayRef inputVectorSizes, + SmallVectorImpl &newResults) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(bitCastOp); + + auto sourceType = bitCastOp.getSource().getType(); + auto resultType = bitCastOp.getResult().getType(); + auto resultShape = resultType.getShape(); + if (inputVectorSizes.empty()) { + inputVectorSizes = resultShape; + } + Location loc = bitCastOp->getLoc(); + + // read + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(sourceType.getElementType())); + Value readResult = vector::createReadOrMaskedRead( + rewriter, loc, bitCastOp.getSource(), inputVectorSizes, padValue, false); + + // bitcast + auto resultVectorType = + VectorType::get(inputVectorSizes, resultType.getElementType()); + vector::BitCastOp vectorbitCastOp = + rewriter.create(loc, resultVectorType, readResult); + + // write + llvm::SmallVector destSizes; + for (auto size : resultShape) + destSizes.emplace_back(rewriter.getIndexAttr(size)); + auto write = + createWriteOrMaskedWrite(rewriter, loc, vectorbitCastOp->getResult(0), + destSizes, inputVectorSizes, false); + newResults.push_back(write->getResults()[0]); + return success(); +} + +/// Vectorize a `tensor::concat` to these 3 Ops: +/// Tensor::EmptyOp - The result tensor. +/// Vector::TransferWriteOp - Write the result vector back to the destination +/// tensor. +/// Vector::TransferWriteOp - Write the result vector back to the destination +/// tensor. +static LogicalResult lowerTensorConcatOp(RewriterBase &rewriter, + tensor::ConcatOp concatOp, + ArrayRef inputVectorSizes, + SmallVectorImpl &newResults) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(concatOp); + + Location loc = concatOp.getLoc(); + FailureOr dest = + tensor::getOrCreateDestination(rewriter, loc, concatOp->getResult(0)); + if (failed(dest)) + return failure(); + + auto empty = dest->getDefiningOp(); + if (!empty) + return failure(); + + // Compute the partial sums for the slice offsets. + auto dim = concatOp.getDim(); + Value dimValue = + rewriter.create(loc, rewriter.getIndexAttr(dim)); + + int64_t rank = concatOp.getResultType().getRank(); + auto srcType = + mlir::dyn_cast(concatOp->getResultTypes()[0]); + auto padValue = rewriter.create( + loc, rewriter.getZeroAttr(srcType.getElementType())); + + // Construct the chain of insert_slice ops into the destination. + Value result = *dest; + Value previous_offset = rewriter.create(loc, 0); + for (auto [idx, input] : llvm::enumerate(concatOp.getInputs())) { + + SmallVector sizes = + tensor::getMixedSizes(rewriter, loc, input); + SmallVector readMaskShape; + auto inputType = mlir::dyn_cast(input.getType()); + auto sourceShape = inputType.getShape(); + + readMaskShape.append(sourceShape.begin(), sourceShape.end()); + Value readResult = vector::createReadOrMaskedRead( + rewriter, loc, input, sourceShape, padValue, false); + Value zero = rewriter.create(loc, 0); + SmallVector indices(rank, zero); + indices[dim] = previous_offset; + result = rewriter + .create( + loc, readResult, result, indices, + rewriter.getMultiDimIdentityMap(rank)) + ->getResults()[0]; + if (idx != concatOp.getNumOperands() - 1) { + auto dimOp = rewriter.create(loc, input, dimValue); + previous_offset = + rewriter.create(loc, dimOp, previous_offset); + } + } + + newResults.push_back(result); + return success(); +} + // TODO: probably need some extra checks for reduction followed by consumer // ops that may not commute (e.g. linear reduction + non-linear instructions). static LogicalResult reductionPreconditions(LinalgOp op) { @@ -1931,6 +2134,108 @@ vectorizePadOpPrecondition(tensor::PadOp padOp, return success(); } +static LogicalResult +lowerExpandOpPrecondition(tensor::ExpandShapeOp expandOp, + ArrayRef inputVectorSizes) { + auto resultType = expandOp->getResultTypes()[0]; + auto resultShape = mlir::dyn_cast(resultType); + // check reassociation + llvm::SmallVector associateIndices; + for (auto &attr : expandOp.getReassociation()) { + for (auto &indice : mlir::dyn_cast(attr)) { + associateIndices.push_back(mlir::dyn_cast(indice).getInt()); + } + } + + if (llvm::any_of(associateIndices, + [](int64_t x) { return x == ShapedType::kDynamic; })) { + LDBG("Reassociation must be static: " << expandOp << "\n"); + return failure(); + } + // check input and output shape + if (!resultShape.hasStaticShape() || + !expandOp.getSrcType().hasStaticShape()) { + LDBG("Input and output shape must be static: " << expandOp << "\n"); + return failure(); + } + if (!inputVectorSizes.empty() && + failed(vector::isValidMaskedInputVector(resultShape.getShape(), + inputVectorSizes))) + return failure(); + + return success(); +} + +static LogicalResult +lowerBitcastOpPrecondition(tensor::BitcastOp bitCastOp, + ArrayRef inputVectorSizes) { + auto resultType = bitCastOp->getResultTypes()[0]; + auto resultShapeType = mlir::dyn_cast(resultType); + auto srcType = bitCastOp.getSource().getType(); + auto srcShapeType = mlir::dyn_cast(srcType); + + bool isStaticInputOutput = + resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape(); + if (!isStaticInputOutput) { + LDBG("Input and output shape must be static: " << bitCastOp << "\n"); + return failure(); + } + + if (!inputVectorSizes.empty() && + failed(vector::isValidMaskedInputVector(resultShapeType.getShape(), + inputVectorSizes))) + return failure(); + return success(); +} + +static LogicalResult +lowerCollapseShapeOpPrecondition(tensor::CollapseShapeOp collapseOp, + ArrayRef inputVectorSizes) { + auto resultType = collapseOp->getResultTypes()[0]; + auto resultShapeType = mlir::dyn_cast(resultType); + auto srcShapeType = collapseOp.getSrcType(); + + bool isStaticInputOutput = + resultShapeType.hasStaticShape() && srcShapeType.hasStaticShape(); + if (!isStaticInputOutput) { + LDBG("Input and output shape must be static: " << collapseOp << "\n"); + return failure(); + } + + if (!inputVectorSizes.empty() && + failed(vector::isValidMaskedInputVector(resultShapeType.getShape(), + inputVectorSizes))) + return failure(); + return success(); +} + +static LogicalResult +lowerConcatOpPrecondition(tensor::ConcatOp concatOp, + ArrayRef inputVectorSizes) { + if (!inputVectorSizes.empty()) { + LDBG("Concat operation do not support specify inputVectorSizes: " + << concatOp << "\n"); + } + for (auto x : concatOp->getOperands()) { + auto type = mlir::dyn_cast(x.getType()); + if (!type) { + LDBG("Operation type error: " << concatOp << "\n"); + return failure(); + } + if (!type.hasStaticShape()) { + LDBG("Type must be static: " << concatOp << "\n"); + return failure(); + } + } + auto dim = concatOp.getDim(); + if (dim >= (uint64_t)concatOp.getResultType().getRank()) { + LDBG("Invalid dim: " << concatOp << "\n"); + return failure(); + } + + return success(); +} + /// Preconditions for scalable vectors. static LogicalResult vectorizeScalableVectorPrecondition(Operation *op, @@ -1976,6 +2281,19 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition( .Case([&](auto unpackOp) { return vectorizeUnPackOpPrecondition(unpackOp, inputVectorSizes); }) + .Case([&](auto expandShapeOp) { + return lowerExpandOpPrecondition(expandShapeOp, inputVectorSizes); + }) + .Case([&](auto collapseShapeOp) { + return lowerCollapseShapeOpPrecondition(collapseShapeOp, + inputVectorSizes); + }) + .Case([&](auto bitCastOp) { + return lowerBitcastOpPrecondition(bitCastOp, inputVectorSizes); + }) + .Case([&](auto concatOp) { + return lowerConcatOpPrecondition(concatOp, inputVectorSizes); + }) .Default([](auto) { return failure(); }); } @@ -2075,6 +2393,22 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op, return vectorizeAsTensorUnpackOp(rewriter, unpackOp, inputVectorSizes, results); }) + .Case([&](auto expandShapeOp) { + return lowerTensorReshape(rewriter, expandShapeOp, inputVectorSizes, + results); + }) + .Case([&](auto collapseShapeOp) { + return lowerTensorReshape(rewriter, collapseShapeOp, + inputVectorSizes, results); + }) + .Case([&](auto bitCastOp) { + return lowerTensorBitcastOp(rewriter, bitCastOp, inputVectorSizes, + results); + }) + .Case([&](auto concatOp) { + return lowerTensorConcatOp(rewriter, concatOp, inputVectorSizes, + results); + }) .Default([](auto) { return failure(); }); if (failed(vectorizeResult)) { diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index bbeccc7fecd68..114815b4e3de8 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1055,3 +1055,195 @@ func.func @test_vectorize_unpack_no_vector_sizes_permute(%source: tensor<4x7x4xf transform.yield } } + + // ----- + + // CHECK-LABEL: func @test_vectorize_collapseshape +func.func @test_vectorize_collapseshape(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[C0:.*]]= arith.constant 0 : index + // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[C80:.*]] = arith.constant 8 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C16:.*]] = arith.constant 16 : index + // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1> + // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<8x8x32x32xi1> -> vector<8x8x32x32xf32> + // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x32xf32> to vector<64x1024xf32> + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32> + // CHECK: %[[C01:.*]] = arith.constant 0 : index + // CHECK: %[[C64:.*]] = arith.constant 64 : index + // CHECK: %[[C512:.*]] = arith.constant 512 : index + // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1> + // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<64x1024xi1> -> tensor<64x512xf32> + // CHECK: return %[[WRIT]] : tensor<64x512xf32> + %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32> + return %collapsed : tensor<64x512xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [64, 1024] : !transform.any_op + transform.yield + } +} + + // ----- + + // CHECK-LABEL: func @test_vectorize_collapseshape_no_vector_size +func.func @test_vectorize_collapseshape_no_vector_size(%source: tensor<8x8x32x16xf32>, %dest: tensor<64x512xf32>) -> tensor<64x512xf32> { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[C0:.*]]= arith.constant 0 : index + // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true, true, true]} : tensor<8x8x32x16xf32>, vector<8x8x32x16xf32> + // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<8x8x32x16xf32> to vector<64x512xf32> + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32> + // CHECK: %[[C01:.*]] = arith.constant 0 : index + // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true]} : vector<64x512xf32>, tensor<64x512xf32> + // CHECK: return %[[WRIT]] : tensor<64x512xf32> + %collapsed = tensor.collapse_shape %source [[0, 1], [2, 3]] : tensor<8x8x32x16xf32> into tensor<64x512xf32> + return %collapsed : tensor<64x512xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.collapse_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} + + // ----- + + // CHECK-LABEL: func @test_vectorize_expandshape +func.func @test_vectorize_expandshape(%source: tensor<64x512xf32>, %dest: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32> { + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[C0:.*]]= arith.constant 0 : index + // CHECK: %[[C64:.*]] = arith.constant 64 : index + // CHECK: %[[C512:.*]] = arith.constant 512 : index + // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1> + // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<64x1024xi1> -> vector<64x1024xf32> + // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<64x1024xf32> to vector<8x8x32x32xf32> + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<8x8x32x16xf32> + // CHECK: %[[C01:.*]]= arith.constant 0 : index + // CHECK: %[[C8:.*]] = arith.constant 8 : index + // CHECK: %[[C80:.*]] = arith.constant 8 : index + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[C16:.*]] = arith.constant 16 : index + // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C8]], %[[C80]], %[[C32]], %[[C16]] : vector<8x8x32x32xi1> + // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<8x8x32x32xi1> -> tensor<8x8x32x16xf32> + // CHECK: return %[[WRIT]] : tensor<8x8x32x16xf32> + %expanded = tensor.expand_shape %source [[0, 1], [2, 3]] output_shape [8, 8, 32, 16] : tensor<64x512xf32> into tensor<8x8x32x16xf32> + return %expanded : tensor<8x8x32x16xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.expand_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [8, 8, 32, 32] : !transform.any_op + transform.yield + } +} + + // ----- + + // CHECK-LABEL: func @test_vectorize_collapseshape_no_vector_size +func.func @test_vectorize_collapseshape_no_vector_size(%source: tensor<64x512xf32>, %dest: tensor<8x8x32x16xf32>) -> tensor<8x8x32x16xf32>{ + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[C0:.*]]= arith.constant 0 : index + // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true]} : tensor<64x512xf32>, vector<64x512xf32> + // CHECK: %[[SHAPC:.*]] = vector.shape_cast %[[READ0]] : vector<64x512xf32> to vector<8x8x32x16xf32> + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<8x8x32x16xf32> + // CHECK: %[[C01:.*]] = arith.constant 0 : index + // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true, true, true]} : vector<8x8x32x16xf32>, tensor<8x8x32x16xf32> + // CHECK: return %[[WRIT]] : tensor<8x8x32x16xf32> + %expanded = tensor.expand_shape %source [[0, 1], [2, 3]] output_shape [8, 8, 32, 16] : tensor<64x512xf32> into tensor<8x8x32x16xf32> + return %expanded : tensor<8x8x32x16xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.expand_shape"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} + + // ----- + + // CHECK-LABEL: func @test_vectorize_bitcast +func.func @test_vectorize_bitcast(%source: tensor<64x512xi32>) -> tensor<64x512xf32> { + // CHECK: %[[C0i32:.*]] = arith.constant 0 : i32 + // CHECK: %[[C0:.*]]= arith.constant 0 : index + // CHECK: %[[C64:.*]] = arith.constant 64 : index + // CHECK: %[[C512:.*]] = arith.constant 512 : index + // CHECK: %[[MSK0:.*]] = vector.create_mask %[[C64]], %[[C512]] : vector<64x1024xi1> + // CHECK: %[[READ0:.*]] = vector.mask %[[MSK0]] {{.*}} : vector<64x1024xi1> -> vector<64x1024xi32> + // CHECK: %[[CAST:.*]] = vector.bitcast %[[READ0]] : vector<64x1024xi32> to vector<64x1024xf32> + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32> + // CHECK: %[[C00:.*]]= arith.constant 0 : index + // CHECK: %[[C641:.*]] = arith.constant 64 : index + // CHECK: %[[C5121:.*]] = arith.constant 512 : index + // CHECK: %[[WRITEMSK:.*]] = vector.create_mask %[[C641]], %[[C5121]] : vector<64x1024xi1> + // CHECK: %[[WRIT:.*]] = vector.mask %[[WRITEMSK]] {{.*}} : vector<64x1024xi1> -> tensor<64x512xf32> + // CHECK: return %[[WRIT]] : tensor<64x512xf32> + %0 = tensor.bitcast %source : tensor<64x512xi32> to tensor<64x512xf32> + return %0 : tensor<64x512xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.bitcast"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 vector_sizes [64, 1024] : !transform.any_op + transform.yield + } +} + + // ----- + + // CHECK-LABEL: func @test_vectorize_bitcast_no_vector_size +func.func @test_vectorize_bitcast_no_vector_size(%source: tensor<64x512xi32>) -> tensor<64x512xf32> { + // CHECK: %[[C0i32:.*]] = arith.constant 0 : i32 + // CHECK: %[[C0:.*]]= arith.constant 0 : index + // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[C0i32]] {in_bounds = [true, true]} : tensor<64x512xi32>, vector<64x512xi32> + // CHECK: %[[CAST:.*]] = vector.bitcast %[[READ0]] : vector<64x512xi32> to vector<64x512xf32> + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x512xf32> + // CHECK: %[[C00:.*]] = arith.constant 0 : index + // CHECK: %[[WRIT:.*]] = vector.transfer_write {{.*}}, {{.*}} {in_bounds = [true, true]} : vector<64x512xf32>, tensor<64x512xf32> + // CHECK: return %[[WRIT]] : tensor<64x512xf32> + %0 = tensor.bitcast %source : tensor<64x512xi32> to tensor<64x512xf32> + return %0 : tensor<64x512xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.bitcast"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} + + // ----- + + // CHECK-LABEL: func @test_vectorize_concat_no_vector_size +func.func @test_vectorize_concat_no_vector_size(%arg0: tensor<64x512xf32>, %arg1:tensor<64x512xf32>) -> tensor<64x1024xf32> { + // CHECK: %[[EMPT:.*]] = tensor.empty() : tensor<64x1024xf32> + // CHECK: %[[C1:.*]]= arith.constant 1 : index + // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 + // CHECK: %[[C0:.*]]= arith.constant 0 : index + // CHECK: %[[C0_0:.*]]= arith.constant 0 : index + // CHECK: %[[READ0:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true]} : tensor<64x512xf32>, vector<64x512xf32> + // CHECK: %[[C0_1:.*]]= arith.constant 0 : index + // CHECK: %[[WRIT0:.*]] = vector.transfer_write %[[READ0]], %[[EMPT]][{{.*}}, {{.*}}] : vector<64x512xf32>, tensor<64x1024xf32> + // CHECK: %[[DIM:.*]] = tensor.dim {{.*}}, {{.*}} : tensor<64x512xf32> + // CHECK: %[[ADD:.*]] = arith.addi %[[DIM]], {{.*}} : index + // CHECK: %[[C0_2:.*]]= arith.constant 0 : index + // CHECK: %[[READ1:.*]] = vector.transfer_read {{.*}}, %[[CST]] {in_bounds = [true, true]} : tensor<64x512xf32>, vector<64x512xf32> + // CHECK: %[[C0_3:.*]]= arith.constant 0 : index + // CHECK: %[[WRIT1:.*]] = vector.transfer_write %[[READ1]], %[[WRIT0]][{{.*}}, {{.*}}] : vector<64x512xf32>, tensor<64x1024xf32> + // CHECK: return %[[WRIT1]] : tensor<64x1024xf32> + %0 = tensor.concat dim(1) %arg0, %arg1 + : (tensor<64x512xf32>, tensor<64x512xf32>) -> tensor<64x1024xf32> + return %0 : tensor<64x1024xf32> + } + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["tensor.concat"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.structured.vectorize %0 : !transform.any_op + transform.yield + } +} +