From e08eb06cd444ee3810943db82671c7669a43d22e Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 29 Mar 2021 19:48:17 -0700 Subject: [PATCH 1/8] Add fetchVariable method to Session to get value of resource variable Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/RawTensor.java | 17 +++ .../src/main/java/org/tensorflow/Session.java | 111 +++++++++++++++++- .../test/java/org/tensorflow/SessionTest.java | 47 ++++++++ 3 files changed, 170 insertions(+), 5 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index c332fd7f1d1..84a819f7d1f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -154,6 +154,23 @@ static RawTensor fromHandle(TF_Tensor handle) { return t; } + /** + * Create a Tensor object from a handle to the C TF_Tensor object. + * DOES NOT SET THE TYPE INFO, should only be passed directly to a {@link OperationBuilder#setAttr(String, Tensor)}. + * Will likely NPE otherwise. + * + *

Takes ownership of the handle. + */ + static RawTensor dangerousUntypedRawTensorFromHandle(TF_Tensor handle) { + RawTensor t = new RawTensor(null, Shape.of(shape(handle))); + try (PointerScope scope = new PointerScope()) { + scope.attach(handle); + t.tensorHandle = handle; + t.tensorScope = scope.extend(); + } + return t; + } + /** * Create an eager Tensor object from a handle to the C TF_Tensor object. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 6395b1770a9..100d1488c66 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -24,7 +24,10 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; +import java.util.Map; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -39,11 +42,15 @@ import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Constant; import org.tensorflow.proto.framework.ConfigProto; +import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; /** * Driver for {@link Graph} execution. @@ -225,8 +232,7 @@ public Runner fetch(String operation) { */ public Runner fetch(String operation, int index) { Operation op = graph.operationOrThrow(operation); - outputs.add(op.output(index)); - return this; + return fetch(op.output(index)); } /** @@ -236,6 +242,9 @@ public Runner fetch(String operation, int index) { * @return this session runner */ public Runner fetch(Output output) { + if(output.dataType() == DataType.DT_RESOURCE){ + throw new IllegalArgumentException("Output " + output + " is a resource variable, fetch using fetchVariable(), not fetch()."); + } outputs.add(output); return this; } @@ -250,6 +259,63 @@ public Runner fetch(Operand operand) { return fetch(operand.asOutput()); } + /** + * Make {@link #run()} return the value of the output of {@code operation}, which should be a resource variable. + * + * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code + * fetch(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * fetch(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} + * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. + * @return this session runner + * @throws IllegalArgumentException if no output exists with the provided name + */ + public Runner fetchVariable(String operation, Class type) { + return fetchVariable(graph.outputOrThrow(operation), type); + } + + /** + * Make {@link #run()} return the value of the {@code index}-th output of {@code operation} which should be a resource variable. + * + *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which + * one to return. + * + * @param operation the string name of the operation + * @return this session runner + * @throws IllegalArgumentException if no operation exists with the provided name + * @throws IndexOutOfBoundsException if the operation has no output with the given index + */ + public Runner fetchVariable(String operation, int index, Class type) { + Operation op = graph.operationOrThrow(operation); + return fetchVariable(op.output(index), type); + } + + /** + * Makes {@link #run()} return the value of the resource variable referred to by {@code output}. + * + * @param output the node to fetch the tensor from + * @return this session runner + */ + public Runner fetchVariable(Output output, Class type) { + if(output.dataType() != DataType.DT_RESOURCE){ + throw new IllegalArgumentException("Output " + output + " is not a resource variable, fetch using fetch(), not fetchVariable()."); + } + outputs.add(output); + variableTypes.put(this.outputs.size() - 1, type); + return this; + } + + /** + * Makes {@link #run()} return the value of the resource variable referred to by the output of {@code operand}. + * + * @param operand the node to fetch the tensor from, as an operand + * @return this session runner + */ + public Runner fetchVariable(Operand operand, Class type) { + fetchVariable(operand.asOutput(), type); + return this; + } + /** * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. * @@ -375,6 +441,7 @@ private Run runHelper(boolean wantMetadata) { Session.run( nativeHandle, runOptions, + variableTypes, inputTensorHandles, inputOpHandles, inputOpIndices, @@ -426,6 +493,7 @@ public void close() { private final ArrayList inputTensors = new ArrayList<>(); private final ArrayList> outputs = new ArrayList<>(); private final ArrayList targets = new ArrayList<>(); + private final HashMap> variableTypes = new HashMap<>(); private RunOptions runOptions = null; } @@ -614,6 +682,7 @@ private static void delete(TF_Session handle) { private static RunMetadata run( TF_Session handle, RunOptions runOptions, + Map> variableTypes, TF_Tensor[] inputTensorHandles, TF_Operation[] inputOpHandles, int[] inputOpIndices, @@ -659,9 +728,41 @@ private static RunMetadata run( status); status.throwExceptionIfNotOK(); - for (int i = 0; i < noutputs; ++i) { - TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(); - outputTensors.add(RawTensor.fromHandle(h).asTypedTensor()); + Ops reader = null; + EagerSession eagerSession = null; + if(!variableTypes.isEmpty()){ + eagerSession = EagerSession.create(); + reader = Ops.create(eagerSession); + } + + try { + for (int i = 0; i < noutputs; ++i) { + TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(); + Tensor value; + if (variableTypes.containsKey(i)) { + RawTensor variable = RawTensor.dangerousUntypedRawTensorFromHandle(h); + + OperationBuilder builder = reader.scope() + .env() + .opBuilder(Constant.OP_NAME, reader.scope().makeOpName(Constant.OP_NAME)) + .setAttr("value", variable) + .setAttr("dtype", DataType.DT_RESOURCE); + + reader.scope().apply(builder); + + Operation constant = builder.build(); + + Operand read = reader.readVariableOp(constant.output(0), variableTypes.get(i)); + value = read.asTensor(); + } else { + value = RawTensor.fromHandle(h).asTypedTensor(); + } + outputTensors.add(value); + } + } finally { + if(eagerSession != null){ + eagerSession.close(); + } } try { return runMetadata != null ? RunMetadata.parseFrom(runMetadata.dataAsByteBuffer()) : null; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index d7ea381d315..9a67a486c1c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -29,6 +29,7 @@ import java.util.Comparator; import org.junit.jupiter.api.Test; +import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Split; @@ -249,6 +250,52 @@ public void saveAndRestore() throws IOException { .forEach(File::delete); } + @Test + public static void testFetchVariable(){ + try(Graph g = new Graph(); + Session s = new Session(g)){ + Ops tf = Ops.create(g); + Operand variable = tf.varHandleOp(TInt32.class, Shape.scalar()); + Op assign = tf.assignVariableOp(variable, tf.constant(2)); + + try(TInt32 value = (TInt32) s.runner().addTarget(assign).fetchVariable(variable, TInt32.class).run().get(0)){ + assertEquals(2, value.getInt()); + } + + } + } + + @Test + public static void testFetchVariableException(){ + try(Graph g = new Graph(); + Session s = new Session(g)){ + Ops tf = Ops.create(g); + Operand variable = tf.varHandleOp(TInt32.class, Shape.scalar()); + Op assign = tf.assignVariableOp(variable, tf.constant(2)); + + try(TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)){ + fail(); + } catch (IllegalStateException e){ + assertTrue(e.getMessage().contains("is a resource variable")); + } + } + } + + @Test + public static void testFetchVariableNonVariableException(){ + try(Graph g = new Graph(); + Session s = new Session(g)){ + Ops tf = Ops.create(g); + Operand constant = tf.constant(2); + + try(TInt32 value = (TInt32) s.runner().fetchVariable(constant, TInt32.class).run().get(0)){ + fail(); + } catch (IllegalStateException e){ + assertTrue(e.getMessage().contains("is not a resource variable")); + } + } + } + private static RunOptions fullTraceRunOptions() { return RunOptions.newBuilder() .setTraceLevel(RunOptions.TraceLevel.FULL_TRACE) From f8f9a2f0f19bea1a9f7f9ab5157a45e385aef532 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 29 Mar 2021 19:55:38 -0700 Subject: [PATCH 2/8] Format Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 100d1488c66..2be002cedf4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -16,16 +16,12 @@ package org.tensorflow; import static org.tensorflow.Graph.resolveOutputs; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession; -import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.List; import java.util.Map; import org.bytedeco.javacpp.BytePointer; @@ -242,8 +238,9 @@ public Runner fetch(String operation, int index) { * @return this session runner */ public Runner fetch(Output output) { - if(output.dataType() == DataType.DT_RESOURCE){ - throw new IllegalArgumentException("Output " + output + " is a resource variable, fetch using fetchVariable(), not fetch()."); + if (output.dataType() == DataType.DT_RESOURCE) { + throw new IllegalArgumentException( + "Output " + output + " is a resource variable, fetch using fetchVariable(), not fetch()."); } outputs.add(output); return this; @@ -275,7 +272,8 @@ public Runner fetchVariable(String operation, Class type) { } /** - * Make {@link #run()} return the value of the {@code index}-th output of {@code operation} which should be a resource variable. + * Make {@link #run()} return the value of the {@code index}-th output of {@code operation} which should be a + * resource variable. * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one to return. @@ -297,8 +295,9 @@ public Runner fetchVariable(String operation, int index, Class * @return this session runner */ public Runner fetchVariable(Output output, Class type) { - if(output.dataType() != DataType.DT_RESOURCE){ - throw new IllegalArgumentException("Output " + output + " is not a resource variable, fetch using fetch(), not fetchVariable()."); + if (output.dataType() != DataType.DT_RESOURCE) { + throw new IllegalArgumentException( + "Output " + output + " is not a resource variable, fetch using fetch(), not fetchVariable()."); } outputs.add(output); variableTypes.put(this.outputs.size() - 1, type); @@ -312,7 +311,7 @@ public Runner fetchVariable(Output output, Class type) { * @return this session runner */ public Runner fetchVariable(Operand operand, Class type) { - fetchVariable(operand.asOutput(), type); + fetchVariable(operand.asOutput(), type); return this; } @@ -730,7 +729,7 @@ private static RunMetadata run( Ops reader = null; EagerSession eagerSession = null; - if(!variableTypes.isEmpty()){ + if (!variableTypes.isEmpty()) { eagerSession = EagerSession.create(); reader = Ops.create(eagerSession); } @@ -760,7 +759,7 @@ private static RunMetadata run( outputTensors.add(value); } } finally { - if(eagerSession != null){ + if (eagerSession != null) { eagerSession.close(); } } From 951a2d9297f5a54c7b7db0de0e524f7839048dab Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 29 Mar 2021 19:56:25 -0700 Subject: [PATCH 3/8] More Formatting Signed-off-by: Ryan Nett --- .../test/java/org/tensorflow/SessionTest.java | 107 +++++++++--------- 1 file changed, 55 insertions(+), 52 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 9a67a486c1c..a278e876a95 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -20,15 +20,15 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; -import java.io.BufferedOutputStream; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; import java.util.Comparator; - import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; @@ -39,13 +39,12 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.GraphDef; import org.tensorflow.proto.framework.RunOptions; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.StdArrays; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.Session}. */ +/** + * Unit tests for {@link org.tensorflow.Session}. + */ public class SessionTest { @Test @@ -53,12 +52,12 @@ public void runUsingOperationNames() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + transpose_A_times_X(tf, new int[][]{{2}, {3}}); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}})); AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } } } @@ -68,14 +67,14 @@ public void runUsingOperationHandles() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); + transpose_A_times_X(tf, new int[][]{{2}, {3}}); Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}})); AutoCloseableList outputs = new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } } } @@ -89,18 +88,18 @@ public void runUsingColonSeparatedNames() { tf.math.add(split.output().get(0), split.output().get(1)); // Fetch using colon separated names. - try (TInt32 fetched = (TInt32)s.runner().fetch("Split:1").run().get(0)) { + try (TInt32 fetched = (TInt32) s.runner().fetch("Split:1").run().get(0)) { assertEquals(3, fetched.getInt(0)); assertEquals(4, fetched.getInt(1)); } // Feed using colon separated names. try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1); TInt32 fetched = (TInt32) s.runner() - .feed("Split:0", fed) - .feed("Split:1", fed) - .fetch("Add") - .run() - .get(0)) { + .feed("Split:0", fed) + .feed("Split:1", fed) + .fetch("Add") + .run() + .get(0)) { assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched); } } @@ -111,17 +110,17 @@ public void runWithMetadata() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { + transpose_A_times_X(tf, new int[][]{{2}, {3}}); + try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][]{{5}, {7}}))) { Session.Run result = s.runner() - .feed("X", x) - .fetch("Y") - .setOptions(fullTraceRunOptions()) - .runAndFetchMetadata(); + .feed("X", x) + .fetch("Y") + .setOptions(fullTraceRunOptions()) + .runAndFetchMetadata(); // Sanity check on outputs. AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); // Sanity check on metadata assertNotNull(result.metadata); assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); @@ -140,8 +139,8 @@ public void runMultipleOutputs() { AutoCloseableList outputs = new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); assertEquals(2, outputs.size()); - assertEquals(31415, ((TInt32)outputs.get(0)).getInt()); - assertEquals(2718, ((TInt32)outputs.get(1)).getInt()); + assertEquals(31415, ((TInt32) outputs.get(0)).getInt()); + assertEquals(2718, ((TInt32) outputs.get(1)).getInt()); outputs.close(); } } @@ -163,7 +162,8 @@ public void failOnUseAfterClose() { @Test public void createWithConfigProto() { try (Graph g = new Graph(); - Session s = new Session(g, singleThreadConfigProto())) {} + Session s = new Session(g, singleThreadConfigProto())) { + } } @Test @@ -214,12 +214,14 @@ public void runInitByName() { } @Test - public void saveAndRestore() throws IOException { + public void saveAndRestore() throws IOException { Path testFolder = Files.createTempDirectory("tf-session-save-restore-test"); try (Graph g = new Graph()) { Ops tf = Ops.create(g); - Variable x = tf.withName("x").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); - Variable y = tf.withName("y").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable x = tf.withName("x") + .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); + Variable y = tf.withName("y") + .variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); Init init = tf.init(); try (Session s = new Session(g)) { @@ -232,9 +234,10 @@ public void saveAndRestore() throws IOException { try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); try (AutoCloseableList oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); - AutoCloseableList newList = new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())){ - assertEquals(oldList.get(0),newList.get(0)); - assertEquals(oldList.get(1),newList.get(1)); + AutoCloseableList newList = new AutoCloseableList<>( + restoredSession.runner().fetch("x").fetch("y").run())) { + assertEquals(oldList.get(0), newList.get(0)); + assertEquals(oldList.get(1), newList.get(1)); } } } @@ -245,20 +248,20 @@ public void saveAndRestore() throws IOException { // Cleanup test dir Files.walk(testFolder) - .sorted(Comparator.reverseOrder()) - .map(Path::toFile) - .forEach(File::delete); + .sorted(Comparator.reverseOrder()) + .map(Path::toFile) + .forEach(File::delete); } @Test - public static void testFetchVariable(){ - try(Graph g = new Graph(); - Session s = new Session(g)){ + public static void testFetchVariable() { + try (Graph g = new Graph(); + Session s = new Session(g)) { Ops tf = Ops.create(g); Operand variable = tf.varHandleOp(TInt32.class, Shape.scalar()); Op assign = tf.assignVariableOp(variable, tf.constant(2)); - try(TInt32 value = (TInt32) s.runner().addTarget(assign).fetchVariable(variable, TInt32.class).run().get(0)){ + try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetchVariable(variable, TInt32.class).run().get(0)) { assertEquals(2, value.getInt()); } @@ -266,31 +269,31 @@ public static void testFetchVariable(){ } @Test - public static void testFetchVariableException(){ - try(Graph g = new Graph(); - Session s = new Session(g)){ + public static void testFetchVariableException() { + try (Graph g = new Graph(); + Session s = new Session(g)) { Ops tf = Ops.create(g); Operand variable = tf.varHandleOp(TInt32.class, Shape.scalar()); Op assign = tf.assignVariableOp(variable, tf.constant(2)); - try(TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)){ + try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) { fail(); - } catch (IllegalStateException e){ + } catch (IllegalStateException e) { assertTrue(e.getMessage().contains("is a resource variable")); } } } @Test - public static void testFetchVariableNonVariableException(){ - try(Graph g = new Graph(); - Session s = new Session(g)){ + public static void testFetchVariableNonVariableException() { + try (Graph g = new Graph(); + Session s = new Session(g)) { Ops tf = Ops.create(g); Operand constant = tf.constant(2); - try(TInt32 value = (TInt32) s.runner().fetchVariable(constant, TInt32.class).run().get(0)){ + try (TInt32 value = (TInt32) s.runner().fetchVariable(constant, TInt32.class).run().get(0)) { fail(); - } catch (IllegalStateException e){ + } catch (IllegalStateException e) { assertTrue(e.getMessage().contains("is not a resource variable")); } } From 6d9317f8e77aacd1b9b98540dfdff66298f90dd3 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 30 Mar 2021 16:25:00 -0700 Subject: [PATCH 4/8] Rework, automatically wrap variables in read when fetched Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 166 ++++++------------ .../test/java/org/tensorflow/SessionTest.java | 39 ++-- 2 files changed, 74 insertions(+), 131 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 2be002cedf4..9e94767d6df 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -16,14 +16,13 @@ package org.tensorflow; import static org.tensorflow.Graph.resolveOutputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_OperationGetAttrType; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun; import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -37,16 +36,16 @@ import org.tensorflow.internal.c_api.TF_SessionOptions; import org.tensorflow.internal.c_api.TF_Status; import org.tensorflow.internal.c_api.TF_Tensor; +import org.tensorflow.internal.types.registry.TensorTypeRegistry; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.ReadVariableOp; import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.DataType; import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; -import org.tensorflow.types.family.TType; /** * Driver for {@link Graph} execution. @@ -195,6 +194,10 @@ public Runner feed(String operation, int index, Tensor t) { * @return this session runner */ public Runner feed(Operand operand, Tensor t) { + if(operand.env() != graph){ + throw new IllegalStateException("Can't feed value to operand " + operand + ", it is from a different graph."); + } + inputs.add(operand.asOutput()); inputTensors.add(t); return this; @@ -203,6 +206,8 @@ public Runner feed(Operand operand, Tensor t) { /** * Make {@link #run()} return the output of {@code operation}. * + * If the output is a resource variable, will fetch the value. + * * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code * fetch(operation, 0)}, or it is a string of the form * operation_name:output_index , in which case this method acts like {@code @@ -218,6 +223,8 @@ public Runner fetch(String operation) { /** * Make {@link #run()} return the {@code index}-th output of {@code operation}. * + * If the output is a resource variable, will fetch the value. + * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one to return. * @@ -234,85 +241,57 @@ public Runner fetch(String operation, int index) { /** * Makes {@link #run()} return the Tensor referred to by {@code output}. * + * If {@code output} is a resource variable, will fetch the value. + * * @param output the node to fetch the tensor from * @return this session runner */ public Runner fetch(Output output) { - if (output.dataType() == DataType.DT_RESOURCE) { - throw new IllegalArgumentException( - "Output " + output + " is a resource variable, fetch using fetchVariable(), not fetch()."); + if(output.env() != graph){ + throw new IllegalStateException("Can't fetch output " + output + ", it is from a different graph."); } - outputs.add(output); - return this; - } - /** - * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. - * - * @param operand the node to fetch the tensor from, as an operand - * @return this session runner - */ - public Runner fetch(Operand operand) { - return fetch(operand.asOutput()); - } + if (output.dataType() == DataType.DT_RESOURCE) { + int[] rawDt = new int[1]; - /** - * Make {@link #run()} return the value of the output of {@code operation}, which should be a resource variable. - * - * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code - * fetch(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * fetch(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} - * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. - * @return this session runner - * @throws IllegalArgumentException if no output exists with the provided name - */ - public Runner fetchVariable(String operation, Class type) { - return fetchVariable(graph.outputOrThrow(operation), type); - } + GraphOperation graphOp = (GraphOperation) output.op(); - /** - * Make {@link #run()} return the value of the {@code index}-th output of {@code operation} which should be a - * resource variable. - * - *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which - * one to return. - * - * @param operation the string name of the operation - * @return this session runner - * @throws IllegalArgumentException if no operation exists with the provided name - * @throws IndexOutOfBoundsException if the operation has no output with the given index - */ - public Runner fetchVariable(String operation, int index, Class type) { - Operation op = graph.operationOrThrow(operation); - return fetchVariable(op.output(index), type); - } + try(PointerScope scope = new PointerScope()) { + TF_Status status = TF_Status.newStatus(); + TF_OperationGetAttrType(graphOp.getUnsafeNativeHandle(), "dtype", rawDt, status); + status.throwExceptionIfNotOK(); + } - /** - * Makes {@link #run()} return the value of the resource variable referred to by {@code output}. - * - * @param output the node to fetch the tensor from - * @return this session runner - */ - public Runner fetchVariable(Output output, Class type) { - if (output.dataType() != DataType.DT_RESOURCE) { - throw new IllegalArgumentException( - "Output " + output + " is not a resource variable, fetch using fetch(), not fetchVariable()."); + DataType valueDt = DataType.forNumber(rawDt[0]); + + Operand read = null; + for(GraphOperation op : graphOp.consumers()){ + if(op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)){ + read = op.output(0); + } + } + + if(read == null){ + read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read").readVariableOp(output, TensorTypeRegistry.find(valueDt).type()); + } + + outputs.add(read.asOutput()); + } else { + outputs.add(output); } - outputs.add(output); - variableTypes.put(this.outputs.size() - 1, type); return this; } /** - * Makes {@link #run()} return the value of the resource variable referred to by the output of {@code operand}. + * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. + * + * If {@code operand} is a resource variable, will fetch the value. * * @param operand the node to fetch the tensor from, as an operand * @return this session runner */ - public Runner fetchVariable(Operand operand, Class type) { - fetchVariable(operand.asOutput(), type); - return this; + public Runner fetch(Operand operand) { + return fetch(operand.asOutput()); } /** @@ -323,9 +302,7 @@ public Runner fetchVariable(Operand operand, Class type) { * @throws IllegalArgumentException if no operation exists with the provided name */ public Runner addTarget(String operation) { - GraphOperation op = graph.operationOrThrow(operation); - targets.add(op); - return this; + return addTarget(graph.operationOrThrow(operation)); } /** @@ -334,13 +311,11 @@ public Runner addTarget(String operation) { * @param operation the operation to execute * @return this session runner * @throws IllegalArgumentException if the operation is not a {@link GraphOperation} + * @throws IllegalStateException if the operation is not from the session's graph. */ public Runner addTarget(Operation operation) { - if (!(operation instanceof GraphOperation)) { - throw new IllegalArgumentException( - "Operation of type " - + operation.getClass().getName() - + " is not supported in graph sessions"); + if(operation.env() != graph){ + throw new IllegalStateException("Can't fetch operation " + operation + ", it is from a different graph."); } targets.add((GraphOperation) operation); return this; @@ -440,7 +415,6 @@ private Run runHelper(boolean wantMetadata) { Session.run( nativeHandle, runOptions, - variableTypes, inputTensorHandles, inputOpHandles, inputOpIndices, @@ -492,7 +466,6 @@ public void close() { private final ArrayList inputTensors = new ArrayList<>(); private final ArrayList> outputs = new ArrayList<>(); private final ArrayList targets = new ArrayList<>(); - private final HashMap> variableTypes = new HashMap<>(); private RunOptions runOptions = null; } @@ -661,12 +634,12 @@ private static void delete(TF_Session handle) { * * @param handle to the C API TF_Session object (Session.nativeHandle) * @param runOptions A RunOptions protocol buffer, or null - * @param inputOpHandles (see inputOpIndices) - * @param inputOpIndices (see inputTensorHandles) * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed" * (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a * Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, * it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. + * @param inputOpHandles (see inputOpIndices) + * @param inputOpIndices (see inputTensorHandles) * @param outputOpHandles (see outputOpIndices) * @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The * outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length == @@ -681,7 +654,6 @@ private static void delete(TF_Session handle) { private static RunMetadata run( TF_Session handle, RunOptions runOptions, - Map> variableTypes, TF_Tensor[] inputTensorHandles, TF_Operation[] inputOpHandles, int[] inputOpIndices, @@ -727,41 +699,9 @@ private static RunMetadata run( status); status.throwExceptionIfNotOK(); - Ops reader = null; - EagerSession eagerSession = null; - if (!variableTypes.isEmpty()) { - eagerSession = EagerSession.create(); - reader = Ops.create(eagerSession); - } - - try { - for (int i = 0; i < noutputs; ++i) { - TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(); - Tensor value; - if (variableTypes.containsKey(i)) { - RawTensor variable = RawTensor.dangerousUntypedRawTensorFromHandle(h); - - OperationBuilder builder = reader.scope() - .env() - .opBuilder(Constant.OP_NAME, reader.scope().makeOpName(Constant.OP_NAME)) - .setAttr("value", variable) - .setAttr("dtype", DataType.DT_RESOURCE); - - reader.scope().apply(builder); - - Operation constant = builder.build(); - - Operand read = reader.readVariableOp(constant.output(0), variableTypes.get(i)); - value = read.asTensor(); - } else { - value = RawTensor.fromHandle(h).asTypedTensor(); - } - outputTensors.add(value); - } - } finally { - if (eagerSession != null) { - eagerSession.close(); - } + for (int i = 0; i < noutputs; ++i) { + TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(); + outputTensors.add(RawTensor.fromHandle(h).asTypedTensor()); } try { return runMetadata != null ? RunMetadata.parseFrom(runMetadata.dataAsByteBuffer()) : null; diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index a278e876a95..d6d4e6d419d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -15,6 +15,7 @@ package org.tensorflow; +import static java.util.Collections.singletonList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -25,6 +26,9 @@ import java.nio.file.Files; import java.nio.file.Path; import java.util.Comparator; +import java.util.Iterator; +import java.util.Spliterators; +import java.util.stream.StreamSupport; import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -261,41 +265,40 @@ public static void testFetchVariable() { Operand variable = tf.varHandleOp(TInt32.class, Shape.scalar()); Op assign = tf.assignVariableOp(variable, tf.constant(2)); - try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetchVariable(variable, TInt32.class).run().get(0)) { + try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) { assertEquals(2, value.getInt()); } } } + private static int numOperations(Graph g){ + int numOperations = 0; + for (Iterator it = g.operations(); it.hasNext(); ) { + Operation o = it.next(); + numOperations++; + } + return numOperations; + } + @Test - public static void testFetchVariableException() { + public static void testFetchVariableReusingRead() { try (Graph g = new Graph(); Session s = new Session(g)) { Ops tf = Ops.create(g); Operand variable = tf.varHandleOp(TInt32.class, Shape.scalar()); Op assign = tf.assignVariableOp(variable, tf.constant(2)); + Operand read = tf.readVariableOp(variable, TInt32.class); + + int ops = numOperations(g); + try (TInt32 value = (TInt32) s.runner().addTarget(assign).fetch(variable).run().get(0)) { - fail(); - } catch (IllegalStateException e) { - assertTrue(e.getMessage().contains("is a resource variable")); + assertEquals(2, value.getInt()); } - } - } - @Test - public static void testFetchVariableNonVariableException() { - try (Graph g = new Graph(); - Session s = new Session(g)) { - Ops tf = Ops.create(g); - Operand constant = tf.constant(2); + assertEquals(0, numOperations(g) - ops); - try (TInt32 value = (TInt32) s.runner().fetchVariable(constant, TInt32.class).run().get(0)) { - fail(); - } catch (IllegalStateException e) { - assertTrue(e.getMessage().contains("is not a resource variable")); - } } } From e5d85125ffa1d9680b57e434b7d1ac2a6e47c60a Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 30 Mar 2021 16:25:58 -0700 Subject: [PATCH 5/8] Forgot to format Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 17 +++++++++-------- .../test/java/org/tensorflow/SessionTest.java | 5 +---- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 9e94767d6df..465dd8238d9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -194,7 +194,7 @@ public Runner feed(String operation, int index, Tensor t) { * @return this session runner */ public Runner feed(Operand operand, Tensor t) { - if(operand.env() != graph){ + if (operand.env() != graph) { throw new IllegalStateException("Can't feed value to operand " + operand + ", it is from a different graph."); } @@ -247,7 +247,7 @@ public Runner fetch(String operation, int index) { * @return this session runner */ public Runner fetch(Output output) { - if(output.env() != graph){ + if (output.env() != graph) { throw new IllegalStateException("Can't fetch output " + output + ", it is from a different graph."); } @@ -256,7 +256,7 @@ public Runner fetch(Output output) { GraphOperation graphOp = (GraphOperation) output.op(); - try(PointerScope scope = new PointerScope()) { + try (PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); TF_OperationGetAttrType(graphOp.getUnsafeNativeHandle(), "dtype", rawDt, status); status.throwExceptionIfNotOK(); @@ -265,14 +265,15 @@ public Runner fetch(Output output) { DataType valueDt = DataType.forNumber(rawDt[0]); Operand read = null; - for(GraphOperation op : graphOp.consumers()){ - if(op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)){ + for (GraphOperation op : graphOp.consumers()) { + if (op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)) { read = op.output(0); } } - if(read == null){ - read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read").readVariableOp(output, TensorTypeRegistry.find(valueDt).type()); + if (read == null) { + read = Ops.create(graph).withSubScope("session_reads").withName(output.op().name() + "_read") + .readVariableOp(output, TensorTypeRegistry.find(valueDt).type()); } outputs.add(read.asOutput()); @@ -314,7 +315,7 @@ public Runner addTarget(String operation) { * @throws IllegalStateException if the operation is not from the session's graph. */ public Runner addTarget(Operation operation) { - if(operation.env() != graph){ + if (operation.env() != graph) { throw new IllegalStateException("Can't fetch operation " + operation + ", it is from a different graph."); } targets.add((GraphOperation) operation); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index d6d4e6d419d..4223a03ee23 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -15,7 +15,6 @@ package org.tensorflow; -import static java.util.Collections.singletonList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -27,8 +26,6 @@ import java.nio.file.Path; import java.util.Comparator; import java.util.Iterator; -import java.util.Spliterators; -import java.util.stream.StreamSupport; import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -272,7 +269,7 @@ public static void testFetchVariable() { } } - private static int numOperations(Graph g){ + private static int numOperations(Graph g) { int numOperations = 0; for (Iterator it = g.operations(); it.hasNext(); ) { Operation o = it.next(); From fb5c3196eed4a2a658362b0c13469da2f09d629f Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 31 Mar 2021 11:52:40 -0700 Subject: [PATCH 6/8] Remove obsolete method Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/RawTensor.java | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index 84a819f7d1f..c332fd7f1d1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -154,23 +154,6 @@ static RawTensor fromHandle(TF_Tensor handle) { return t; } - /** - * Create a Tensor object from a handle to the C TF_Tensor object. - * DOES NOT SET THE TYPE INFO, should only be passed directly to a {@link OperationBuilder#setAttr(String, Tensor)}. - * Will likely NPE otherwise. - * - *

Takes ownership of the handle. - */ - static RawTensor dangerousUntypedRawTensorFromHandle(TF_Tensor handle) { - RawTensor t = new RawTensor(null, Shape.of(shape(handle))); - try (PointerScope scope = new PointerScope()) { - scope.attach(handle); - t.tensorHandle = handle; - t.tensorScope = scope.extend(); - } - return t; - } - /** * Create an eager Tensor object from a handle to the C TF_Tensor object. * From b40fb0c662bb7c57fd7a0b9fceb9703a3459771d Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 3 Apr 2021 18:13:59 -0700 Subject: [PATCH 7/8] Small fixes Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 465dd8238d9..58fb62b5fee 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -195,7 +195,8 @@ public Runner feed(String operation, int index, Tensor t) { */ public Runner feed(Operand operand, Tensor t) { if (operand.env() != graph) { - throw new IllegalStateException("Can't feed value to operand " + operand + ", it is from a different graph."); + throw new IllegalStateException("Can't feed value for operand " + operand + ", it is from " + + (operand.env().isEager() ? "an eager session" : "a different graph") + "."); } inputs.add(operand.asOutput()); @@ -248,7 +249,8 @@ public Runner fetch(String operation, int index) { */ public Runner fetch(Output output) { if (output.env() != graph) { - throw new IllegalStateException("Can't fetch output " + output + ", it is from a different graph."); + throw new IllegalStateException("Can't fetch output " + output + ", it is from " + + (output.env().isEager() ? "an eager session" : "a different graph") + "."); } if (output.dataType() == DataType.DT_RESOURCE) { @@ -268,6 +270,7 @@ public Runner fetch(Output output) { for (GraphOperation op : graphOp.consumers()) { if (op.dtype(0) == valueDt && op.type().equals(ReadVariableOp.OP_NAME)) { read = op.output(0); + break; } } @@ -316,7 +319,8 @@ public Runner addTarget(String operation) { */ public Runner addTarget(Operation operation) { if (operation.env() != graph) { - throw new IllegalStateException("Can't fetch operation " + operation + ", it is from a different graph."); + throw new IllegalStateException("Can't target operation " + operation + ", it is from " + + (operation.env().isEager() ? "an eager session" : "a different graph") + "."); } targets.add((GraphOperation) operation); return this; From a6496ce050e3bc80e2be9b1370e827a08514a72b Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 3 Apr 2021 19:03:44 -0700 Subject: [PATCH 8/8] Python model loading + variable fetching test Signed-off-by: Ryan Nett --- .../org/tensorflow/SavedModelBundleTest.java | 16 ++++++++++++---- .../model/saved_model.pb | Bin 14026 -> 19119 bytes .../variables/variables.data-00000-of-00001 | Bin 42 -> 127 bytes .../model/variables/variables.index | Bin 144 -> 195 bytes .../saved_model_using_python/source_model.py | 11 +++++++++-- 5 files changed, 21 insertions(+), 6 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index cd8ac7e2ae4..ff93e317805 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -27,8 +27,8 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Collections; -import java.util.Map; import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.ndarray.FloatNdArray; @@ -292,21 +292,29 @@ public void pythonTfFunction() { ConcreteFunction add = bundle.function("add"); Map args = new HashMap(); try (TFloat32 a = TFloat32.scalarOf(10.0f); - TFloat32 b = TFloat32.scalarOf(15.5f)) { + TFloat32 b = TFloat32.scalarOf(15.5f)) { args.put("a", a); args.put("b", b); Map result = add.call(args); assertEquals(result.size(), 1); - try (TFloat32 c = (TFloat32)result.values().iterator().next()) { + try (TFloat32 c = (TFloat32) result.values().iterator().next()) { assertEquals(25.5f, c.getFloat()); } } + + // variable unwrapping happens in Session, which is used by ConcreteFunction.call + ConcreteFunction getVariable = bundle.function("get_variable"); + try (TFloat32 v = (TFloat32) getVariable.call(new HashMap<>()) + .get(getVariable.signature().outputNames().iterator().next())) { + assertEquals(2f, v.getFloat()); + } + } } private static Signature buildGraphWithVariables(Ops tf, Shape xShape) { Placeholder x = tf.placeholder(TFloat32.class, Placeholder.shape(xShape)); - Variable y = tf + Variable y = tf.withName("variable") .variable(tf.random.randomUniform(tf.constant(xShape), TFloat32.class)); ReduceSum z = tf.reduceSum(tf.math.add(x, y), tf.array(0, 1)); Init init = tf.init(); diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/saved_model.pb index 169e0095a3ec3d3a8bd421e46190b49dfc04cc28..d9498dd4b74db0225ba8436f188f8076b614fc46 100644 GIT binary patch literal 19119 zcmd^Hdv_DZ71zp=rA%`C7#l2>*Rq5Zlm^QyAVz6REteEX2pmH=r%l72m9#brvZP9D z8=9v5r(dh5AEO_mU!nKT-PxUeTDY_)q)sBO=6UCLA2avP-BdOH&%bHRfWLqJpWb>l~Z+zgH<(~I??3#j zDif8l()XCQ%S*cl!!9(K?~aDM{XITE^(Ot{o?heyQ)g6Yk0&E<6sbDPnET4@9(bdP zjkBngY@0nEI9>O}XwY*f$+bt83R9P-FB}+wzVuSdDAqLgYqqfCI=yGkr0;YF?xQhh z)rmVD9Zb4zZ55VH_(?n?VxaSOeh=HzRV6Gv1Jf}Z& zCo25r83uq(+@3vj_9Kjm3YiHQs>qv{*&;~S=m>iD+{1p?opM8n5sAb#y~qlJF0(v! zCx>pmY+Y|&->^;{;&Rk$+wFF5bZ+0ewY#yYp&r<*$EuQv+)&n1Kg||wQVLs; zG36yzY#D{R`RA}iRCntrwYQB~tXBePb>_+wwPiZC1YhMX`Rt3v&K6{}NnCZo+@ zrb4}h69_}}2RWi8N7NHbsEU^JEu%zhKy&6BCFH~1r9`LeuphQ+uFIm)&BCX{cP*YukVI3q+gcftY6=H`gG@h`|GE7 zpR7llvjY6{87n)zp6zf$F?01=0#GpG!j`GxAUk(v=?r!-!eKON#pHGn-yd`$kG?u} z+7N&ktEhoO(N0iia~AwM&(`ejXfSZQUjNXw_no1$=T7W?Z^}X0bJwKX6HqjJG#tE! zCo;@-e+YXV^=(8q4})J|6*dPfU()d5xAssSMVN*l)7V6hrUKZ z*R=!w*+&y+JO=7EZ-IuK-g7-59X3~>8n&!gH1{m9;{FohiGX!EwnKpOC4kMCRQJuo zPY&r%H=;;SyUxG?@jx!WMWn40l&JU{~vCSFh|`W|#5}>qeu& z9b&(;bku`k2APco`5Y^kS)f}1$Itli|tLock0>51$*O$FR7UTTC)Qf9(t9H zsuC669~VMb5uwXFm)ONLgy7Tx77{?ZxrwgE=O}6=_RS1Z&ROMP_;NTp8iFgys{Nij z1PA+dO`Jo-N2XMYvqDyFHZSCSzAb$x%*Ka^{>7d1?5v7Dp1b5p*}iRAXIK^*`VXvn z--nh5Pr8v|3t9$wgRO>Md!EszUSZl(j_yd7rdr?y$VtufA#eGzEDh(WuQ5g&tn@(T zk=?Q~SCz?QvUCl`RkYAi_;iz(nda0q{OZ)Sh{Wl?U~|xjm$!PoXO^xxy4DF-ne9|q zAVSbj|1w;zf^h|^TE?Hiy9xkH9}14bGax1zQ-`om6fi)$1!IW+KC2e^r!0LQ5z{%# zR_y}_9oW;c+wJf6Wr$*qYfj7f41BqhcJCngtN5YiTT{NsaASjrmb=@3RjcryvCC*_ z53NwBgwDinb+Yp|4`T7__mKN;6B91|8mq+*XIKYfk{X$OF zY;TCEKDIq~e{6R;%^STQ0f8wI& z%l&Wtd>VH`bD=MY)hGo^qQKq=i$eTFbfYfyuae0ztIFj39%D~L^z{LV3|F2bvY;;# z*=ds!(?I9H;>B7qW6py0fjW`?d&Y(b`)C3u;P?B!Y`T4oM4aa8Tyo>Rb_HTBg_fq_ z*x_jUU)iZAIIi3K+&@pQi+K`4kQWQen$MU`fSAR#TX?#UZAgp@Q3K? z7h?Tr1xkaJaKT||OcJu=T-Gv65F(7l$`de5&#ZO2ArR4A$)3dE$At5Wb%9P0&a2d) zN!C4Ve*ubw;k;sdmJnF`oV|VT?)DdtAKm}zsR*P!YCpKU{nY;L&er35)+v@oHZfq$ zUl2&Z*tH?1ZBGXx%uF)(pD$`nE^0v-B3e|TT@Z&90y93hv;?9SEP?pe^%%=r!Wg4F z1gDd8DSQd}4(FTnnUp58Y&p8{f5Po1fzE0o;D~<65}Ow{OjAw~zlyLgv??K0A0_vu z^=`0F`5PmI{!5RY?#p~2D$i-D3DSH6q=nY+82Qw!t87NXh+!!^EKe!Z*JoR;!ke6J zw^2{U+*)O6Q%_P)X74Mzy1w!}NnQvhnMD5pm$iWu)Z?4J&zL;sf@>Wdg!PoeZ|dYB z!_w`)fzyREn3vQ85MK#vLO|!`HpPrKSWnL8y<e7J=8lP1T@UteGNe$2dFagsAFrtbdp-y6E$u zklX=z{4CGsxtJl56;5(NUgMo4GJYub#w}HX#LdpDxC{B4hzm(=bQ)eNw3c(1^Oju6 zH%9k-+%VuJssL78F1{sjOYAr$y?IlGSgKbXejWmrGM;Bxhu2GGCO zy=AG|nGWweY5?K%6pf}W2`_EI2RKL-c}td%v7_Sy>GAm3d3MH`;vIY#3=1MHLn)Vis_Wovh=Vr?q4PX9e&o!F%D7EVbAn!&#|28 ziN1h5NRJ7gqd}#ML3t$J5FW1d%VCT7B3m**96B%lGS4NPRE!vt0RbzL;;yn_8htPI z{I?P`Xc-h5nYJ213Yfd(q_-&_l+vnZ8`A>E z0CPz7?^TaY@d;IrNMC-QP}uTw>~0E^6ieBvh(d7_gMKOcOvzk6=}Ezf8y`d85WUn$ ziKS4b`jEa<8I2V3VAI4gPDq1=o-6k6`}b3}rFgGaAoWHA-B)ZFHPosL3u0`2DB57V z1n-g}B)wQ{PT=8FI)GRs+)aGIzbnB^>31ojI^yk$5^;LP{$Gfnc{h6etZ1?py#Wbw zyzs^p2a4aC>SIwH%d|{nUSFTY#AZ} zy9F#(Lnu|}p$;&Xn^1-8IH%)Ou|!+@(fNMiT%)iLEzRuGHvm!E=|QjwOD zNVoXZ@! zBd83N?w`|A!k`|`rN#{KcYr8p;TtF&4XOnAaIF8G`jzk&pUmv> znkTg`RpUt86irjqFHOtnkm1J&anXr5Vc-n8IN^-L&t_y)Ve&#l&rUDg zESf~g3Xa*c_eVW<04ewVp*&4D z`eCpjAyOmpleS>ZflMCcyzR6aOPY?_unO5#va};h`4v_rKf-SCGc2CvFPrH)xsto^ z{?j^!_TU}3bx0EG^}Ajj@YUE-{DlKPX4m}79d58%yqfe~jt_Z8AsJ`cO6(2;xZtni zjR+xC%_Kz4oR{MtM~EdGz7Zi5AoUb`_<$`exH9F zn}Q@cOW0|mLf*n#XhqM6QBmn=_{b|rVolN4kg!O+ii8ir>nu~zxl%206mIcjp;{)J zi_4%L&o7@3%ZaI?duT7VkQe1u2)fdXdnVBCI|8=1Ytsu;La}+Ns)06P z{r6^f0FghDc4D>+^Wp1Jo3G#^5m98~Qez?(i3KjU|IGTO&9JUHt^ zi(O|h#rcAQ@6Cbp+MO6j&eRZQ$T;eIFTe`;{((=#574JSwfEq)q8Iki^^Qi9m&$Mi z3)s;qym<@drgi~|?^a@6?_v(*HPM7>9?$Ov=AKq)e^S~b)fMIK@`io8|)5d*f6*>TcRD;-#6$(L+Gz&ijr71oWCT<6BPGj!mTu^}dkz13!Cm=+ed~EUBGe3o` zNW$m8D^1~9k`Q$cs0HyjqkB>cuf+1zC02?7Ff0C?m7npdQpup(=uKj z;j3(-3(_!G;Iz{$P=(S3(I`MZD1JM`;RsVY;YBn~-IlcaMeI*3geTH8*sx%a^3QBx zX(w8$fR#9x1xMf+F{ASaVYY-634&vgy`p6(Xwm9K+W<5wTM&f_@X~-wf`)*-pNf{x zPb9G@_tOYed7$`PT(gk@%KNGm)#$}}PH0FZKyUyB1%ZUT;PTaF6CB6L4K_-`KwWcW zN0Dwq1bfi=J5yA0qj>=93<|a)cuT^9d)%CVvLglikA#Gld<2}iaU7lowATW3X= zK55f2F+cDFnOxEf76Z-J3sOr?&^Wtg1}MC%d~Poesml@_N@o9!C` z+VN7)L0M*i4n3tg@*4eXo*nd22C%E*hI6htJDSW$tr4+v= zUsC})n*-*D*zFQ2n)nQT>lGxHVvb>IqDM2l;bFB*CO`PHd(%4XT`9O5R^nYtYu0g zDcf0BbYBlchG84l4O<@u%osKx55v&C6hpD?WlsZ&6&QwYK>x*n?P=J#yrf7`qNIE@ zb%G{HDBgQ{&bjBFuY1l(3FI%E@S6nrp9pyvSg9G7P4~fItzp+2wlq+-Z`U<<00-A>)z&JF>SM=(?+Su6U}@%-mKl&w=T2wk$t~x3Dtl_Zm|I%REfur*vw4Ah z^C}pwm=G>4nIkZ!aB)!_zF1@gppYW*R2=__nPE>>;#ncH;7QWl^3y4SG*?OKSi~Q= z3`s{Ul?4yufi{c@h!S)5k`VnI#YHspV%GWLnFP0R&-=J|Umt%v z-fjlMc7bx@_@kH^<%R%hk>i4ZuiOL-fGFeZ3>;KSwQ5x>+4`2IY^sL3u9=Enwn!Xe zUG_8!R53NRtkjI^ZTw}`8fHmTbOVhW|MpbKYRr(skOUFWXIx-EUcw(SIesaY7rrIL zGAEf8dTLdISp}g~G*dGWUX%06wyDe&#w>LUjYvhWYKFS0wMuedP6M4T4^q%C0t$6xY!Wp= zLo|3C&$VzonQS%H$zG;9iui|as829MJ-d1Yrh|rBu}uv0$_nGJvLMUw<2W3O2;0!S z_RE~i2zOtOyV|oXnPV(DzIqIfc47%SN7QKKN=_zk7U2CjJO`7(g-(2^p)3mpF*qv33g=TX z@)?LE7mgvf8ptEXs%s^^qN9{Kx;sR?O~VIqcmWOtdCCF=H#9hNmhsXtsNLB5*-mWb zpY_9wkAtbl!b}~9F^8$m=)dp+GWXdo%sm>x+$7;!hre}Xt?QfRGfygOyU6PUv-IC_ z$N~vUE9#tb7Vk13s3@ZND7I<+b1mX2s8lHKRk6gJK&!5!op8jG>HD)NkY~FH0NYK%+fw7#l`y+h{wEAN6$;bBqr4fV4RHy?hswm_M6Y zJ>12l%KM@1iu4uV9RwUoI&``6eZs@_Fpt6*{=K6ybmQm6g`n(2y6?x(&&-4sn&qLG zGCP@?2lp`_B;fg8nwiHWRTuO4jehizTg}Jis(zSh;epl3POf6_sox$o`tSB(^v`62 ziu;Mt4@M>|j&9nwy(e`awPpG*fsO7LU9t~`Svmo0BU(k@Nhb;6S%xjy%~0mjNi5HR zcY{+F(n%}}>@)pNhHOsa*ZwzJBjvz z({(In$w{v+_I=CNp4{?+gXbB=);8=tIGJ4E&WBH9vQ4vDO$~bhoZ3IKht=Scj-qo~{~*BA?M`97)N~N8lR?%d zuB)YuuA58VI~*XAo(uLE zg@JZg)89Y-JNk*{I#e?39jGVWy8ySUvO~Hp^uiim+ z8D^DtY${9;@GI2#KEU5OpYVN@CI>6}4(H ze9eK^q^f#bGiSF|Yt~WQ*=^mvg{m^!T*63v9ly~nWgU9~Zz+amZ`aI?065MRG$0ms zN#nRApI$tj?D~l&lD`$*xAz74J)uj*kzhpeCTE9E)W4vK$*Jn#(iGbG?Yc>MoiEp$J2lyrKXM zU-C2Z!55ia?q~v+=*eC2Y3ONam_d?;?!*TpfVeBh;K*GOJ4!qQ*=8)I zr|yc_K^cfmC+@_@1Fw!rB|QQ?C}hvY7v>n`>FDI+8XxB96Y46&5eC&Pr3^DhU(YclB*@b#G{iMnA5}RQ H0~Z4TH$)-y literal 42 xcmY#Yj-9B*CC|mn!6?L;om!Mw9G_T_$;HRPB*aymnVy$eQd*Q+%*DXP007Mb3JU-L diff --git a/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index b/tensorflow-core/tensorflow-core-api/src/test/resources/saved_model_using_python/model/variables/variables.index index 8be702a36ed1281cf9a59b61769c71f7508c9468..ed8ff96c1d63be14e79cf4594419f29bd0d9ecb7 100644 GIT binary patch delta 147 zcmbQhc$krgfq{*KQHqI!kwJ!cB9B6j0*gkq>4#XI-wev!Wr;