Skip to content

Commit 0155a11

Browse files
Michael Levesque-Diontensorflower-gardener
authored andcommitted
Generate mhlo.dynamic_reshape instead of chlo.dynamic_reshape for squeeze
The extra logic that chlo.dynamic_reshape implements is not needed for squeeze. PiperOrigin-RevId: 630150307
1 parent ce76d42 commit 0155a11

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2680,7 +2680,7 @@ func.func @squeeze_ranked(%arg0: tensor<?x?x?xf32>) -> tensor<?xf32> {
26802680
// CHECK: %[[C2:.*]] = arith.constant 2 : index
26812681
// CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x?xf32>
26822682
// CHECK: %[[T:.*]] = tensor.from_elements %[[D2]] : tensor<1xindex>
2683-
// CHECK: %[[R:.*]] = "chlo.dynamic_reshape"(%arg0, %[[T]]) : (tensor<?x?x?xf32>, tensor<1xindex>) -> tensor<?xf32>
2683+
// CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor<?x?x?xf32>, tensor<1xindex>) -> tensor<?xf32>
26842684
// CHECK: return %[[R]] : tensor<?xf32>
26852685
%0 = "tf.Squeeze"(%arg0) { squeeze_dims = [0, 1] }: (tensor<?x?x?xf32>) -> tensor<?xf32>
26862686
func.return %0 : tensor<?xf32>
@@ -2695,7 +2695,7 @@ func.func @squeeze_ranked_negative(%arg0: tensor<?x?x10xf32>) -> tensor<?x10xf32
26952695
// CHECK: %[[C2:.*]] = arith.constant 2 : index
26962696
// CHECK: %[[D2:.*]] = tensor.dim %arg0, %[[C2]] : tensor<?x?x10xf32>
26972697
// CHECK: %[[T:.*]] = tensor.from_elements %[[D0]], %[[D2]] : tensor<2xindex>
2698-
// CHECK: %[[R:.*]] = "chlo.dynamic_reshape"(%arg0, %[[T]]) : (tensor<?x?x10xf32>, tensor<2xindex>) -> tensor<?x10xf32>
2698+
// CHECK: %[[R:.*]] = mhlo.dynamic_reshape %arg0, %[[T]] : (tensor<?x?x10xf32>, tensor<2xindex>) -> tensor<?x10xf32>
26992699
// CHECK: return %[[R]] : tensor<?x10xf32>
27002700
%0 = "tf.Squeeze"(%arg0) { squeeze_dims = [-2] }: (tensor<?x?x10xf32>) -> tensor<?x10xf32>
27012701
func.return %0 : tensor<?x10xf32>

tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6457,9 +6457,7 @@ class ConvertDynamicSqueezeOp : public OpRewritePattern<TF::SqueezeOp> {
64576457

64586458
auto from_extents =
64596459
rewriter.create<tensor::FromElementsOp>(op.getLoc(), dims);
6460-
// chlo::DynamicReshapeOp checks if the reshape is legal and will fail if
6461-
// any non-1 dimension is squeezed.
6462-
rewriter.replaceOpWithNewOp<chlo::DynamicReshapeOp>(op, result_ty, input,
6460+
rewriter.replaceOpWithNewOp<mhlo::DynamicReshapeOp>(op, result_ty, input,
64636461
from_extents);
64646462
return success();
64656463
}

0 commit comments

Comments
 (0)