Skip to content

Commit 60fb9ca

Browse files
andylytensorflower-gardener
authored andcommitted
Fix ResourceOpLiftingPass to handle call ops that return tf.ReadVariableOp results.
It is possible for a tf.StatefulPartitionedCall to return the result of a tf.ReadVariableOp (with some potential forwarding through ops like tf.Identity). As return op operands are captured prior to replacing tf.ReadVariableOp results with function args, the new function return operands may not be correct. Instead, when replacing tf.ReadVariableOp results with function args, the operands of the new return are updated. PiperOrigin-RevId: 320597502 Change-Id: I81f614e0b89670c978da376d5810ff82502f601f
1 parent 9065899 commit 60fb9ca

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

tensorflow/compiler/mlir/tensorflow/tests/resource_op_lifting.mlir

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,3 +710,29 @@ func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.res
710710
%0 = "tf._Unknown_"() : () -> tensor<*x!tf.resource<tensor<f32>>>
711711
return %0 : tensor<*x!tf.resource<tensor<f32>>>
712712
}
713+
714+
// -----
715+
716+
// Tests call op where it's result is the result of a tf.ReadVariableOp.
717+
718+
// CHECK-LABEL: func @call_with_forwarded_read_only_result
719+
// CHECK-SAME: (%[[RESOURCE_ARG0:.*]]: tensor<*x!tf.resource<tensor<f32>>>)
720+
func @call_with_forwarded_read_only_result(%arg0: tensor<*x!tf.resource<tensor<f32>>>) {
721+
// CHECK: %[[READ:.*]] = "tf.ReadVariableOp"(%[[RESOURCE_ARG0]])
722+
%0 = "tf_device.cluster"() ( {
723+
// CHECK: %[[CALL:.*]] = "tf.StatefulPartitionedCall"(%[[READ]])
724+
%1 = "tf.StatefulPartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
725+
// CHECK-NEXT: tf_device.return %[[CALL]]
726+
tf_device.return %1 : tensor<f32>
727+
}) {} : () -> tensor<f32>
728+
return
729+
}
730+
731+
func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
732+
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32>
733+
%1 = "tf.Identity"(%0) {device = ""} : (tensor<f32>) -> tensor<f32>
734+
return %1 : tensor<f32>
735+
}
736+
737+
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>) -> tensor<f32>
738+
// CHECK-NEXT: return %[[A0]]

tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -880,23 +880,34 @@ LogicalResult HandlePartitionedCallOpCallee(
880880
result->arg_data_type_and_updated_output_index[entry.getFirst()] = {
881881
entry.getSecond(), -1};
882882
}
883-
llvm::SmallVector<Value, 4> new_retvals;
884-
for (auto val : callee.front().getTerminator()->getOperands()) {
885-
// Remove resource type outputs.
886-
if (getElementTypeOrSelf(val.getType()).isa<TF::ResourceType>()) continue;
887-
new_retvals.push_back(val);
883+
llvm::SmallVector<int64_t, 4> retval_indices_to_preserve;
884+
for (auto& val : callee.front().getTerminator()->getOpOperands()) {
885+
// Store indices of results that are not resources.
886+
if (!getElementTypeOrSelf(val.get().getType()).isa<TF::ResourceType>())
887+
retval_indices_to_preserve.push_back(val.getOperandNumber());
888888
}
889+
int64_t num_retvals = retval_indices_to_preserve.size();
890+
llvm::SmallVector<Value, 4> new_retvals;
889891
// Lift resources.
890892
LiftArgRetResourcesForFunction(
891893
callee, remaining_resource_data_types, [&](int64_t index, Value value) {
892894
result->arg_data_type_and_updated_output_index[index].second =
893-
new_retvals.size();
895+
num_retvals++;
894896
new_retvals.push_back(value);
895897
});
898+
896899
auto old_return = callee.front().getTerminator();
900+
llvm::SmallVector<Value, 4> old_and_new_retvals;
901+
old_and_new_retvals.reserve(retval_indices_to_preserve.size() +
902+
new_retvals.size());
903+
for (int64_t retval_index : retval_indices_to_preserve)
904+
old_and_new_retvals.push_back(old_return->getOperand(retval_index));
905+
906+
old_and_new_retvals.append(new_retvals.begin(), new_retvals.end());
897907
// Replace old return with the new ones with update values.
898908
OpBuilder builder(old_return);
899-
auto new_return = builder.create<ReturnOp>(old_return->getLoc(), new_retvals);
909+
auto new_return =
910+
builder.create<ReturnOp>(old_return->getLoc(), old_and_new_retvals);
900911
old_return->erase();
901912
callee.setType(FunctionType::get(
902913
callee.getType().getInputs(),

0 commit comments

Comments
 (0)