Skip to content

Add fetchVariable method to Session to get value of resource variable #261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
package org.tensorflow;

import static org.tensorflow.Graph.resolveOutputs;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrType;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun;
import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig;

Expand All @@ -38,8 +36,12 @@
import org.tensorflow.internal.c_api.TF_SessionOptions;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.TF_Tensor;
import org.tensorflow.internal.types.registry.TensorTypeRegistry;
import org.tensorflow.op.Op;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.ReadVariableOp;
import org.tensorflow.proto.framework.ConfigProto;
import org.tensorflow.proto.framework.DataType;
import org.tensorflow.proto.framework.RunMetadata;
import org.tensorflow.proto.framework.RunOptions;
import org.tensorflow.proto.util.SaverDef;
Expand Down Expand Up @@ -192,6 +194,11 @@ public Runner feed(String operation, int index, Tensor t) {
* @return this session runner
*/
public Runner feed(Operand<?> operand, Tensor t) {
if (operand.env() != graph) {
throw new IllegalStateException("Can't feed value for operand " + operand + ", it is from " +
(operand.env().isEager() ? "an eager session" : "a different graph") + ".");
}

inputs.add(operand.asOutput());
inputTensors.add(t);
return this;
Expand All @@ -200,6 +207,8 @@ public Runner feed(Operand<?> operand, Tensor t) {
/**
* Make {@link #run()} return the output of {@code operation}.
*
* If the output is a resource variable, will fetch the value.
*
* @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code
* fetch(operation, 0)}, or it is a string of the form
* <tt>operation_name:output_index</tt> , in which case this method acts like {@code
Expand All @@ -215,6 +224,8 @@ public Runner fetch(String operation) {
/**
* Make {@link #run()} return the {@code index}-th output of {@code operation}.
*
* If the output is a resource variable, will fetch the value.
*
* <p>Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which
* one to return.
*
Expand All @@ -225,24 +236,61 @@ public Runner fetch(String operation) {
*/
public Runner fetch(String operation, int index) {
Operation op = graph.operationOrThrow(operation);
outputs.add(op.output(index));
return this;
return fetch(op.output(index));
}

/**
* Makes {@link #run()} return the Tensor referred to by {@code output}.
*
* If {@code output} is a resource variable, will fetch the value.
*
* @param output the node to fetch the tensor from
* @return this session runner
*/
public Runner fetch(Output<?> output) {
outputs.add(output);
if (output.env() != graph) {
throw new IllegalStateException("Can't fetch output " + output + ", it is from " +
(output.env().isEager() ? "an eager session" : "a different graph") + ".");
}

if (output.dataType() == DataType.DT_RESOURCE) {
int[] rawDt = new int[1];

GraphOperation graphOp = (GraphOperation) output.op();

try (PointerScope scope = new PointerScope()) {
TF_Status status = TF_Status.newStatus();
TF_OperationGetAttrType(graphOp.getUnsafeNativeHandle(), "dtype", rawDt, status);
status.throwExceptionIfNotOK();
}

DataType valueDt = DataType.forNumber(rawDt[0]);

Operand<?> read = null;
for (GraphOperation op : graphOp.consumers()) {
if (op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)) {
read = op.output(0);
break;
}
}

if (read == null) {
read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read")
.readVariableOp(output, TensorTypeRegistry.find(valueDt).type());
}

outputs.add(read.asOutput());
} else {
outputs.add(output);
}
return this;
}

/**
* Makes {@link #run()} return the Tensor referred to by the output of {@code operand}.
*
* If {@code operand} is a resource variable, will fetch the value.
*
* @param operand the node to fetch the tensor from, as an operand
* @return this session runner
*/
Expand All @@ -258,9 +306,7 @@ public Runner fetch(Operand<?> operand) {
* @throws IllegalArgumentException if no operation exists with the provided name
*/
public Runner addTarget(String operation) {
GraphOperation op = graph.operationOrThrow(operation);
targets.add(op);
return this;
return addTarget(graph.operationOrThrow(operation));
}

/**
Expand All @@ -269,13 +315,12 @@ public Runner addTarget(String operation) {
* @param operation the operation to execute
* @return this session runner
* @throws IllegalArgumentException if the operation is not a {@link GraphOperation}
* @throws IllegalStateException if the operation is not from the session's graph.
*/
public Runner addTarget(Operation operation) {
if (!(operation instanceof GraphOperation)) {
throw new IllegalArgumentException(
"Operation of type "
+ operation.getClass().getName()
+ " is not supported in graph sessions");
if (operation.env() != graph) {
throw new IllegalStateException("Can't target operation " + operation + ", it is from " +
(operation.env().isEager() ? "an eager session" : "a different graph") + ".");
}
targets.add((GraphOperation) operation);
return this;
Expand Down Expand Up @@ -594,12 +639,12 @@ private static void delete(TF_Session handle) {
*
* @param handle to the C API TF_Session object (Session.nativeHandle)
* @param runOptions A RunOptions protocol buffer, or null
* @param inputOpHandles (see inputOpIndices)
* @param inputOpIndices (see inputTensorHandles)
* @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed"
* (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a
* Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus,
* it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length.
* @param inputOpHandles (see inputOpIndices)
* @param inputOpIndices (see inputTensorHandles)
* @param outputOpHandles (see outputOpIndices)
* @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The
* outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length ==
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Collections;
import java.util.Map;
import java.util.HashMap;
import java.util.Map;
import org.junit.jupiter.api.Test;
import org.tensorflow.exceptions.TensorFlowException;
import org.tensorflow.ndarray.FloatNdArray;
Expand Down Expand Up @@ -292,21 +292,29 @@ public void pythonTfFunction() {
ConcreteFunction add = bundle.function("add");
Map<String, Tensor> args = new HashMap();
try (TFloat32 a = TFloat32.scalarOf(10.0f);
TFloat32 b = TFloat32.scalarOf(15.5f)) {
TFloat32 b = TFloat32.scalarOf(15.5f)) {
args.put("a", a);
args.put("b", b);
Map<String, Tensor> result = add.call(args);
assertEquals(result.size(), 1);
try (TFloat32 c = (TFloat32)result.values().iterator().next()) {
try (TFloat32 c = (TFloat32) result.values().iterator().next()) {
assertEquals(25.5f, c.getFloat());
}
}

// variable unwrapping happens in Session, which is used by ConcreteFunction.call
ConcreteFunction getVariable = bundle.function("get_variable");
try (TFloat32 v = (TFloat32) getVariable.call(new HashMap<>())
.get(getVariable.signature().outputNames().iterator().next())) {
assertEquals(2f, v.getFloat());
}

}
}

private static Signature buildGraphWithVariables(Ops tf, Shape xShape) {
Placeholder<TFloat32> x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape));
Variable<TFloat32> y = tf
Variable<TFloat32> y = tf.withName("variable")
.variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class));
ReduceSum<TFloat32> z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1));
Init init = tf.init();
Expand Down
Loading