Skip to content

Commit

Permalink
Revert "[MLIR][Shape] Concretize broadcast result type if possible"
Browse files Browse the repository at this point in the history
This reverts commit dca5361.
  • Loading branch information
frgossen committed Apr 28, 2021
1 parent b87219f commit 511ffe1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 52 deletions.
3 changes: 1 addition & 2 deletions mlir/include/mlir/Dialect/Shape/IR/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class PatternRewriter;
namespace shape {

/// Alias type for extent tensors.
RankedTensorType getExtentTensorType(MLIRContext *ctx,
int64_t rank = ShapedType::kDynamicSize);
RankedTensorType getExtentTensorType(MLIRContext *ctx);

// Check if a type is an extent tensor, e.g., tensor<?xindex>.
bool isExtentTensorType(Type);
Expand Down
37 changes: 3 additions & 34 deletions mlir/lib/Dialect/Shape/IR/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ namespace {
#include "ShapeCanonicalization.inc"
}

RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
return RankedTensorType::get({rank}, IndexType::get(ctx));
RankedTensorType shape::getExtentTensorType(MLIRContext *ctx) {
return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx));
}

bool shape::isExtentTensorType(Type type) {
Expand Down Expand Up @@ -660,42 +660,11 @@ struct CanonicalizeCastExtentTensorOperandsPattern
return success();
}
};

struct BroadcastConcretizeResultTypePattern
: public OpRewritePattern<BroadcastOp> {
using OpRewritePattern<BroadcastOp>::OpRewritePattern;

LogicalResult matchAndRewrite(BroadcastOp op,
PatternRewriter &rewriter) const override {
// Only concretize dynamic extent tensor result types.
auto resultTy = op.getType().dyn_cast<RankedTensorType>();
if (!resultTy || !resultTy.isDynamicDim(0))
return failure();

// Infer resulting shape rank if possible.
int64_t maxRank = 0;
for (Value shape : op.shapes()) {
if (auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
// Cannot infer resulting shape rank if any operand is dynamically
// ranked.
if (extentTensorTy.isDynamicDim(0))
return failure();
maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
}
}

auto newOp = rewriter.create<BroadcastOp>(
op.getLoc(), getExtentTensorType(getContext(), maxRank), op.shapes());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
return success();
}
};
} // namespace

void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<BroadcastConcretizeResultTypePattern,
BroadcastFoldConstantOperandsPattern,
patterns.add<BroadcastFoldConstantOperandsPattern,
BroadcastForwardSingleOperandPattern,
CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
RemoveDuplicateOperandsPattern<BroadcastOp>,
Expand Down
17 changes: 1 addition & 16 deletions mlir/test/Dialect/Shape/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1344,8 +1344,7 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
%arg1 : tensor<3xindex>) -> (!shape.witness, tensor<?xindex>) {
// CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?xindex> to tensor<3xindex>
// CHECK: %[[WIT:.*]] = shape.cstr_broadcastable %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
// CHECK: %[[UNCAST_RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex> -> tensor<3xindex>
// CHECK: %[[RES:.*]] = tensor.cast %[[UNCAST_RES]] : tensor<3xindex> to tensor<?xindex>
// CHECK: %[[RES:.*]] = shape.broadcast %[[CAST_ARG0]], %[[ARG1]] : tensor<3xindex>, tensor<3xindex>
// CHECK: return %[[WIT]], %[[RES]]
%0 = tensor.cast %arg0 : tensor<?xindex> to tensor<3xindex>
%1 = tensor.cast %arg1 : tensor<3xindex> to tensor<?xindex>
Expand All @@ -1354,17 +1353,3 @@ func @cast_extent_tensor_operands(%arg0 : tensor<?xindex>,
-> tensor<?xindex>
return %2, %3 : !shape.witness, tensor<?xindex>
}

// -----

// CHECK-LABEL: @concretize_broadcast_result_type
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2xindex>, %[[ARG1:.*]]: tensor<3xindex>)
func @concretize_broadcast_result_type(%arg0 : tensor<2xindex>,
%arg1 : tensor<3xindex>) -> tensor<?xindex> {
// CHECK: %[[CONCR:.*]] = shape.broadcast %[[ARG0]], %[[ARG1]] : tensor<2xindex>, tensor<3xindex> -> tensor<3xindex>
// CHECK: %[[RES:.*]] = tensor.cast %[[CONCR]] : tensor<3xindex> to tensor<?xindex>
// CHECK: return %[[RES]]
%0 = shape.broadcast %arg0, %arg1 : tensor<2xindex>, tensor<3xindex>
-> tensor<?xindex>
return %0 : tensor<?xindex>
}

0 comments on commit 511ffe1

Please sign in to comment.