diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index d6e69085324..5a8a9bace8a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -347,10 +347,10 @@ public final class Ops { public final SignalOps signal; - public final QuantizationOps quantization; - public final TrainOps train; + public final QuantizationOps quantization; + private final Scope scope; private Ops(Scope scope) { @@ -372,8 +372,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - quantization = new QuantizationOps(this); train = new TrainOps(this); + quantization = new QuantizationOps(this); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 71dc0f7cefc..feda13f9277 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -180,7 +180,7 @@ public Map call(Map arguments) Map outputToNode = signatureDef.getOutputsMap(); outputToNode.values().forEach(t -> runner.fetch(t.getName())); - List resultTensors = runner.run(); + List resultTensors = runner.run().getResults(); try { ListIterator resultTensorIter = resultTensors.listIterator(); Map returnMap = new HashMap(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index f2717f263eb..a176ffeb823 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -126,6 +126,66 @@ public GraphOperation operation(String name) { } } + + /** + * Returns the operation (node in the Graph) with the provided name. + *

+ * Or throws an {@code IllegalArgumentException} if no such operation exists in the Graph. + * + * @param name name of the operation to look for + * @return operation in the graph with this name + * @see #operation(String) + */ + public GraphOperation operationOrError(String name) { + GraphOperation op = operation(name); + if (op == null) { + throw new IllegalArgumentException("No Operation named [" + name + "] in the Graph"); + } + return op; + } + + /** + * Returns the {@code index}-th output of {@code operation}. + * Throws {@code IllegalArgumentException} if the operation is not found, or does not have an output at {@code index}. + * + * @param operation The operation to get the output of. + * @param index The index of the output to get. + * @return The {@code index}-th output of {@code operation}. + */ + public Output getOutput(String operation, int index){ + GraphOperation graphOp = operationOrError(operation); + if(index < 0 || index >= graphOp.numOutputs()){ + throw new IllegalArgumentException("Index out of bounds for operation " + operation + + ". Operation has " + graphOp.numOutputs() + " outputs"); + } + + return graphOp.output(index); + } + + /** + * Returns the output specified by {@code output}. + * Will try to parse the output index from {@code output}. + * I.e. {@code "scope/op:2"} will get the 2nd (0-indexed) output of {@code scope/op}. + * Otherwise, will return the 0th output. + * + * @param output The operation to get the output of, with the index optionally specified by colon. + * @return The output specified by {@code output}. + */ + @SuppressWarnings("rawtypes") + public Output getOutput(String output) { + int colon = output.lastIndexOf(':'); + if (colon == -1 || colon == output.length() - 1) { + return new Output(operationOrError(output), 0); + } + try { + String op = output.substring(0, colon); + int index = Integer.parseInt(output.substring(colon + 1)); + return new Output(operationOrError(op), index); + } catch (NumberFormatException e) { + return new Output(operationOrError(output), 0); + } + } + /** * Iterator over all the {@link Operation}s in the graph. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index c332fd7f1d1..dfc61ea12c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -65,7 +65,18 @@ public RawTensor asRawTensor() { @Override public void close() { + if(closed) { + throw new IllegalStateException("Tensor has already been closed"); + } tensorScope.close(); + closed = true; + } + + /** + * @return {@code true} if this tensor has been closed; + */ + public boolean isClosed() { + return closed; } /** @@ -222,6 +233,7 @@ private static long[] shape(TF_Tensor handle) { } private PointerScope tensorScope; + private boolean closed = false; private TF_Tensor tensorHandle; private final TensorTypeInfo typeInfo; private final Shape shape; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index e9d517a6548..b2d34076292 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -15,7 +15,22 @@ package org.tensorflow; +import static org.tensorflow.Graph.resolveOutputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; + import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Spliterator; +import java.util.function.Consumer; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -33,14 +48,9 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; - -import java.util.ArrayList; -import java.util.List; import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; - -import static org.tensorflow.Graph.resolveOutputs; -import static org.tensorflow.internal.c_api.global.tensorflow.*; +import org.tensorflow.types.family.TType; /** * Driver for {@link Graph} execution. @@ -135,6 +145,202 @@ public void close() { } } + /** + * The result of a run in a session. Contains the fetched tensors and the outputs that were fetched. + *

+ * Closing a {@code Result} object will close all of the tensors contained by it. + */ + public final class Result implements AutoCloseable, Iterable{ + private final List results; + private final List> fetches; + private final LinkedHashMap, Tensor> outputMap; + + /** + * Metadata about the run. + * + *

A RunMetadata + * protocol buffer. + */ + private final RunMetadata metadata; + + private boolean closed = false; + + private Result(List results, List> fetches, RunMetadata metadata) { + + if(results.size() != fetches.size()){ + throw new IllegalArgumentException("Expected the same number of fetches and values, got " + fetches.size() + + " fetches and " + results.size() + " values."); + } + + this.metadata = metadata; + this.results = results; + this.fetches = fetches; + outputMap = new LinkedHashMap<>(); + for(int i = 0 ; i < fetches.size() ; i++){ + outputMap.put(fetches.get(i), results.get(i)); + } + } + + private void requireOpen(){ + if(closed) { + throw new IllegalStateException("Result has been closed, can not access it."); + } + } + + /** + * Get the result tensors. + */ + public List getResults() { + requireOpen(); + return Collections.unmodifiableList(results); + } + + /** + * Get the outputs that were fetched. + */ + public List> getFetches() { + return Collections.unmodifiableList(fetches); + } + + /** + * Get a map of the fetched outputs to their results. + */ + public Map, Tensor> getOutputMap(){ + return Collections.unmodifiableMap(outputMap); + } + + /** + * Get the run metadata. May be null if not requested. + */ + public RunMetadata getMetadata() { + return metadata; + } + + /** + * @return Whether the result has been closed. + */ + public boolean isClosed() { + return closed; + } + + /** + * Get the result at {@code index}. + */ + public Tensor get(int index){ + requireOpen(); + return results.get(index); + } + + /** + * Get the result for {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + @SuppressWarnings("unchecked") + public T get(Output output){ + requireOpen(); + if(!outputMap.containsKey(output)) + throw new IllegalArgumentException("Did not fetch an output for " + output); + return (T) outputMap.get(output); + } + + /** + * Get the result for {@code operand} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + public T get(Operand operand){ + requireOpen(); + return get(operand.asOutput()); + } + + /** + * Get the result for the {@code index}-th output of {@code operation} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + public Tensor get(String operation, int index){ + requireOpen(); + return get(graph.getOutput(operation, index)); + } + + + /** + * Get the result for the output specified by {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + public Tensor get(String output){ + requireOpen(); + return get(graph.getOutput(output)); + } + + /** + * Returns {@code true} if {@code output} was fetched as part of this {@code Result}. + */ + public boolean contains(Output output){ + requireOpen(); + return outputMap.containsKey(output); + } + + /** + * Returns {@code true} if {@code operand} was fetched as part of this {@code Result}. + */ + public boolean contains(Operand operand){ + requireOpen(); + return contains(operand.asOutput()); + } + + /** + * Returns {@code true} if the {@code index}-th output of {@code operation} was fetched as part of this {@code Result}. + */ + public boolean contains(String operation, int index){ + requireOpen(); + return contains(graph.getOutput(operation, index)); + } + + + /** + * Returns {@code true} the output specified by {@code output} was fetched as part of this {@code Result} + */ + public boolean contains(String output){ + requireOpen(); + return contains(graph.getOutput(output)); + } + + /** + * Close any open tensors contained by this {@code Result}. + */ + @Override + public void close() { + requireOpen(); + for(Tensor t : this){ + if(!t.isClosed()) { + t.close(); + } + } + closed = true; + } + + @Override + public Iterator iterator() { + requireOpen(); + return results.iterator(); + } + + @Override + public void forEach(Consumer action) { + requireOpen(); + results.forEach(action); + } + + @Override + public Spliterator spliterator() { + requireOpen(); + return results.spliterator(); + } + + /** + * Return the number of tensors contained by this Result. + */ + public int size() { + return getResults().size(); + } + } + /** * Run {@link Operation}s and evaluate {@link Tensor Tensors}. * @@ -159,7 +365,7 @@ public final class Runner { * @return this session runner */ public Runner feed(String operation, Tensor t) { - return feed(parseOutput(operation), t); + return feed(graph.getOutput(operation), t); } /** @@ -174,11 +380,9 @@ public Runner feed(String operation, Tensor t) { * @return this session runner */ public Runner feed(String operation, int index, Tensor t) { - Operation op = operationByName(operation); - if (op != null) { - inputs.add(op.output(index)); - inputTensors.add(t); - } + Operation op = graph.operationOrError(operation); + inputs.add(op.output(index)); + inputTensors.add(t); return this; } @@ -206,9 +410,10 @@ public Runner feed(Operand operand, Tensor t) { * the {@code SignatureDef} protocol buffer messages that are included in {@link * SavedModelBundle#metaGraphDef()}. * @return this session runner + * @see Graph#getOutput(String) */ public Runner fetch(String operation) { - return fetch(parseOutput(operation)); + return fetch(graph.getOutput(operation)); } /** @@ -219,12 +424,11 @@ public Runner fetch(String operation) { * * @param operation the string name of the operation * @return this session runner + * @see Graph#getOutput(String, int) */ public Runner fetch(String operation, int index) { - Operation op = operationByName(operation); - if (op != null) { - outputs.add(op.output(index)); - } + Operation op = graph.operationOrError(operation); + outputs.add(op.output(index)); return this; } @@ -255,12 +459,11 @@ public Runner fetch(Operand operand) { * * @param operation the string name of the operation to execute * @return this session runner + * @see Graph#operationOrError(String) */ public Runner addTarget(String operation) { - GraphOperation op = operationByName(operation); - if (op != null) { - targets.add(op); - } + GraphOperation op = graph.operationOrError(operation); + targets.add(op); return this; } @@ -312,21 +515,13 @@ public Runner setOptions(RunOptions options) { * Execute the graph fragments necessary to compute all requested fetches. * *

WARNING: The caller assumes ownership of all returned {@link Tensor Tensors}, i.e., - * the caller must call {@link Tensor#close} on all elements of the returned list to free up + * the caller must call {@link Tensor#close} on all returned tensors or {@link Result#close()} to free up * resources. * - *

TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it - * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in - * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a - * {@code Map}? - * - *

TODO(andrewmyers): It would also be good if whatever is returned here made it easier to - * extract output tensors in a type-safe way. - * - * @return list of resulting tensors fetched by this session runner + * @return a {@link Result} containing tensors fetched by this session runner */ - public List run() { - return runHelper(false).outputs; + public Result run() { + return runHelper(false); } /** @@ -339,11 +534,11 @@ public List run() { * * @return list of resulting tensors fetched by this session runner, with execution metadata */ - public Run runAndFetchMetadata() { + public Result runAndFetchMetadata() { return runHelper(true); } - private Run runHelper(boolean wantMetadata) { + private Result runHelper(boolean wantMetadata) { TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()]; TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()]; int[] inputOpIndices = new int[inputs.size()]; @@ -398,10 +593,7 @@ private Run runHelper(boolean wantMetadata) { } finally { runRef.close(); } - Run ret = new Run(); - ret.outputs = outputs; - ret.metadata = metadata; - return ret; + return new Result(outputs, new ArrayList<>(this.outputs), metadata); } private class Reference implements AutoCloseable { @@ -427,33 +619,10 @@ public void close() { } } - private GraphOperation operationByName(String opName) { - GraphOperation op = graph.operation(opName); - if (op == null) { - throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph"); - } - return op; - } - - @SuppressWarnings("rawtypes") - private Output parseOutput(String opName) { - int colon = opName.lastIndexOf(':'); - if (colon == -1 || colon == opName.length() - 1) { - return new Output(operationByName(opName), 0); - } - try { - String op = opName.substring(0, colon); - int index = Integer.parseInt(opName.substring(colon + 1)); - return new Output(operationByName(op), index); - } catch (NumberFormatException e) { - return new Output(operationByName(opName), 0); - } - } - - private final ArrayList> inputs = new ArrayList<>(); - private final ArrayList inputTensors = new ArrayList<>(); - private final ArrayList> outputs = new ArrayList<>(); - private final ArrayList targets = new ArrayList<>(); + private ArrayList> inputs = new ArrayList<>(); + private ArrayList inputTensors = new ArrayList<>(); + private ArrayList> outputs = new ArrayList<>(); + private ArrayList targets = new ArrayList<>(); private RunOptions runOptions = null; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index fc1275229bf..5294d902685 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -212,4 +212,9 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData */ @Override void close(); + + /** + * @return {@code true} if this tensor has been closed. + */ + boolean isClosed(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index 2fc423b914e..0545b5a794d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -80,4 +80,9 @@ default long numBytes() { default void close() { asRawTensor().close(); } + + @Override + default boolean isClosed(){ + return asRawTensor().isClosed(); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java deleted file mode 100644 index 330a40bae6b..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.tensorflow; - -import java.util.ArrayList; -import java.util.Collection; - -public final class AutoCloseableList extends ArrayList - implements AutoCloseable { - - public AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index e4340da3275..88e77a4022a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -53,8 +53,7 @@ public void withDeviceMethod() { .abs(aOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { + try (Session.Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } @@ -85,8 +84,7 @@ public void withEmptyDeviceSpec() { .abs(aOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { + try (Session.Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } @@ -131,8 +129,7 @@ public void withTwoScopes() { .mul(absOps, bOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { + try (Session.Result t = session.runner().fetch(mulOps).run()) { assertEquals(10, ((TInt32)t.get(0)).getInt()); } } @@ -179,8 +176,7 @@ public void withIncorrectDeviceSpec() { .mul(absOps, bOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { + try (Session.Result t = session.runner().fetch(mulOps).run()) { fail(); } catch (TFInvalidArgumentException e) { // ok @@ -212,8 +208,7 @@ public void withDeviceSpecInScope() { .abs(aOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { + try (Session.Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index d8ffc1a475b..32f8cb4d18b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -32,7 +32,9 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.Graph}. */ +/** + * Unit tests for {@link org.tensorflow.Graph}. + */ public class GraphTest { @Test @@ -146,7 +148,7 @@ public void addGradientsToGraph() { Output y0 = tf.math.square(x1).y(); Output y1 = tf.math.square(y0).y(); Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum(); - + Output[] grads0 = g.addGradients(y1, toArray(x1)); assertNotNull(grads0); assertEquals(1, grads0.length); @@ -157,18 +159,17 @@ public void addGradientsToGraph() { assertEquals(2, grads1.length); assertEquals(DataType.DT_FLOAT, grads1[0].dataType()); assertEquals(DataType.DT_FLOAT, grads1[1].dataType()); - + try (TFloat32 c1 = TFloat32.scalarOf(3.0f); TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = new AutoCloseableList<>( - s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run())) { - + Session.Result outputs = s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run()) { + assertEquals(3, outputs.size()); assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); assertEquals(6.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); @@ -212,7 +213,7 @@ public void addGradientsWithInitialValuesToGraph() { Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); - + Output[] grad0 = g.addGradients(y1, toArray(y0)); assertNotNull(grad0); assertEquals(1, grad0.length); @@ -268,18 +269,18 @@ public void buildWhileLoopSingleInput() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output input = tf.placeholder(TInt32.class).output(); + Output input = tf.placeholder(TInt32.class).output(); @SuppressWarnings("unchecked") Output[] loopOutputs = g.whileLoop( toArray(input), (condGraph, condInputs, condOutputs) -> { Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); + condOutputs[0] = tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); }, (bodyGraph, bodyInputs, bodyOutputs) -> { Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); }, "test_loop"); @@ -300,8 +301,8 @@ public void buildWhileLoopMultipleInputs() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output input1 = tf.placeholder(TInt32.class).output(); - Output input2 = tf.placeholder(TInt32.class).output(); + Output input1 = tf.placeholder(TInt32.class).output(); + Output input2 = tf.placeholder(TInt32.class).output(); Output[] inputs = toArray(input1, input2); @SuppressWarnings("unchecked") @@ -309,25 +310,23 @@ public void buildWhileLoopMultipleInputs() { inputs, (condGraph, condInputs, condOutputs) -> { Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); + condOutputs[0] = tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); }, (bodyGraph, bodyInputs, bodyOutputs) -> { Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); - bodyOutputs[1] = tfb.math.square((Output)bodyInputs[1]).y(); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); + bodyOutputs[1] = tfb.math.square((Output) bodyInputs[1]).y(); }, "test_loop"); try (TInt32 c1 = TInt32.scalarOf(2); TInt32 c2 = TInt32.scalarOf(5); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() - .feed(input1, c1) - .feed(input2, c2) - .fetch(loopOutputs[0]) - .fetch(loopOutputs[1]) - .run())) { + Session.Result outputs = s.runner() + .feed(input1, c1) + .feed(input2, c2) + .fetch(loopOutputs[0]) + .fetch(loopOutputs[1]) + .run()) { assertEquals(2, outputs.size()); assertEquals(16, ((TInt32)outputs.get(0)).getInt()); // ((2^2)^2) assertEquals(625, ((TInt32)outputs.get(1)).getInt()); // ((5^2)^2) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index b1928bff51c..c1ec11f89f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -24,6 +24,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; @@ -49,8 +50,7 @@ public void runUsingOperationNames() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { + Session.Result outputs = s.runner().feed("X", x).fetch("Y").run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); } @@ -66,8 +66,7 @@ public void runUsingOperationHandles() { Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { + Session.Result outputs = s.runner().feed(feed, x).fetch(fetch).run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); } @@ -107,19 +106,18 @@ public void runWithMetadata() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { - Session.Run result = s.runner() + Session.Result result = s.runner() .feed("X", x) .fetch("Y") .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); - assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); + assertEquals(1, result.size()); + assertEquals(31, ((TInt32)result.get(0)).getInt(0, 0)); // Sanity check on metadata - assertNotNull(result.metadata); - assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); - outputs.close(); + assertNotNull(result.getMetadata()); + assertTrue(result.getMetadata().hasStepStats(), result.getMetadata().toString()); + result.close(); } } } @@ -131,8 +129,7 @@ public void runMultipleOutputs() { Ops tf = Ops.create(g); tf.withName("c1").constant(2718); tf.withName("c2").constant(31415); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); + Session.Result outputs = s.runner().fetch("c2").fetch("c1").run(); assertEquals(2, outputs.size()); assertEquals(31415, ((TInt32)outputs.get(0)).getInt()); assertEquals(2718, ((TInt32)outputs.get(1)).getInt()); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 5dd6903d913..8955f0df5fe 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -20,7 +20,6 @@ import java.io.IOException; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.EagerSession; import org.tensorflow.Graph; import org.tensorflow.Session; @@ -60,8 +59,7 @@ public void createInts() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -79,8 +77,7 @@ public void createFloats() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -98,8 +95,7 @@ public void createDoubles() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -117,8 +113,7 @@ public void createLongs() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -136,8 +131,7 @@ public void createStrings() throws IOException { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 80150b64bb6..1a65ca90d34 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -21,7 +21,6 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; @@ -48,9 +47,8 @@ public void createGradients() { assertEquals(2, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + Session.Result outputs = + sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run()) { assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); @@ -75,8 +73,7 @@ public void createGradientsWithSum() { assertEquals(1, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + Session.Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) { assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } @@ -101,9 +98,8 @@ public void createGradientsWithInitialValues() { assertEquals(1, grads1.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + Session.Result outputs = + sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) { assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index 4121baf3af1..ef83f0117b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -132,7 +132,7 @@ public void operationsComposingZerosAreCorrectlyNamed() { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); - List results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); + Session.Result results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 882a64ba54d..c96fe0b68e9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -53,7 +53,7 @@ public void testGraphIteration() { int batches = 0; while (true) { try { - List outputs = session.runner().fetch(x).fetch(y).run(); + Session.Result outputs = session.runner().fetch(x).fetch(y).run(); try (TInt32 xBatch = (TInt32)outputs.get(0); TInt32 yBatch = (TInt32)outputs.get(1)) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index 5f203427563..5e63db43716 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -78,7 +78,7 @@ public void testGraphIteration() { int batches = 0; while (true) { try { - List outputs = session.runner().fetch(X).fetch(y).run(); + Session.Result outputs = session.runner().fetch(X).fetch(y).run(); try (TInt32 XBatch = (TInt32)outputs.get(0); TInt32 yBatch = (TInt32)outputs.get(1)) {