diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 603e86ca3d766..d1b73ff2dbd0c 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -433,18 +433,23 @@ struct ChainedTensorCast : public OpRewritePattern { // We can remove the intermediate cast if joining all three produces the // same result as just joining the source and result shapes. - auto firstJoin = - joinShapes(joinShapes(sourceType, intermediateType), resultType); + auto firstJoin = joinShapes(sourceType, intermediateType); // The join might not exist if the cast sequence would fail at runtime. if (!firstJoin) return failure(); + auto secondJoin = joinShapes(firstJoin, resultType); + + // The join might not exist if the cast sequence would fail at runtime. + if (!secondJoin) + return failure(); + // The newJoin always exists if the above join exists, it might just contain // less information. If so, we cannot drop the intermediate cast, as doing // so would remove runtime checks. auto newJoin = joinShapes(sourceType, resultType); - if (firstJoin != newJoin) + if (secondJoin != newJoin) return failure(); rewriter.replaceOpWithNewOp(tensorCast, resultType,