From 7202ea324aaa51df60beb076a29eef5542249639 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 20:24:52 -0800 Subject: [PATCH 01/19] move output finding methods to graph, make public Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Graph.java | 60 +++++++++++++++++++ .../src/main/java/org/tensorflow/Session.java | 55 ++++------------- 2 files changed, 73 insertions(+), 42 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index f2717f263eb..91497a9fb26 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -126,6 +126,66 @@ public GraphOperation operation(String name) { } } + + /** + * Returns the operation (node in the Graph) with the provided name. + * + *

Or throws an {@code IllegalArgumentException} if no such operation exists in the Graph. + * + * @param name name of the operation to look for + * @return operation in the graph with this name + * @see #operation(String) + */ + public GraphOperation operationOrError(String name) { + GraphOperation op = operation(name); + if (op == null) { + throw new IllegalArgumentException("No Operation named [" + name + "] in the Graph"); + } + return op; + } + + /** + * Returns the {@code index}-th output of {@code operation}. + * Throws {@code IllegalArgumentException} if the operation is not found, or does not have an output at {@code index}. + * + * @param operation The operation to get the output of. + * @param index The index of the output to get. + * @return The {@code index}-th output of {@code operation}. + */ + public Output getOutput(String operation, int index){ + GraphOperation graphOp = operationOrError(operation); + if(index < 0 || index >= graphOp.numOutputs()){ + throw new IllegalArgumentException("Index out of bounds for operation " + operation + + ". Operation has " + graphOp.numOutputs() + " outputs"); + } + + return graphOp.output(index); + } + + /** + * Returns the output specified by {@code output}. + * Will try to parse the output index from {@code output}. + * I.e. {@code "scope/op:2"} will get the 2nd (0-indexed) output of {@code scope/op}. + * Otherwise, will return the 0th output. + * + * @param output The operation to get the output of, with the index optionally specified by colon. + * @return The output specified by {@code output}. + */ + @SuppressWarnings("rawtypes") + public Output getOutput(String output) { + int colon = output.lastIndexOf(':'); + if (colon == -1 || colon == output.length() - 1) { + return new Output(operationOrError(output), 0); + } + try { + String op = output.substring(0, colon); + int index = Integer.parseInt(output.substring(colon + 1)); + return new Output(operationOrError(op), index); + } catch (NumberFormatException e) { + return new Output(operationOrError(output), 0); + } + } + /** * Iterator over all the {@link Operation}s in the graph. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index e9d517a6548..c63f4d2ce73 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -159,7 +159,7 @@ public final class Runner { * @return this session runner */ public Runner feed(String operation, Tensor t) { - return feed(parseOutput(operation), t); + return feed(graph.getOutput(operation), t); } /** @@ -174,11 +174,9 @@ public Runner feed(String operation, Tensor t) { * @return this session runner */ public Runner feed(String operation, int index, Tensor t) { - Operation op = operationByName(operation); - if (op != null) { - inputs.add(op.output(index)); - inputTensors.add(t); - } + Operation op = graph.operationOrError(operation); + inputs.add(op.output(index)); + inputTensors.add(t); return this; } @@ -208,7 +206,7 @@ public Runner feed(Operand operand, Tensor t) { * @return this session runner */ public Runner fetch(String operation) { - return fetch(parseOutput(operation)); + return fetch(graph.getOutput(operation)); } /** @@ -221,10 +219,8 @@ public Runner fetch(String operation) { * @return this session runner */ public Runner fetch(String operation, int index) { - Operation op = operationByName(operation); - if (op != null) { - outputs.add(op.output(index)); - } + Operation op = graph.operationOrError(operation); + outputs.add(op.output(index)); return this; } @@ -257,10 +253,8 @@ public Runner fetch(Operand operand) { * @return this session runner */ public Runner addTarget(String operation) { - GraphOperation op = operationByName(operation); - if (op != null) { - targets.add(op); - } + GraphOperation op = graph.operationOrError(operation); + targets.add(op); return this; } @@ -427,33 +421,10 @@ public void close() { } } - private GraphOperation operationByName(String opName) { - GraphOperation op = graph.operation(opName); - if (op == null) { - throw new IllegalArgumentException("No Operation named [" + opName + "] in the Graph"); - } - return op; - } - - @SuppressWarnings("rawtypes") - private Output parseOutput(String opName) { - int colon = opName.lastIndexOf(':'); - if (colon == -1 || colon == opName.length() - 1) { - return new Output(operationByName(opName), 0); - } - try { - String op = opName.substring(0, colon); - int index = Integer.parseInt(opName.substring(colon + 1)); - return new Output(operationByName(op), index); - } catch (NumberFormatException e) { - return new Output(operationByName(opName), 0); - } - } - - private final ArrayList> inputs = new ArrayList<>(); - private final ArrayList inputTensors = new ArrayList<>(); - private final ArrayList> outputs = new ArrayList<>(); - private final ArrayList targets = new ArrayList<>(); + private ArrayList> inputs = new ArrayList<>(); + private ArrayList inputTensors = new ArrayList<>(); + private ArrayList> outputs = new ArrayList<>(); + private ArrayList targets = new ArrayList<>(); private RunOptions runOptions = null; } From 78384d7197fd32c46fd934aaee36833a8e6abbcb Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 20:25:20 -0800 Subject: [PATCH 02/19] add session result class, use in run() Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 100 ++++++++++++++++-- 1 file changed, 90 insertions(+), 10 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index c63f4d2ce73..d522c4d1208 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -15,7 +15,17 @@ package org.tensorflow; +import static org.tensorflow.Graph.resolveOutputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; + import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -33,14 +43,9 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; - -import java.util.ArrayList; -import java.util.List; import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; - -import static org.tensorflow.Graph.resolveOutputs; -import static org.tensorflow.internal.c_api.global.tensorflow.*; +import org.tensorflow.types.family.TType; /** * Driver for {@link Graph} execution. @@ -302,11 +307,86 @@ public Runner setOptions(RunOptions options) { return this; } + public final class Result implements AutoCloseable{ + private final List> results; + private final List> fetches; + private final LinkedHashMap, Integer> indexMap; + + public Result(List> results, List> fetches) { + this.results = new ArrayList<>(results); + this.fetches = new ArrayList<>(fetches); + indexMap = new LinkedHashMap<>(); + for(int i = 0 ; i < fetches.size() ; i++){ + indexMap.put(fetches.get(i), i); + } + } + + /** + * Get the result tensors. + */ + public List> getResults() { + return results; + } + + /** + * Get the outputs that were fetched. + */ + public List> getFetches() { + return fetches; + } + + /** + * Get the result at {@code index}. + */ + public Tensor get(int index){ + return results.get(index); + } + + /** + * Get the result for {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + @SuppressWarnings("unchecked") + public Tensor get(Output output){ + if(!indexMap.containsKey(output)) + throw new IllegalArgumentException("Did not fetch an output for " + output); + return (Tensor) results.get(indexMap.get(output)); + } + + /** + * Get the result for {@code operand} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + public Tensor get(Operand operand){ + return get(operand.asOutput()); + } + + /** + * Get the result for the {@code index}-th output of {@code operation} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + public Tensor get(String operation, int index){ + return get(graph.getOutput(operation, index)); + } + + + /** + * Get the result for the output specified by {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + public Tensor get(String output){ + return get(graph.getOutput(output)); + } + + @Override + public void close() throws Exception { + for(Tensor t : results){ + t.close(); + } + } + } + /** * Execute the graph fragments necessary to compute all requested fetches. * *

WARNING: The caller assumes ownership of all returned {@link Tensor Tensors}, i.e., - * the caller must call {@link Tensor#close} on all elements of the returned list to free up + * the caller must call {@link Tensor#close} on all returned tensors or {@link Result#close()} to free up * resources. * *

TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it @@ -317,10 +397,10 @@ public Runner setOptions(RunOptions options) { *

TODO(andrewmyers): It would also be good if whatever is returned here made it easier to * extract output tensors in a type-safe way. * - * @return list of resulting tensors fetched by this session runner + * @return a {@link Result} containing tensors fetched by this session runner */ - public List run() { - return runHelper(false).outputs; + public Result run() { + return new Result(runHelper(false).outputs, outputs); } /** From e4cf8560c871032087f78f9483e0321e42db7150 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 20:27:30 -0800 Subject: [PATCH 03/19] make Result implement Iterable Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index d522c4d1208..fa8552b4462 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -24,8 +24,11 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; +import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; +import java.util.Spliterator; +import java.util.function.Consumer; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -307,7 +310,7 @@ public Runner setOptions(RunOptions options) { return this; } - public final class Result implements AutoCloseable{ + public final class Result implements AutoCloseable, Iterable>{ private final List> results; private final List> fetches; private final LinkedHashMap, Integer> indexMap; @@ -380,6 +383,21 @@ public void close() throws Exception { t.close(); } } + + @Override + public Iterator> iterator() { + return results.iterator(); + } + + @Override + public void forEach(Consumer> action) { + results.forEach(action); + } + + @Override + public Spliterator> spliterator() { + return results.spliterator(); + } } /** From 786e3fcbd6eefcad3da7d92b650def86c2aeb140 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 21:07:30 -0800 Subject: [PATCH 04/19] fix the only direct usage Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/ConcreteFunction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 71dc0f7cefc..feda13f9277 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -180,7 +180,7 @@ public Map call(Map arguments) Map outputToNode = signatureDef.getOutputsMap(); outputToNode.values().forEach(t -> runner.fetch(t.getName())); - List resultTensors = runner.run(); + List resultTensors = runner.run().getResults(); try { ListIterator resultTensorIter = resultTensors.listIterator(); Map returnMap = new HashMap(); From 4500b56a6f7bd015d7deb94677768f2c28186759 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 21:13:54 -0800 Subject: [PATCH 05/19] better docs Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index fa8552b4462..316a918ee3f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -212,6 +212,7 @@ public Runner feed(Operand operand, Tensor t) { * the {@code SignatureDef} protocol buffer messages that are included in {@link * SavedModelBundle#metaGraphDef()}. * @return this session runner + * @see Graph#getOutput(String, int) */ public Runner fetch(String operation) { return fetch(graph.getOutput(operation)); @@ -225,6 +226,7 @@ public Runner fetch(String operation) { * * @param operation the string name of the operation * @return this session runner + * @see Graph#getOutput(String, int) */ public Runner fetch(String operation, int index) { Operation op = graph.operationOrError(operation); @@ -310,12 +312,17 @@ public Runner setOptions(RunOptions options) { return this; } + /** + * The result of a run in a session. Contains the fetched tensors and the outputs that were fetched. + *

+ * Closing a {@code Result} object will close all of the tensors contained by it. + */ public final class Result implements AutoCloseable, Iterable>{ private final List> results; private final List> fetches; private final LinkedHashMap, Integer> indexMap; - public Result(List> results, List> fetches) { + private Result(List> results, List> fetches) { this.results = new ArrayList<>(results); this.fetches = new ArrayList<>(fetches); indexMap = new LinkedHashMap<>(); @@ -377,9 +384,12 @@ public Tensor get(String output){ return get(graph.getOutput(output)); } + /** + * Close all of the tensors contained by this {@code Result}. + */ @Override public void close() throws Exception { - for(Tensor t : results){ + for(Tensor t : this){ t.close(); } } From 666f434c1564fb2b153dafb58d6e52d00226a07e Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 21:15:26 -0800 Subject: [PATCH 06/19] even better docs Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Graph.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java index 91497a9fb26..a176ffeb823 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Graph.java @@ -129,8 +129,8 @@ public GraphOperation operation(String name) { /** * Returns the operation (node in the Graph) with the provided name. - * - *

Or throws an {@code IllegalArgumentException} if no such operation exists in the Graph. + *

+ * Or throws an {@code IllegalArgumentException} if no such operation exists in the Graph. * * @param name name of the operation to look for * @return operation in the graph with this name From 91b63acab8d6486af760884144e1b34b3318d61c Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 7 Dec 2020 21:17:32 -0800 Subject: [PATCH 07/19] remove fixed todos Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 316a918ee3f..903336b7b9a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -417,14 +417,6 @@ public Spliterator> spliterator() { * the caller must call {@link Tensor#close} on all returned tensors or {@link Result#close()} to free up * resources. * - *

TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it - * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in - * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a - * {@code Map}? - * - *

TODO(andrewmyers): It would also be good if whatever is returned here made it easier to - * extract output tensors in a type-safe way. - * * @return a {@link Result} containing tensors fetched by this session runner */ public Result run() { From 6754d4a815c70b15d854df2b8f4cbe685fae40ec Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 01:19:06 -0800 Subject: [PATCH 08/19] Move Result out of Runner Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 203 +++++++++--------- 1 file changed, 105 insertions(+), 98 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 903336b7b9a..0293d9dc0f4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -143,6 +143,111 @@ public void close() { } } + /** + * The result of a run in a session. Contains the fetched tensors and the outputs that were fetched. + *

+ * Closing a {@code Result} object will close all of the tensors contained by it. + */ + public final class Result implements AutoCloseable, Iterable>{ + private final List> results; + private final List> fetches; + private final LinkedHashMap, Integer> indexMap; + + private Result(List> results, List> fetches) { + this.results = new ArrayList<>(results); + this.fetches = new ArrayList<>(fetches); + indexMap = new LinkedHashMap<>(); + for(int i = 0 ; i < fetches.size() ; i++){ + indexMap.put(fetches.get(i), i); + } + } + + /** + * Get the result tensors. + */ + public List> getResults() { + return results; + } + + /** + * Get the outputs that were fetched. + */ + public List> getFetches() { + return fetches; + } + + /** + * Get the result at {@code index}. + */ + public Tensor get(int index){ + return results.get(index); + } + + /** + * Get the result for {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + @SuppressWarnings("unchecked") + public Tensor get(Output output){ + if(!indexMap.containsKey(output)) + throw new IllegalArgumentException("Did not fetch an output for " + output); + return (Tensor) results.get(indexMap.get(output)); + } + + /** + * Get the result for {@code operand} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + public Tensor get(Operand operand){ + return get(operand.asOutput()); + } + + /** + * Get the result for the {@code index}-th output of {@code operation} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + public Tensor get(String operation, int index){ + return get(graph.getOutput(operation, index)); + } + + + /** + * Get the result for the output specified by {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. + */ + public Tensor get(String output){ + return get(graph.getOutput(output)); + } + + /** + * Close all of the tensors contained by this {@code Result}. + */ + @Override + public void close() { + for(Tensor t : this){ + t.close(); + } + } + + @Override + public Iterator> iterator() { + return results.iterator(); + } + + @Override + public void forEach(Consumer> action) { + results.forEach(action); + } + + @Override + public Spliterator> spliterator() { + return results.spliterator(); + } + + /** + * Return the number of tensors contained by this Result. + */ + public int size() { + return getResults().size(); + } + } + /** * Run {@link Operation}s and evaluate {@link Tensor Tensors}. * @@ -312,104 +417,6 @@ public Runner setOptions(RunOptions options) { return this; } - /** - * The result of a run in a session. Contains the fetched tensors and the outputs that were fetched. - *

- * Closing a {@code Result} object will close all of the tensors contained by it. - */ - public final class Result implements AutoCloseable, Iterable>{ - private final List> results; - private final List> fetches; - private final LinkedHashMap, Integer> indexMap; - - private Result(List> results, List> fetches) { - this.results = new ArrayList<>(results); - this.fetches = new ArrayList<>(fetches); - indexMap = new LinkedHashMap<>(); - for(int i = 0 ; i < fetches.size() ; i++){ - indexMap.put(fetches.get(i), i); - } - } - - /** - * Get the result tensors. - */ - public List> getResults() { - return results; - } - - /** - * Get the outputs that were fetched. - */ - public List> getFetches() { - return fetches; - } - - /** - * Get the result at {@code index}. - */ - public Tensor get(int index){ - return results.get(index); - } - - /** - * Get the result for {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. - */ - @SuppressWarnings("unchecked") - public Tensor get(Output output){ - if(!indexMap.containsKey(output)) - throw new IllegalArgumentException("Did not fetch an output for " + output); - return (Tensor) results.get(indexMap.get(output)); - } - - /** - * Get the result for {@code operand} or throw an {@code IllegalArgumentException} if it wasn't fetched. - */ - public Tensor get(Operand operand){ - return get(operand.asOutput()); - } - - /** - * Get the result for the {@code index}-th output of {@code operation} or throw an {@code IllegalArgumentException} if it wasn't fetched. - */ - public Tensor get(String operation, int index){ - return get(graph.getOutput(operation, index)); - } - - - /** - * Get the result for the output specified by {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. - */ - public Tensor get(String output){ - return get(graph.getOutput(output)); - } - - /** - * Close all of the tensors contained by this {@code Result}. - */ - @Override - public void close() throws Exception { - for(Tensor t : this){ - t.close(); - } - } - - @Override - public Iterator> iterator() { - return results.iterator(); - } - - @Override - public void forEach(Consumer> action) { - results.forEach(action); - } - - @Override - public Spliterator> spliterator() { - return results.spliterator(); - } - } - /** * Execute the graph fragments necessary to compute all requested fetches. * From c96ec0a190cc06eb9db0b29dbe543ef473d62ddf Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 01:22:10 -0800 Subject: [PATCH 09/19] fix more tests Signed-off-by: Ryan Nett --- .../java/org/tensorflow/DeviceSpecTest.java | 15 ++--- .../test/java/org/tensorflow/GraphTest.java | 61 +++++++++---------- .../test/java/org/tensorflow/SessionTest.java | 9 +-- .../org/tensorflow/op/core/ConstantTest.java | 15 ++--- .../org/tensorflow/op/core/GradientsTest.java | 18 +++--- .../org/tensorflow/op/core/ZerosTest.java | 4 +- 6 files changed, 52 insertions(+), 70 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index e4340da3275..88e77a4022a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -53,8 +53,7 @@ public void withDeviceMethod() { .abs(aOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { + try (Session.Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } @@ -85,8 +84,7 @@ public void withEmptyDeviceSpec() { .abs(aOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { + try (Session.Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } @@ -131,8 +129,7 @@ public void withTwoScopes() { .mul(absOps, bOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { + try (Session.Result t = session.runner().fetch(mulOps).run()) { assertEquals(10, ((TInt32)t.get(0)).getInt()); } } @@ -179,8 +176,7 @@ public void withIncorrectDeviceSpec() { .mul(absOps, bOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { + try (Session.Result t = session.runner().fetch(mulOps).run()) { fail(); } catch (TFInvalidArgumentException e) { // ok @@ -212,8 +208,7 @@ public void withDeviceSpecInScope() { .abs(aOps) .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { + try (Session.Result t = session.runner().fetch(absOps).run()) { assertEquals(1, ((TInt32)t.get(0)).getInt()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index d8ffc1a475b..0fad33302d5 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -32,7 +32,9 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.Graph}. */ +/** + * Unit tests for {@link org.tensorflow.Graph}. + */ public class GraphTest { @Test @@ -146,7 +148,7 @@ public void addGradientsToGraph() { Output y0 = tf.math.square(x1).y(); Output y1 = tf.math.square(y0).y(); Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum(); - + Output[] grads0 = g.addGradients(y1, toArray(x1)); assertNotNull(grads0); assertEquals(1, grads0.length); @@ -157,18 +159,17 @@ public void addGradientsToGraph() { assertEquals(2, grads1.length); assertEquals(DataType.DT_FLOAT, grads1[0].dataType()); assertEquals(DataType.DT_FLOAT, grads1[1].dataType()); - - try (TFloat32 c1 = TFloat32.scalarOf(3.0f); - TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = new AutoCloseableList<>( - s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run())) { - + + try (Tensor c1 = TFloat32.scalarOf(3.0f); + Tensor c2 = TFloat32.scalarOf(2.0f); + Session.Result outputs = s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run()) { + assertEquals(3, outputs.size()); assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); assertEquals(6.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); @@ -212,7 +213,7 @@ public void addGradientsWithInitialValuesToGraph() { Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); - + Output[] grad0 = g.addGradients(y1, toArray(y0)); assertNotNull(grad0); assertEquals(1, grad0.length); @@ -268,18 +269,18 @@ public void buildWhileLoopSingleInput() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output input = tf.placeholder(TInt32.class).output(); + Output input = tf.placeholder(TInt32.class).output(); @SuppressWarnings("unchecked") Output[] loopOutputs = g.whileLoop( toArray(input), (condGraph, condInputs, condOutputs) -> { Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); + condOutputs[0] = tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); }, (bodyGraph, bodyInputs, bodyOutputs) -> { Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); }, "test_loop"); @@ -300,8 +301,8 @@ public void buildWhileLoopMultipleInputs() { Session s = new Session(g)) { Ops tf = Ops.create(g); - Output input1 = tf.placeholder(TInt32.class).output(); - Output input2 = tf.placeholder(TInt32.class).output(); + Output input1 = tf.placeholder(TInt32.class).output(); + Output input2 = tf.placeholder(TInt32.class).output(); Output[] inputs = toArray(input1, input2); @SuppressWarnings("unchecked") @@ -309,25 +310,23 @@ public void buildWhileLoopMultipleInputs() { inputs, (condGraph, condInputs, condOutputs) -> { Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); + condOutputs[0] = tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); }, (bodyGraph, bodyInputs, bodyOutputs) -> { Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); - bodyOutputs[1] = tfb.math.square((Output)bodyInputs[1]).y(); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); + bodyOutputs[1] = tfb.math.square((Output) bodyInputs[1]).y(); }, "test_loop"); try (TInt32 c1 = TInt32.scalarOf(2); TInt32 c2 = TInt32.scalarOf(5); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() - .feed(input1, c1) - .feed(input2, c2) - .fetch(loopOutputs[0]) - .fetch(loopOutputs[1]) - .run())) { + Session.Result outputs = s.runner() + .feed(input1, c1) + .feed(input2, c2) + .fetch(loopOutputs[0]) + .fetch(loopOutputs[1]) + .run()) { assertEquals(2, outputs.size()); assertEquals(16, ((TInt32)outputs.get(0)).getInt()); // ((2^2)^2) assertEquals(625, ((TInt32)outputs.get(1)).getInt()); // ((5^2)^2) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index b1928bff51c..79e8720a856 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -49,8 +49,7 @@ public void runUsingOperationNames() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { + Session.Result outputs = s.runner().feed("X", x).fetch("Y").run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); } @@ -66,8 +65,7 @@ public void runUsingOperationHandles() { Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { + Session.Result outputs = s.runner().feed(feed, x).fetch(fetch).run()) { assertEquals(1, outputs.size()); assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); } @@ -131,8 +129,7 @@ public void runMultipleOutputs() { Ops tf = Ops.create(g); tf.withName("c1").constant(2718); tf.withName("c2").constant(31415); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); + Session.Result outputs = s.runner().fetch("c2").fetch("c1").run(); assertEquals(2, outputs.size()); assertEquals(31415, ((TInt32)outputs.get(0)).getInt()); assertEquals(2718, ((TInt32)outputs.get(1)).getInt()); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 5dd6903d913..b296681c72c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -60,8 +60,7 @@ public void createInts() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -79,8 +78,7 @@ public void createFloats() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -98,8 +96,7 @@ public void createDoubles() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -117,8 +114,7 @@ public void createLongs() { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } @@ -136,8 +132,7 @@ public void createStrings() throws IOException { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { + try (Session.Result t = sess.runner().fetch(op1).fetch(op2).run()) { assertEquals(array, t.get(0)); assertEquals(array, t.get(1)); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 80150b64bb6..59bf659ae48 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -21,7 +21,6 @@ import java.util.Arrays; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; @@ -48,9 +47,8 @@ public void createGradients() { assertEquals(2, grads.dy().size()); try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + Session.Result outputs = + sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run()) { assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); @@ -74,9 +72,8 @@ public void createGradientsWithSum() { assertNotNull(grads.dy()); assertEquals(1, grads.dy().size()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + try (Tensor c = TFloat32.scalarOf(3.0f); + Session.Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) { assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } @@ -100,10 +97,9 @@ public void createGradientsWithInitialValues() { assertNotNull(grads1.dy()); assertEquals(1, grads1.dy().size()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + try (Tensor c = TFloat32.scalarOf(3.0f); + Session.Result outputs = + sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) { assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index 4121baf3af1..8bbe98629c8 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -131,8 +131,8 @@ public void operationsComposingZerosAreCorrectlyNamed() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); - List results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); + Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.DTYPE); + Session.Result results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } } From 64b6f2e148c7fba61d38ef467f42b0a0f8242352 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 14:13:09 -0800 Subject: [PATCH 10/19] fix more tests that didn't show up the first time Signed-off-by: Ryan Nett --- .../java/org/tensorflow/framework/data/DatasetIteratorTest.java | 2 +- .../test/java/org/tensorflow/framework/data/MapDatasetTest.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 882a64ba54d..c96fe0b68e9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -53,7 +53,7 @@ public void testGraphIteration() { int batches = 0; while (true) { try { - List outputs = session.runner().fetch(x).fetch(y).run(); + Session.Result outputs = session.runner().fetch(x).fetch(y).run(); try (TInt32 xBatch = (TInt32)outputs.get(0); TInt32 yBatch = (TInt32)outputs.get(1)) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index 5f203427563..5e63db43716 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -78,7 +78,7 @@ public void testGraphIteration() { int batches = 0; while (true) { try { - List outputs = session.runner().fetch(X).fetch(y).run(); + Session.Result outputs = session.runner().fetch(X).fetch(y).run(); try (TInt32 XBatch = (TInt32)outputs.get(0); TInt32 yBatch = (TInt32)outputs.get(1)) { From 11934126ddeb417d7b39e9c103fa6bba940ec5a6 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 19:13:19 -0800 Subject: [PATCH 11/19] remove AutoCloseableList Signed-off-by: Ryan Nett --- .../org/tensorflow/AutoCloseableList.java | 27 ------------------- .../test/java/org/tensorflow/SessionTest.java | 7 +++-- .../org/tensorflow/op/core/ConstantTest.java | 1 - 3 files changed, 5 insertions(+), 30 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java deleted file mode 100644 index 330a40bae6b..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.tensorflow; - -import java.util.ArrayList; -import java.util.Collection; - -public final class AutoCloseableList extends ArrayList - implements AutoCloseable { - - public AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 79e8720a856..243b8a3fab1 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -24,6 +24,7 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; @@ -111,13 +112,15 @@ public void runWithMetadata() { .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); + List outputs = result.outputs; assertEquals(1, outputs.size()); assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); // Sanity check on metadata assertNotNull(result.metadata); assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); - outputs.close(); + for(Tensor output : outputs) { + output.close(); + } } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index b296681c72c..8955f0df5fe 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -20,7 +20,6 @@ import java.io.IOException; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.EagerSession; import org.tensorflow.Graph; import org.tensorflow.Session; From e95ae47f077ab9a5435a19b06988e175df807078 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 19:19:25 -0800 Subject: [PATCH 12/19] Add input size checking and closed checking Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 0293d9dc0f4..54a527fbb60 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -153,7 +153,15 @@ public final class Result implements AutoCloseable, Iterable>{ private final List> fetches; private final LinkedHashMap, Integer> indexMap; + private boolean closed = false; + private Result(List> results, List> fetches) { + + if(results.size() != fetches.size()){ + throw new IllegalArgumentException("Expected the same number of fetches and values, got " + fetches.size() + + " fetches and " + results.size() + " values."); + } + this.results = new ArrayList<>(results); this.fetches = new ArrayList<>(fetches); indexMap = new LinkedHashMap<>(); @@ -162,10 +170,17 @@ private Result(List> results, List> fetches) { } } + private void requireOpen(){ + if(closed) { + throw new IllegalStateException("Result has been closed, can not access it."); + } + } + /** * Get the result tensors. */ public List> getResults() { + requireOpen(); return results; } @@ -176,10 +191,18 @@ public List> getFetches() { return fetches; } + /** + * @return Whether the result has been closed. + */ + public boolean isClosed() { + return closed; + } + /** * Get the result at {@code index}. */ public Tensor get(int index){ + requireOpen(); return results.get(index); } @@ -188,6 +211,7 @@ public Tensor get(int index){ */ @SuppressWarnings("unchecked") public Tensor get(Output output){ + requireOpen(); if(!indexMap.containsKey(output)) throw new IllegalArgumentException("Did not fetch an output for " + output); return (Tensor) results.get(indexMap.get(output)); @@ -197,6 +221,7 @@ public Tensor get(Output output){ * Get the result for {@code operand} or throw an {@code IllegalArgumentException} if it wasn't fetched. */ public Tensor get(Operand operand){ + requireOpen(); return get(operand.asOutput()); } @@ -204,6 +229,7 @@ public Tensor get(Operand operand){ * Get the result for the {@code index}-th output of {@code operation} or throw an {@code IllegalArgumentException} if it wasn't fetched. */ public Tensor get(String operation, int index){ + requireOpen(); return get(graph.getOutput(operation, index)); } @@ -212,6 +238,7 @@ public Tensor get(String operation, int index){ * Get the result for the output specified by {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. */ public Tensor get(String output){ + requireOpen(); return get(graph.getOutput(output)); } @@ -220,23 +247,28 @@ public Tensor get(String output){ */ @Override public void close() { + requireOpen(); for(Tensor t : this){ t.close(); } + closed = true; } @Override public Iterator> iterator() { + requireOpen(); return results.iterator(); } @Override public void forEach(Consumer> action) { + requireOpen(); results.forEach(action); } @Override public Spliterator> spliterator() { + requireOpen(); return results.spliterator(); } From 701f7f065da76e52dc68ae12020b9a948276d3d4 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 19:26:31 -0800 Subject: [PATCH 13/19] add contains method, change map structure and add getter Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 55 ++++++++++++++++--- 1 file changed, 48 insertions(+), 7 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 54a527fbb60..bcc9a8ea6dc 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -27,6 +27,7 @@ import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; import java.util.Spliterator; import java.util.function.Consumer; import org.bytedeco.javacpp.BytePointer; @@ -151,7 +152,7 @@ public void close() { public final class Result implements AutoCloseable, Iterable>{ private final List> results; private final List> fetches; - private final LinkedHashMap, Integer> indexMap; + private final LinkedHashMap, Tensor> outputMap; private boolean closed = false; @@ -164,9 +165,9 @@ private Result(List> results, List> fetches) { this.results = new ArrayList<>(results); this.fetches = new ArrayList<>(fetches); - indexMap = new LinkedHashMap<>(); + outputMap = new LinkedHashMap<>(); for(int i = 0 ; i < fetches.size() ; i++){ - indexMap.put(fetches.get(i), i); + outputMap.put(fetches.get(i), results.get(i)); } } @@ -181,14 +182,21 @@ private void requireOpen(){ */ public List> getResults() { requireOpen(); - return results; + return new ArrayList<>(results); } /** * Get the outputs that were fetched. */ public List> getFetches() { - return fetches; + return new ArrayList<>(fetches); + } + + /** + * Get a map of the fetched outputs to their results. + */ + public Map, Tensor> getOutputMap(){ + return new LinkedHashMap<>(outputMap); } /** @@ -212,9 +220,9 @@ public Tensor get(int index){ @SuppressWarnings("unchecked") public Tensor get(Output output){ requireOpen(); - if(!indexMap.containsKey(output)) + if(!outputMap.containsKey(output)) throw new IllegalArgumentException("Did not fetch an output for " + output); - return (Tensor) results.get(indexMap.get(output)); + return (Tensor) outputMap.get(output); } /** @@ -242,6 +250,39 @@ public Tensor get(String output){ return get(graph.getOutput(output)); } + /** + * Returns {@code true} if {@code output} was fetched as part of this {@code Result}. + */ + public boolean contains(Output output){ + requireOpen(); + return outputMap.containsKey(output); + } + + /** + * Returns {@code true} if {@code operand} was fetched as part of this {@code Result}. + */ + public boolean contains(Operand operand){ + requireOpen(); + return contains(operand.asOutput()); + } + + /** + * Returns {@code true} if the {@code index}-th output of {@code operation} was fetched as part of this {@code Result}. + */ + public boolean contains(String operation, int index){ + requireOpen(); + return contains(graph.getOutput(operation, index)); + } + + + /** + * Returns {@code true} the output specified by {@code output} was fetched as part of this {@code Result} + */ + public boolean contains(String output){ + requireOpen(); + return contains(graph.getOutput(output)); + } + /** * Close all of the tensors contained by this {@code Result}. */ From 0a771b3d6446cf2226318124815209c54a22d64d Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 19:36:41 -0800 Subject: [PATCH 14/19] add isClosed to tensor, use to prevent double close Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index bcc9a8ea6dc..d706c6f04fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -284,13 +284,15 @@ public boolean contains(String output){ } /** - * Close all of the tensors contained by this {@code Result}. + * Close any open tensors contained by this {@code Result}. */ @Override public void close() { requireOpen(); for(Tensor t : this){ - t.close(); + if(!t.isClosed()) { + t.close(); + } } closed = true; } From 598da6292dd43d2174980b1f82c44961ee56d9f7 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 8 Dec 2020 19:43:40 -0800 Subject: [PATCH 15/19] fold Run into Result Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 30 ++++++++++++++----- .../test/java/org/tensorflow/SessionTest.java | 13 ++++---- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index d706c6f04fa..789ca000b0a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -154,15 +154,25 @@ public final class Result implements AutoCloseable, Iterable>{ private final List> fetches; private final LinkedHashMap, Tensor> outputMap; + /** + * Metadata about the run. + * + *

A RunMetadata + * protocol buffer. + */ + private final RunMetadata metadata; + private boolean closed = false; - private Result(List> results, List> fetches) { + private Result(List> results, List> fetches, RunMetadata metadata) { if(results.size() != fetches.size()){ throw new IllegalArgumentException("Expected the same number of fetches and values, got " + fetches.size() + " fetches and " + results.size() + " values."); } + this.metadata = metadata; this.results = new ArrayList<>(results); this.fetches = new ArrayList<>(fetches); outputMap = new LinkedHashMap<>(); @@ -199,6 +209,13 @@ public Map, Tensor> getOutputMap(){ return new LinkedHashMap<>(outputMap); } + /** + * Get the run metadata. May be null if not requested. + */ + public RunMetadata getMetadata() { + return metadata; + } + /** * @return Whether the result has been closed. */ @@ -502,7 +519,7 @@ public Runner setOptions(RunOptions options) { * @return a {@link Result} containing tensors fetched by this session runner */ public Result run() { - return new Result(runHelper(false).outputs, outputs); + return runHelper(false); } /** @@ -515,11 +532,11 @@ public Result run() { * * @return list of resulting tensors fetched by this session runner, with execution metadata */ - public Run runAndFetchMetadata() { + public Result runAndFetchMetadata() { return runHelper(true); } - private Run runHelper(boolean wantMetadata) { + private Result runHelper(boolean wantMetadata) { TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()]; TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()]; int[] inputOpIndices = new int[inputs.size()]; @@ -574,10 +591,7 @@ private Run runHelper(boolean wantMetadata) { } finally { runRef.close(); } - Run ret = new Run(); - ret.outputs = outputs; - ret.metadata = metadata; - return ret; + return new Result(outputs, this.outputs, metadata); } private class Reference implements AutoCloseable { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index 243b8a3fab1..d48093dee8f 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -106,21 +106,18 @@ public void runWithMetadata() { Ops tf = Ops.create(g); transpose_A_times_X(tf, new int[][] {{2}, {3}}); try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { - Session.Run result = s.runner() + Session.Result result = s.runner() .feed("X", x) .fetch("Y") .setOptions(fullTraceRunOptions()) .runAndFetchMetadata(); // Sanity check on outputs. - List outputs = result.outputs; - assertEquals(1, outputs.size()); + assertEquals(1, result.size()); assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); // Sanity check on metadata - assertNotNull(result.metadata); - assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); - for(Tensor output : outputs) { - output.close(); - } + assertNotNull(result.getMetadata()); + assertTrue(result.getMetadata().hasStepStats(), result.getMetadata().toString()); + result.close(); } } } From f67d6e2c02d6945bcec7a69af75f1e4906bbbd9f Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 9 Dec 2020 16:56:42 -0800 Subject: [PATCH 16/19] Use Collections.unmodifiable* Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 789ca000b0a..4beebb7ec80 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -24,6 +24,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import java.util.ArrayList; +import java.util.Collections; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; @@ -173,8 +174,8 @@ private Result(List> results, List> fetches, RunMetadata met } this.metadata = metadata; - this.results = new ArrayList<>(results); - this.fetches = new ArrayList<>(fetches); + this.results = results; + this.fetches = fetches; outputMap = new LinkedHashMap<>(); for(int i = 0 ; i < fetches.size() ; i++){ outputMap.put(fetches.get(i), results.get(i)); @@ -192,21 +193,21 @@ private void requireOpen(){ */ public List> getResults() { requireOpen(); - return new ArrayList<>(results); + return Collections.unmodifiableList(results); } /** * Get the outputs that were fetched. */ public List> getFetches() { - return new ArrayList<>(fetches); + return Collections.unmodifiableList(fetches); } /** * Get a map of the fetched outputs to their results. */ public Map, Tensor> getOutputMap(){ - return new LinkedHashMap<>(outputMap); + return Collections.unmodifiableMap(outputMap); } /** @@ -591,7 +592,7 @@ private Result runHelper(boolean wantMetadata) { } finally { runRef.close(); } - return new Result(outputs, this.outputs, metadata); + return new Result(outputs, new ArrayList<>(this.outputs), metadata); } private class Reference implements AutoCloseable { From 22f0a05ea389a6a2256c71839e076f888766c3a8 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 27 Dec 2020 18:59:15 -0800 Subject: [PATCH 17/19] rebase fixes Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/RawTensor.java | 12 +++++++ .../src/main/java/org/tensorflow/Session.java | 32 +++++++++---------- .../src/main/java/org/tensorflow/Tensor.java | 5 +++ .../org/tensorflow/types/family/TType.java | 5 +++ .../test/java/org/tensorflow/GraphTest.java | 4 +-- .../test/java/org/tensorflow/SessionTest.java | 2 +- .../org/tensorflow/op/core/GradientsTest.java | 4 +-- .../org/tensorflow/op/core/ZerosTest.java | 2 +- 8 files changed, 44 insertions(+), 22 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index c332fd7f1d1..dfc61ea12c2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -65,7 +65,18 @@ public RawTensor asRawTensor() { @Override public void close() { + if(closed) { + throw new IllegalStateException("Tensor has already been closed"); + } tensorScope.close(); + closed = true; + } + + /** + * @return {@code true} if this tensor has been closed; + */ + public boolean isClosed() { + return closed; } /** @@ -222,6 +233,7 @@ private static long[] shape(TF_Tensor handle) { } private PointerScope tensorScope; + private boolean closed = false; private TF_Tensor tensorHandle; private final TensorTypeInfo typeInfo; private final Shape shape; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 4beebb7ec80..56f5a1564e1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -150,10 +150,10 @@ public void close() { *

* Closing a {@code Result} object will close all of the tensors contained by it. */ - public final class Result implements AutoCloseable, Iterable>{ - private final List> results; + public final class Result implements AutoCloseable, Iterable{ + private final List results; private final List> fetches; - private final LinkedHashMap, Tensor> outputMap; + private final LinkedHashMap, Tensor> outputMap; /** * Metadata about the run. @@ -166,7 +166,7 @@ public final class Result implements AutoCloseable, Iterable>{ private boolean closed = false; - private Result(List> results, List> fetches, RunMetadata metadata) { + private Result(List results, List> fetches, RunMetadata metadata) { if(results.size() != fetches.size()){ throw new IllegalArgumentException("Expected the same number of fetches and values, got " + fetches.size() @@ -191,7 +191,7 @@ private void requireOpen(){ /** * Get the result tensors. */ - public List> getResults() { + public List getResults() { requireOpen(); return Collections.unmodifiableList(results); } @@ -206,7 +206,7 @@ public List> getFetches() { /** * Get a map of the fetched outputs to their results. */ - public Map, Tensor> getOutputMap(){ + public Map, Tensor> getOutputMap(){ return Collections.unmodifiableMap(outputMap); } @@ -227,7 +227,7 @@ public boolean isClosed() { /** * Get the result at {@code index}. */ - public Tensor get(int index){ + public Tensor get(int index){ requireOpen(); return results.get(index); } @@ -236,17 +236,17 @@ public Tensor get(int index){ * Get the result for {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. */ @SuppressWarnings("unchecked") - public Tensor get(Output output){ + public T get(Output output){ requireOpen(); if(!outputMap.containsKey(output)) throw new IllegalArgumentException("Did not fetch an output for " + output); - return (Tensor) outputMap.get(output); + return (T) outputMap.get(output); } /** * Get the result for {@code operand} or throw an {@code IllegalArgumentException} if it wasn't fetched. */ - public Tensor get(Operand operand){ + public T get(Operand operand){ requireOpen(); return get(operand.asOutput()); } @@ -254,7 +254,7 @@ public Tensor get(Operand operand){ /** * Get the result for the {@code index}-th output of {@code operation} or throw an {@code IllegalArgumentException} if it wasn't fetched. */ - public Tensor get(String operation, int index){ + public Tensor get(String operation, int index){ requireOpen(); return get(graph.getOutput(operation, index)); } @@ -263,7 +263,7 @@ public Tensor get(String operation, int index){ /** * Get the result for the output specified by {@code output} or throw an {@code IllegalArgumentException} if it wasn't fetched. */ - public Tensor get(String output){ + public Tensor get(String output){ requireOpen(); return get(graph.getOutput(output)); } @@ -307,7 +307,7 @@ public boolean contains(String output){ @Override public void close() { requireOpen(); - for(Tensor t : this){ + for(Tensor t : this){ if(!t.isClosed()) { t.close(); } @@ -316,19 +316,19 @@ public void close() { } @Override - public Iterator> iterator() { + public Iterator iterator() { requireOpen(); return results.iterator(); } @Override - public void forEach(Consumer> action) { + public void forEach(Consumer action) { requireOpen(); results.forEach(action); } @Override - public Spliterator> spliterator() { + public Spliterator spliterator() { requireOpen(); return results.spliterator(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index fc1275229bf..5294d902685 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -212,4 +212,9 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData */ @Override void close(); + + /** + * @return {@code true} if this tensor has been closed. + */ + boolean isClosed(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index 2fc423b914e..0545b5a794d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -80,4 +80,9 @@ default long numBytes() { default void close() { asRawTensor().close(); } + + @Override + default boolean isClosed(){ + return asRawTensor().isClosed(); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index 0fad33302d5..32f8cb4d18b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -160,8 +160,8 @@ public void addGradientsToGraph() { assertEquals(DataType.DT_FLOAT, grads1[0].dataType()); assertEquals(DataType.DT_FLOAT, grads1[1].dataType()); - try (Tensor c1 = TFloat32.scalarOf(3.0f); - Tensor c2 = TFloat32.scalarOf(2.0f); + try (TFloat32 c1 = TFloat32.scalarOf(3.0f); + TFloat32 c2 = TFloat32.scalarOf(2.0f); Session.Result outputs = s.runner() .feed(x1, c1) .feed(x2, c2) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index d48093dee8f..c1ec11f89f8 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -113,7 +113,7 @@ public void runWithMetadata() { .runAndFetchMetadata(); // Sanity check on outputs. assertEquals(1, result.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); + assertEquals(31, ((TInt32)result.get(0)).getInt(0, 0)); // Sanity check on metadata assertNotNull(result.getMetadata()); assertTrue(result.getMetadata().hasStepStats(), result.getMetadata().toString()); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 59bf659ae48..1a65ca90d34 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -72,7 +72,7 @@ public void createGradientsWithSum() { assertNotNull(grads.dy()); assertEquals(1, grads.dy().size()); - try (Tensor c = TFloat32.scalarOf(3.0f); + try (TFloat32 c = TFloat32.scalarOf(3.0f); Session.Result outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run()) { assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); @@ -97,7 +97,7 @@ public void createGradientsWithInitialValues() { assertNotNull(grads1.dy()); assertEquals(1, grads1.dy().size()); - try (Tensor c = TFloat32.scalarOf(3.0f); + try (TFloat32 c = TFloat32.scalarOf(3.0f); Session.Result outputs = sess.runner().feed(x, c).fetch(grads1.dy(0)).run()) { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index 8bbe98629c8..ef83f0117b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -131,7 +131,7 @@ public void operationsComposingZerosAreCorrectlyNamed() { Session sess = new Session(g)) { Scope scope = new Scope(g); long[] shape = {2, 2}; - Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.DTYPE); + Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); Session.Result results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); } } From 58ff9350002352cf16f7cc68eefa893190cceb9e Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 27 Dec 2020 19:02:52 -0800 Subject: [PATCH 18/19] docs update Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 56f5a1564e1..b2d34076292 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -410,7 +410,7 @@ public Runner feed(Operand operand, Tensor t) { * the {@code SignatureDef} protocol buffer messages that are included in {@link * SavedModelBundle#metaGraphDef()}. * @return this session runner - * @see Graph#getOutput(String, int) + * @see Graph#getOutput(String) */ public Runner fetch(String operation) { return fetch(graph.getOutput(operation)); @@ -459,6 +459,7 @@ public Runner fetch(Operand operand) { * * @param operation the string name of the operation to execute * @return this session runner + * @see Graph#operationOrError(String) */ public Runner addTarget(String operation) { GraphOperation op = graph.operationOrError(operation); From 3dc6defe2bdbcb9d484b68f013cc87b976afac33 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Tue, 29 Dec 2020 17:57:08 -0800 Subject: [PATCH 19/19] not sure why this order is swapped Signed-off-by: Ryan Nett --- .../src/gen/annotations/org/tensorflow/op/Ops.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index d6e69085324..5a8a9bace8a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -347,10 +347,10 @@ public final class Ops { public final SignalOps signal; - public final QuantizationOps quantization; - public final TrainOps train; + public final QuantizationOps quantization; + private final Scope scope; private Ops(Scope scope) { @@ -372,8 +372,8 @@ private Ops(Scope scope) { math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - quantization = new QuantizationOps(this); train = new TrainOps(this); + quantization = new QuantizationOps(this); } /**