diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp index e1ed5d81625d8..74382b027c2f4 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDropLeadUnitDim.cpp @@ -73,7 +73,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim VectorType oldDstType = extractOp.getType(); VectorType newDstType = VectorType::get(oldDstType.getShape().drop_front(dropCount), - oldDstType.getElementType()); + oldDstType.getElementType(), + oldDstType.getScalableDims().drop_front(dropCount)); Location loc = extractOp.getLoc(); diff --git a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir index f601be0416814..bb2d30f209243 100644 --- a/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-dropleadunitdim-transforms.mlir @@ -206,6 +206,16 @@ func.func @cast_away_extract_strided_slice_leading_one_dims(%arg0: vector<1x8x8x return %0: vector<1x1x8xf16> } +// CHECK-LABEL: func @cast_away_extract_strided_slice_leading_one_dims_scalable +func.func @cast_away_extract_strided_slice_leading_one_dims_scalable(%arg0: vector<1x8x[8]xf16>) -> vector<1x1x[8]xf16> { + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16> + // CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[SRC]] {offsets = [4], sizes = [1], strides = [1]} : vector<8x[8]xf16> to vector<1x[8]xf16> + %0 = vector.extract_strided_slice %arg0 {offsets = [0, 4], sizes = [1, 1], strides = [1, 1]} : vector<1x8x[8]xf16> to vector<1x1x[8]xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[EXTRACT]] : vector<1x[8]xf16> to vector<1x1x[8]xf16> + // CHECK: return %[[RET]] + return %0: vector<1x1x[8]xf16> +} + // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16>, %arg1: vector<1x8x8xf16>) -> vector<1x8x8xf16> { // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<8xf16> from vector<1x8xf16> @@ -217,6 +227,17 @@ func.func @cast_away_insert_strided_slice_leading_one_dims(%arg0: vector<1x8xf16 return %0: vector<1x8x8xf16> } +// CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_scalable +func.func @cast_away_insert_strided_slice_leading_one_dims_scalable(%arg0: vector<1x[8]xf16>, %arg1: vector<1x8x[8]xf16>) -> vector<1x8x[8]xf16> { + // CHECK: %[[SRC:.+]] = vector.extract %{{.*}}[0] : vector<[8]xf16> from vector<1x[8]xf16> + // CHECK: %[[DST:.+]] = vector.extract %{{.*}}[0] : vector<8x[8]xf16> from vector<1x8x[8]xf16> + // CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[SRC]], %[[DST]] {offsets = [0, 0], strides = [1]} : vector<[8]xf16> into vector<8x[8]xf16> + %0 = vector.insert_strided_slice %arg0, %arg1 {offsets = [0, 0, 0], strides = [1, 1]} : vector<1x[8]xf16> into vector<1x8x[8]xf16> + // CHECK: %[[RET:.+]] = vector.broadcast %[[INSERT]] : vector<8x[8]xf16> to vector<1x8x[8]xf16> + // CHECK: return %[[RET]] + return %0: vector<1x8x[8]xf16> +} + // CHECK-LABEL: func @cast_away_insert_strided_slice_leading_one_dims_one_element // CHECK-SAME: %[[ARG0:.+]]: vector<1x1xf16>, %{{.+}}: vector<1x1x1xf16> func.func @cast_away_insert_strided_slice_leading_one_dims_one_element(%arg0: vector<1x1xf16>, %arg1: vector<1x1x1xf16>) -> vector<1x1x1xf16> {