diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h index 08c5d5ddbc82f9..570719eff64d56 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -29,8 +29,7 @@ class PatternRewriter; namespace shape { /// Alias type for extent tensors. -RankedTensorType getExtentTensorType(MLIRContext *ctx, - int64_t rank = ShapedType::kDynamicSize); +RankedTensorType getExtentTensorType(MLIRContext *ctx); // Check if a type is an extent tensor, e.g., tensor. bool isExtentTensorType(Type); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index ac67a62a0aef5f..fd012aa84d1c8f 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -27,8 +27,8 @@ namespace { #include "ShapeCanonicalization.inc" } -RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) { - return RankedTensorType::get({rank}, IndexType::get(ctx)); +RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) { + return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); } bool shape::isExtentTensorType(Type type) { @@ -660,42 +660,11 @@ struct CanonicalizeCastExtentTensorOperandsPattern return success(); } }; - -struct BroadcastConcretizeResultTypePattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BroadcastOp op, - PatternRewriter &rewriter) const override { - // Only concretize dynamic extent tensor result types. - auto resultTy = op.getType().dyn_cast(); - if (!resultTy || !resultTy.isDynamicDim(0)) - return failure(); - - // Infer resulting shape rank if possible. - int64_t maxRank = 0; - for (Value shape : op.shapes()) { - if (auto extentTensorTy = shape.getType().dyn_cast()) { - // Cannot infer resulting shape rank if any operand is dynamically - // ranked. - if (extentTensorTy.isDynamicDim(0)) - return failure(); - maxRank = std::max(maxRank, extentTensorTy.getDimSize(0)); - } - } - - auto newOp = rewriter.create( - op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes()); - rewriter.replaceOpWithNewOp(op, op.getType(), newOp); - return success(); - } -}; } // namespace void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add, RemoveDuplicateOperandsPattern, diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 6e02438391320f..367ce7f6ba1ac3 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1344,8 +1344,7 @@ func @cast_extent_tensor_operands(%arg0 : tensor, %arg1 : tensor<3xindex>) -> (!shape.witness, tensor) { // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor to tensor<3xindex> // CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> - // CHECK: %[[UNCAST_RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex> - // CHECK: %[[RES:.*]] = tensor.cast %[[UNCAST_RES]] : tensor<3xindex> to tensor + // CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> // CHECK: return %[[WIT]], %[[RES]] %0 = tensor.cast %arg0 : tensor to tensor<3xindex> %1 = tensor.cast %arg1 : tensor<3xindex> to tensor @@ -1354,17 +1353,3 @@ func @cast_extent_tensor_operands(%arg0 : tensor, -> tensor return %2, %3 : !shape.witness, tensor } - -// ----- - -// CHECK-LABEL: @concretize_broadcast_result_type -// CHECK-SAME: (%[[ARG0:.*]]: tensor<2xindex>, %[[ARG1:.*]]: tensor<3xindex>) -func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>, - %arg1 : tensor<3xindex>) -> tensor { - // CHECK: %[[CONCR:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<2xindex>, tensor<3xindex> -> tensor<3xindex> - // CHECK: %[[RES:.*]] = tensor.cast %[[CONCR]] : tensor<3xindex> to tensor - // CHECK: return %[[RES]] - %0 = shape.broadcast %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> - -> tensor - return %0 : tensor -}