From 998677d9bad8a55ac1dbc2e050aec40e08a0bd8e Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 8 Jan 2021 17:20:24 -0800 Subject: [PATCH 01/35] Start of TensorScope Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/RawTensor.java | 59 ++++++++++---- .../src/main/java/org/tensorflow/Tensor.java | 19 +++++ .../main/java/org/tensorflow/TensorScope.java | 80 +++++++++++++++++++ 3 files changed, 143 insertions(+), 15 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index c332fd7f1d1..84238144ef2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -34,9 +34,9 @@ * A tensor which memory has not been mapped to a data space directly accessible from the JVM. * *

A raw tensor is a minimalist representation of a tensor allocated in native memory by the - * TensorFlow runtime library and it controls its lifetime within the current process. The data - * is represented by a flat {@link ByteDataBuffer buffer of bytes}, until it is mapped in a - * n-dimensional typed space by a {@link TType typed tensor}.

+ * TensorFlow runtime library and it controls its lifetime within the current process. The data is represented by a flat + * {@link ByteDataBuffer buffer of bytes}, until it is mapped in a n-dimensional typed space by a {@link TType typed + * tensor}.

* *

Instances of a RawTensor are not thread-safe and their resource must be released * by calling {@link #close()} explicitly or implicitly via try-with-resources.

@@ -65,7 +65,30 @@ public RawTensor asRawTensor() { @Override public void close() { - tensorScope.close(); + if (!closed) { + tensorScope.close(); + closed = true; + } + } + + @Override + public boolean isClosed() { + return closed; + } + + @Override + public void detach() { + TensorScope.detach(this); + } + + @Override + public boolean attachToCurrentScope() { + TensorScope currentScope = TensorScope.getInnerScope(); + if (currentScope != null) { + currentScope.attach(this); + return true; + } + return false; } /** @@ -93,19 +116,19 @@ public String toString() { * Allocates a new tensor in native memory of the given type, shape and size. * *

The size of the tensor must be at least large enough to contain all scalars for the - * given type and shape. More memory can also be allocated to store also metadata within the - * tensor itself, e.g. a lookup table in a string tensor. + * given type and shape. More memory can also be allocated to store also metadata within the tensor itself, e.g. a + * lookup table in a string tensor. * * @param type tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor, or -1 to compute the size from the shape * @return allocated tensor - * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to - * store the tensor data - * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given - * {@code type} are of variable length (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to store the tensor + * data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code type} are of + * variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension() + * unknown} * @throws IllegalStateException if tensor failed to be allocated */ static RawTensor allocate(Class type, Shape shape, long size) { @@ -147,9 +170,9 @@ static RawTensor fromHandle(TF_Tensor handle) { TensorTypeInfo typeInfo = TensorTypeRegistry.find(DataType.forNumber(dtype(handle))); RawTensor t = new RawTensor(typeInfo, Shape.of(shape(handle))); try (PointerScope scope = new PointerScope()) { - scope.attach(handle); - t.tensorHandle = handle; - t.tensorScope = scope.extend(); + scope.attach(handle); + t.tensorHandle = handle; + t.tensorScope = scope.extend(); } return t; } @@ -168,6 +191,7 @@ static RawTensor fromHandle(TF_Tensor handle, EagerSession session) { /** * Returns the native handle to this tensor + * * @throws IllegalStateException if tensor has been closed */ TF_Tensor nativeHandle() { @@ -219,9 +243,14 @@ private static long[] shape(TF_Tensor handle) { RawTensor(TensorTypeInfo typeInfo, Shape shape) { this.typeInfo = typeInfo; this.shape = shape; + TensorScope currentScope = TensorScope.getInnerScope(); + if (currentScope != null) { + currentScope.attach(this); + } } private PointerScope tensorScope; + private boolean closed; private TF_Tensor tensorHandle; private final TensorTypeInfo typeInfo; private final Shape shape; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index fc1275229bf..7eaedb76dbf 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -212,4 +212,23 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData */ @Override void close(); + + /** + * Get whether this tensor has been closed. + */ + boolean isClosed(); + + /** + * Detach this tensor from any scopes managing it. It must be manually closed or attached to another scope. + */ + default void detach(){ + asRawTensor().detach(); + } + + /** + * Attach this tensor to the current scope. No-ops and returns false if there is no current scope. + */ + default boolean attachToCurrentScope(){ + return asRawTensor().attachToCurrentScope(); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java new file mode 100644 index 00000000000..04898cdd1ef --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -0,0 +1,80 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.Set; +import org.bytedeco.javacpp.PointerScope; +import org.tensorflow.ndarray.NdArray; + +public class TensorScope implements AutoCloseable{ + + static final ThreadLocal> scopeStack = ThreadLocal.withInitial(ArrayDeque::new); + + /** Returns {@code scopeStack.get().peek()}, the last opened scope not yet closed. */ + public static TensorScope getInnerScope() { + return scopeStack.get().peek(); + } + + /** Returns {@code scopeStack.get().iterator()}, all scopes not yet closed. */ + public static Iterator getScopeIterator() { + return scopeStack.get().iterator(); + } + + /** + * Detaches the given tensor from any scopes managing it, requiring it to be manually closed. + */ + public static void detach(Tensor t){ + RawTensor raw = t.asRawTensor(); + getScopeIterator().forEachRemaining(scope -> scope.detachTensor(raw)); + } + + public TensorScope(){ + scopeStack.get().push(this); + } + + /** + * Attach a tensor to this scope. This happens automatically to tensors that are created in the scope. + */ + public void attach(Tensor t){ + tensors.add(t.asRawTensor()); + } + + + /** + * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + */ + public void attach(Tensor... tensors){ + for(Tensor t : tensors){ + attach(t); + } + } + + private void detachTensor(Tensor t){ + tensors.remove(t.asRawTensor()); + } + + @Override + public void close() throws Exception { + tensors.forEach(Tensor::close); + } + + private final Set tensors = new LinkedHashSet<>(); +} From 591d44d39d52329268c5a190360ba66364114718 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 13 Jan 2021 16:00:06 -0800 Subject: [PATCH 02/35] Finish TensorScope, add test Signed-off-by: Ryan Nett --- .../java/org/tensorflow/EagerOperation.java | 2 +- .../main/java/org/tensorflow/RawTensor.java | 26 ++-- .../src/main/java/org/tensorflow/Tensor.java | 13 +- .../main/java/org/tensorflow/TensorScope.java | 124 ++++++++++++++---- .../org/tensorflow/types/family/TType.java | 20 +++ .../java/org/tensorflow/TensorScopeTest.java | 98 ++++++++++++++ 6 files changed, 235 insertions(+), 48 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 9f87fd8b95e..07691734db4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -166,7 +166,7 @@ private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession TF_Status status = TF_Status.newStatus(); TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(); status.throwExceptionIfNotOK(); - return RawTensor.fromHandle(tensor, session).asTypedTensor(); + return RawTensor.fromHandle(tensor).asTypedTensor(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index 84238144ef2..b864ccb64c1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -81,9 +81,14 @@ public void detach() { TensorScope.detach(this); } + @Override + public boolean isAttached() { + return attached; + } + @Override public boolean attachToCurrentScope() { - TensorScope currentScope = TensorScope.getInnerScope(); + TensorScope currentScope = TensorScope.getCurrentScope(); if (currentScope != null) { currentScope.attach(this); return true; @@ -177,18 +182,6 @@ static RawTensor fromHandle(TF_Tensor handle) { return t; } - /** - * Create an eager Tensor object from a handle to the C TF_Tensor object. - * - *

Takes ownership of the handle. - */ - static RawTensor fromHandle(TF_Tensor handle, EagerSession session) { - RawTensor t = fromHandle(handle); - session.attach(handle); - t.tensorScope.detach(handle); - return t; - } - /** * Returns the native handle to this tensor * @@ -243,14 +236,15 @@ private static long[] shape(TF_Tensor handle) { RawTensor(TensorTypeInfo typeInfo, Shape shape) { this.typeInfo = typeInfo; this.shape = shape; - TensorScope currentScope = TensorScope.getInnerScope(); - if (currentScope != null) { - currentScope.attach(this); + TensorScope scope = TensorScope.getCurrentScope(); + if (scope != null) { + scope.attach(this); } } private PointerScope tensorScope; private boolean closed; + boolean attached = false; private TF_Tensor tensorHandle; private final TensorTypeInfo typeInfo; private final Shape shape; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 7eaedb76dbf..c38a00626ad 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -221,14 +221,15 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData /** * Detach this tensor from any scopes managing it. It must be manually closed or attached to another scope. */ - default void detach(){ - asRawTensor().detach(); - } + void detach(); + + /** + * Returns true if this tensor is attached to a {@link TensorScope}. + */ + boolean isAttached(); /** * Attach this tensor to the current scope. No-ops and returns false if there is no current scope. */ - default boolean attachToCurrentScope(){ - return asRawTensor().attachToCurrentScope(); - } + boolean attachToCurrentScope(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 04898cdd1ef..b56482af2fb 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -17,64 +17,138 @@ package org.tensorflow; import java.util.ArrayDeque; +import java.util.Collections; import java.util.Deque; -import java.util.Iterator; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.Set; -import org.bytedeco.javacpp.PointerScope; -import org.tensorflow.ndarray.NdArray; +import java.util.WeakHashMap; + + +/** + * A scope that can be used to manage tensor resources. Any tensors created between a scope's creation and calling + * {@code close()}, that haven't been detached, are guaranteed to be closed with the scope (even if they are created in + * a sub-scope). Tensors may be manually closed earlier without issue. + *

+ * Tensors are automatically tracked on creation. A tensor can me manually added to a scope with {@link + * TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}, or by passing them to {@link + * TensorScope#TensorScope(Tensor...)}. The tensor will then be closed when the first of it's managing scopes closes. + *

+ * {@link Tensor#detach()} detaches the tensor from all scopes, requiring the user to close it manually or attach it to + * another scope. + *

+ * Note that scope management is thread local, except for detach, which will detach even from scopes on other threads. + */ +public class TensorScope implements AutoCloseable { -public class TensorScope implements AutoCloseable{ + private static final Set allScopes = Collections.newSetFromMap(new WeakHashMap<>()); - static final ThreadLocal> scopeStack = ThreadLocal.withInitial(ArrayDeque::new); + private static final ThreadLocal> scopeStack = ThreadLocal.withInitial(ArrayDeque::new); - /** Returns {@code scopeStack.get().peek()}, the last opened scope not yet closed. */ - public static TensorScope getInnerScope() { + /** + * Returns {@code scopeStack.get().peek()}, the last opened scope not yet closed on this thread. + */ + static TensorScope getCurrentScope() { return scopeStack.get().peek(); } - /** Returns {@code scopeStack.get().iterator()}, all scopes not yet closed. */ - public static Iterator getScopeIterator() { - return scopeStack.get().iterator(); - } - /** * Detaches the given tensor from any scopes managing it, requiring it to be manually closed. */ - public static void detach(Tensor t){ + public static void detach(Tensor t) { RawTensor raw = t.asRawTensor(); - getScopeIterator().forEachRemaining(scope -> scope.detachTensor(raw)); + synchronized (TensorScope.class) { + allScopes.forEach(x -> x.detachTensor(raw)); + } + raw.attached = false; + } + + /** + * Create a new tensor scope with the given thread locality. + */ + public TensorScope() { + localScopeStack = scopeStack.get(); + + synchronized (TensorScope.class) { + allScopes.add(this); + } + localScopeStack.push(this); } - public TensorScope(){ - scopeStack.get().push(this); + /** + * Create a new tensor scope with the given thread locality, and attach the given tensors. + */ + public TensorScope(Tensor... tensors) { + this(); + attach(tensors); } /** * Attach a tensor to this scope. This happens automatically to tensors that are created in the scope. */ - public void attach(Tensor t){ - tensors.add(t.asRawTensor()); + public void attach(Tensor t) { + RawTensor rt = t.asRawTensor(); + rt.attached = true; + tensors.add(rt); } /** * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. */ - public void attach(Tensor... tensors){ - for(Tensor t : tensors){ - attach(t); + public void attach(Tensor... tensors) { + if (tensors != null) { + for (Tensor t : tensors) { + attach(t); + } } } - private void detachTensor(Tensor t){ + private void detachTensor(Tensor t) { tensors.remove(t.asRawTensor()); } - @Override - public void close() throws Exception { + private void closeScope() { tensors.forEach(Tensor::close); + + synchronized (TensorScope.class) { + allScopes.remove(this); + } + + closed = true; + } + + /** + * Closes this scope and its tensors, and any inner scopes. + */ + @Override + public void close() { + if (closed) { + return; + } + + if (!localScopeStack.contains(this)) { + throw new IllegalStateException("This scope is not on the scope stack, but was not closed." + + " This should not be possible."); + } + + while (true) { + TensorScope ts = localScopeStack.removeLast(); + ts.closeScope(); + if (ts == this) { + return; + } + } + } + + /** + * Gets whether the scope is closed. + */ + public boolean isClosed() { + return closed; } - private final Set tensors = new LinkedHashSet<>(); + private boolean closed = false; + private final Set tensors = new HashSet<>(); + private final Deque localScopeStack; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index 2fc423b914e..a2bffe29575 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -80,4 +80,24 @@ default long numBytes() { default void close() { asRawTensor().close(); } + + @Override + default boolean isClosed(){ + return asRawTensor().isClosed(); + } + + @Override + default void detach(){ + asRawTensor().detach(); + } + + @Override + default boolean isAttached(){ + return asRawTensor().isAttached(); + } + + @Override + default boolean attachToCurrentScope(){ + return asRawTensor().attachToCurrentScope(); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java new file mode 100644 index 00000000000..52cbb73564c --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -0,0 +1,98 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.TFloat32; + +/** + * Unit tests for {@link TensorScope} + */ +public class TensorScopeTest { + + private static TFloat32 makeTensor(long size) { + return TFloat32.tensorOf(Shape.of(size), x -> { + for (long i = 0; i < size; i++) { + x.setFloat(0, i); + } + }); + } + + @Test + public void testBasicScope() { + TensorScope scope = new TensorScope(); + + TFloat32 tensor = makeTensor(10); + TFloat32 detachTensor = makeTensor(10); + detachTensor.detach(); + + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + assertFalse(detachTensor.isAttached()); + assertFalse(detachTensor.isClosed()); + + scope.close(); + + assertTrue(tensor.isClosed()); + assertTrue(scope.isClosed()); + assertFalse(detachTensor.isClosed()); + } + + @Test + public void testNestedScope() { + TensorScope outerScope = new TensorScope(); + TensorScope scope = new TensorScope(); + + TFloat32 tensor = makeTensor(10); + TFloat32 detachTensor = makeTensor(10); + detachTensor.detach(); + + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + assertFalse(detachTensor.isAttached()); + assertFalse(detachTensor.isClosed()); + + outerScope.close(); + + assertTrue(tensor.isClosed()); + assertTrue(scope.isClosed()); + assertTrue(outerScope.isClosed()); + assertFalse(detachTensor.isClosed()); + } + + @Test + public void testAttach(){ + TensorScope firstScope = new TensorScope(); + TFloat32 tensor = makeTensor(10); + TensorScope secondScope = new TensorScope(tensor); + + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + secondScope.close(); + + assertTrue(tensor.isClosed()); + } + + +} From 34e1429893d6e876fd262ba5afbe4972b6d2c81d Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 13 Jan 2021 16:03:26 -0800 Subject: [PATCH 03/35] Javadoc updates Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/TensorScope.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index b56482af2fb..b98ca80c75d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -65,6 +65,7 @@ public static void detach(Tensor t) { /** * Create a new tensor scope with the given thread locality. + * @see TensorScope */ public TensorScope() { localScopeStack = scopeStack.get(); @@ -77,6 +78,7 @@ public TensorScope() { /** * Create a new tensor scope with the given thread locality, and attach the given tensors. + * @see TensorScope */ public TensorScope(Tensor... tensors) { this(); From 700f0f8ce23cf74185428f2ad140cf188fb21e34 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 13 Jan 2021 16:07:00 -0800 Subject: [PATCH 04/35] Add a TensorScope to EagerSession to replicate pointer attachment Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/EagerSession.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 8e7465388a8..192b1437a41 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -361,6 +361,7 @@ void detach(Pointer... resources) { private final WeakPointerScope nativeResources; private TFE_Context nativeHandle; + private final TensorScope tensorScope = new TensorScope(); private EagerSession(Options options) { this.nativeResources = new WeakPointerScope(); @@ -374,6 +375,7 @@ private void checkSession() { } private synchronized void doClose() { + tensorScope.close(); if (nativeHandle != null && !nativeHandle.isNull()) { nativeResources.close(); delete(nativeHandle); From 8231407c1ae5ab64cd3b2ae4ebb15fd7a0598909 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 13 Jan 2021 16:18:41 -0800 Subject: [PATCH 05/35] Make auto-attach optional, use TensorScope in eager session to close tensors. Signed-off-by: Ryan Nett --- .../java/org/tensorflow/EagerOperation.java | 5 +- .../java/org/tensorflow/EagerSession.java | 6 ++- .../main/java/org/tensorflow/RawTensor.java | 12 +---- .../src/main/java/org/tensorflow/Tensor.java | 3 +- .../main/java/org/tensorflow/TensorScope.java | 54 +++++++++++++++---- 5 files changed, 56 insertions(+), 24 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 07691734db4..7302ffaa4d9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -30,6 +30,7 @@ import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.proto.framework.DataType; +import org.tensorflow.types.family.TType; /** * Implementation of an {@link Operation} executed eagerly. @@ -166,7 +167,9 @@ private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession TF_Status status = TF_Status.newStatus(); TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(); status.throwExceptionIfNotOK(); - return RawTensor.fromHandle(tensor).asTypedTensor(); + TType typedTensor = RawTensor.fromHandle(tensor).asTypedTensor(); + session.attachTensor(typedTensor); + return typedTensor; } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index 192b1437a41..b363d40d94e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -357,11 +357,15 @@ void detach(Pointer... resources) { } } + void attachTensor(Tensor tensor){ + tensorScope.attach(tensor); + } + private static volatile EagerSession defaultSession = null; private final WeakPointerScope nativeResources; private TFE_Context nativeHandle; - private final TensorScope tensorScope = new TensorScope(); + private final TensorScope tensorScope = new TensorScope(false); private EagerSession(Options options) { this.nativeResources = new WeakPointerScope(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index b864ccb64c1..0dafad93581 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -88,12 +88,7 @@ public boolean isAttached() { @Override public boolean attachToCurrentScope() { - TensorScope currentScope = TensorScope.getCurrentScope(); - if (currentScope != null) { - currentScope.attach(this); - return true; - } - return false; + return TensorScope.autoAttach(this); } /** @@ -236,10 +231,7 @@ private static long[] shape(TF_Tensor handle) { RawTensor(TensorTypeInfo typeInfo, Shape shape) { this.typeInfo = typeInfo; this.shape = shape; - TensorScope scope = TensorScope.getCurrentScope(); - if (scope != null) { - scope.attach(this); - } + TensorScope.autoAttach(this); } private PointerScope tensorScope; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index c38a00626ad..3930e2cca91 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -229,7 +229,8 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData boolean isAttached(); /** - * Attach this tensor to the current scope. No-ops and returns false if there is no current scope. + * Attach this tensor to the most recent scope that accepts automatic attachment. + * No-ops and returns false if there is no scope that does so. */ boolean attachToCurrentScope(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index b98ca80c75d..2c3cd45050f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -20,17 +20,18 @@ import java.util.Collections; import java.util.Deque; import java.util.HashSet; +import java.util.Iterator; import java.util.LinkedHashSet; import java.util.Set; import java.util.WeakHashMap; /** - * A scope that can be used to manage tensor resources. Any tensors created between a scope's creation and calling + * A scope that can be used to manage tensor resources. If auto-attach is used, any tensors created between a scope's creation and calling * {@code close()}, that haven't been detached, are guaranteed to be closed with the scope (even if they are created in * a sub-scope). Tensors may be manually closed earlier without issue. *

- * Tensors are automatically tracked on creation. A tensor can me manually added to a scope with {@link + * When auto-attach is true, tensors are automatically tracked on creation. A tensor can me manually added to a scope with {@link * TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}, or by passing them to {@link * TensorScope#TensorScope(Tensor...)}. The tensor will then be closed when the first of it's managing scopes closes. *

@@ -46,10 +47,19 @@ public class TensorScope implements AutoCloseable { private static final ThreadLocal> scopeStack = ThreadLocal.withInitial(ArrayDeque::new); /** - * Returns {@code scopeStack.get().peek()}, the last opened scope not yet closed on this thread. + * Attach the tensor to the most recent scope that accepts automatic attachment. + * @return true if attached. */ - static TensorScope getCurrentScope() { - return scopeStack.get().peek(); + static boolean autoAttach(Tensor tensor) { + Iterator iterator = scopeStack.get().descendingIterator(); + while (iterator.hasNext()) { + TensorScope scope = iterator.next(); + if (scope.autoAttach) { + scope.attach(tensor); + return true; + } + } + return false; } /** @@ -64,27 +74,48 @@ public static void detach(Tensor t) { } /** - * Create a new tensor scope with the given thread locality. + * Create a new tensor scope. If {@code autoAttach} is false, will not automatically manage tensors. + * * @see TensorScope */ - public TensorScope() { - localScopeStack = scopeStack.get(); + public TensorScope(boolean autoAttach) { + this.autoAttach = autoAttach; synchronized (TensorScope.class) { allScopes.add(this); } + + localScopeStack = scopeStack.get(); localScopeStack.push(this); } /** - * Create a new tensor scope with the given thread locality, and attach the given tensors. + * Create a new tensor scope that automatically manages tensors. + */ + public TensorScope() { + this(true); + } + + /** + * Create a new tensor, and attach the given tensors. If {@code autoAttach} is false, will not automatically manage + * tensors. + * * @see TensorScope */ - public TensorScope(Tensor... tensors) { - this(); + public TensorScope(boolean autoAttach, Tensor... tensors) { + this(autoAttach); attach(tensors); } + /** + * Create a new tensor scope that automatically manages tensors, and attach the given tensors. + * + * @see TensorScope + */ + public TensorScope(Tensor... tensors) { + this(true, tensors); + } + /** * Attach a tensor to this scope. This happens automatically to tensors that are created in the scope. */ @@ -150,6 +181,7 @@ public boolean isClosed() { return closed; } + private final boolean autoAttach; private boolean closed = false; private final Set tensors = new HashSet<>(); private final Deque localScopeStack; From a9982a491547f04096ad3bfb952c4ebfde7eefb3 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 13 Jan 2021 16:42:51 -0800 Subject: [PATCH 06/35] Test for non-auto-attach scope Signed-off-by: Ryan Nett --- .../java/org/tensorflow/TensorScopeTest.java | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index 52cbb73564c..d9bb8dd3f43 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -55,6 +55,7 @@ public void testBasicScope() { assertTrue(tensor.isClosed()); assertTrue(scope.isClosed()); assertFalse(detachTensor.isClosed()); + detachTensor.close(); } @Test @@ -78,6 +79,7 @@ public void testNestedScope() { assertTrue(scope.isClosed()); assertTrue(outerScope.isClosed()); assertFalse(detachTensor.isClosed()); + detachTensor.close(); } @Test @@ -94,5 +96,24 @@ public void testAttach(){ assertTrue(tensor.isClosed()); } + @Test + public void testNoAutoAttach(){ + TensorScope scope = new TensorScope(false); + TFloat32 tensor = makeTensor(10); + assertFalse(tensor.isAttached()); + + TFloat32 detachTensor = makeTensor(10); + assertFalse(detachTensor.isAttached()); + + scope.attach(detachTensor); + assertTrue(detachTensor.isAttached()); + + detachTensor.detach(); + assertFalse(detachTensor.isAttached()); + + tensor.close(); + detachTensor.close(); + } + } From d30d7b10ee7d4106fb3079a5cffc493dc09634b1 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 13 Jan 2021 16:43:58 -0800 Subject: [PATCH 07/35] cleanup scopes Signed-off-by: Ryan Nett --- .../src/test/java/org/tensorflow/TensorScopeTest.java | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index d9bb8dd3f43..a60b5a1851f 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -94,6 +94,7 @@ public void testAttach(){ secondScope.close(); assertTrue(tensor.isClosed()); + firstScope.close(); } @Test @@ -113,6 +114,7 @@ public void testNoAutoAttach(){ tensor.close(); detachTensor.close(); + scope.close(); } From 8a7e8be73a4d4aefce781a88625dbae3a121da20 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Wed, 13 Jan 2021 16:56:39 -0800 Subject: [PATCH 08/35] HasTensors abstraction for resource management of multiple tensors Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/HasTensors.java | 63 +++++++++++++++ .../src/main/java/org/tensorflow/Tensor.java | 2 +- .../main/java/org/tensorflow/TensorScope.java | 77 ++++++++++++------- .../java/org/tensorflow/TensorScopeTest.java | 2 +- 4 files changed, 115 insertions(+), 29 deletions(-) create mode 100644 tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java new file mode 100644 index 00000000000..f6cd0f199f1 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java @@ -0,0 +1,63 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +/** + * An interface representing a collection or group of tensors. Provides methods for resource management. + */ +public interface HasTensors extends AutoCloseable { + + /** + * Get the tensors held by this object. + */ + Iterable tensors(); + + /** + * Detach these tensors from any scopes managing them. They must be manually closed or attached to another scope. + * @see Tensor#detach() + */ + default void detach(){ + tensors().forEach(Tensor::detach); + } + + /** + * Attach all of these tensors to the most recent scope that accepts automatic attachment. + * No-ops and returns false if there is no scope that does so. + * @see Tensor#detach() + */ + default boolean attachToCurrentScope(){ + if(!TensorScope.hasAutoScope()){ + return false; + } + tensors().forEach(Tensor::attachToCurrentScope); + return true; + } + + /** + * Release resources associated with these tensors. + * + *

WARNING:This must be invoked for all tensors that were not been produced by an eager + * operation or memory will be leaked. May be done automatically via {@link TensorScope}. + * + *

The Tensor objects are no longer usable after {@code close} returns. + * @see Tensor#close() + */ + @Override + default void close(){ + tensors().forEach(Tensor::close); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 3930e2cca91..834c0c819fd 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -206,7 +206,7 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData * Release resources associated with the Tensor. * *

WARNING:This must be invoked for all tensors that were not been produced by an eager - * operation or memory will be leaked. + * operation or memory will be leaked. May be done automatically via {@link TensorScope}. * *

The Tensor object is no longer usable after {@code close} returns. */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 2c3cd45050f..6fffe0600a9 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -27,13 +27,12 @@ /** - * A scope that can be used to manage tensor resources. If auto-attach is used, any tensors created between a scope's creation and calling - * {@code close()}, that haven't been detached, are guaranteed to be closed with the scope (even if they are created in - * a sub-scope). Tensors may be manually closed earlier without issue. + * A scope that can be used to manage tensor resources. If auto-attach is used, any tensors created between a scope's + * creation and calling {@code close()}, that haven't been detached, are guaranteed to be closed with the scope (even if + * they are created in a sub-scope). Tensors may be manually closed earlier without issue. *

- * When auto-attach is true, tensors are automatically tracked on creation. A tensor can me manually added to a scope with {@link - * TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}, or by passing them to {@link - * TensorScope#TensorScope(Tensor...)}. The tensor will then be closed when the first of it's managing scopes closes. + * When auto-attach is true, tensors are automatically tracked on creation. A tensor can me manually added to a scope + * with {@link TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}. The tensor will then be closed when the first of it's managing scopes closes. *

* {@link Tensor#detach()} detaches the tensor from all scopes, requiring the user to close it manually or attach it to * another scope. @@ -48,6 +47,7 @@ public class TensorScope implements AutoCloseable { /** * Attach the tensor to the most recent scope that accepts automatic attachment. + * * @return true if attached. */ static boolean autoAttach(Tensor tensor) { @@ -62,6 +62,20 @@ static boolean autoAttach(Tensor tensor) { return false; } + /** + * Return true if there is a scope that accepts auto attachment on the stack. + */ + public static boolean hasAutoScope() { + Iterator iterator = scopeStack.get().descendingIterator(); + while (iterator.hasNext()) { + TensorScope scope = iterator.next(); + if (scope.autoAttach) { + return true; + } + } + return false; + } + /** * Detaches the given tensor from any scopes managing it, requiring it to be manually closed. */ @@ -97,44 +111,53 @@ public TensorScope() { } /** - * Create a new tensor, and attach the given tensors. If {@code autoAttach} is false, will not automatically manage - * tensors. - * - * @see TensorScope + * Attach a tensor to this scope. This happens automatically to tensors that are created in the scope. + * @return this */ - public TensorScope(boolean autoAttach, Tensor... tensors) { - this(autoAttach); - attach(tensors); + public TensorScope attach(Tensor t) { + RawTensor rt = t.asRawTensor(); + rt.attached = true; + tensors.add(rt); + + return this; } /** - * Create a new tensor scope that automatically manages tensors, and attach the given tensors. - * - * @see TensorScope + * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * @return this */ - public TensorScope(Tensor... tensors) { - this(true, tensors); + public TensorScope attach(Tensor... tensors) { + if (tensors != null) { + for (Tensor t : tensors) { + attach(t); + } + } + + return this; } /** - * Attach a tensor to this scope. This happens automatically to tensors that are created in the scope. + * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * @return this */ - public void attach(Tensor t) { - RawTensor rt = t.asRawTensor(); - rt.attached = true; - tensors.add(rt); - } + public TensorScope attach(HasTensors tensors) { + tensors.tensors().forEach(this::attach); + return this; + } /** * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * @return this */ - public void attach(Tensor... tensors) { + public TensorScope attach(HasTensors... tensors) { if (tensors != null) { - for (Tensor t : tensors) { - attach(t); + for (HasTensors ht : tensors) { + attach(ht); } } + + return this; } private void detachTensor(Tensor t) { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index a60b5a1851f..74f4f6d19f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -86,7 +86,7 @@ public void testNestedScope() { public void testAttach(){ TensorScope firstScope = new TensorScope(); TFloat32 tensor = makeTensor(10); - TensorScope secondScope = new TensorScope(tensor); + TensorScope secondScope = new TensorScope().attach(tensor); assertTrue(tensor.isAttached()); assertFalse(tensor.isClosed()); From 38b98f04cfc9e0fcd1240869eec4be805277ae6a Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 14 Jan 2021 22:39:32 -0800 Subject: [PATCH 09/35] Iterable attach methods Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/TensorScope.java | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 6fffe0600a9..d6590a01f11 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -160,6 +160,30 @@ public TensorScope attach(HasTensors... tensors) { return this; } + /** + * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * @return this + */ + public TensorScope attach(Iterable tensors) { + tensors.forEach(this::attach); + + return this; + } + + /** + * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * @return this + */ + public TensorScope attach(Iterable... tensors) { + if (tensors != null) { + for (Iterable ht : tensors) { + attach(ht); + } + } + + return this; + } + private void detachTensor(Tensor t) { tensors.remove(t.asRawTensor()); } From bfe6aa538d46fbd3bd2dea1e64f8113bc8789082 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 15 Jan 2021 17:59:53 -0800 Subject: [PATCH 10/35] refactor hierarchy, add release to parent methods Signed-off-by: Ryan Nett --- .../java/org/tensorflow/EagerOperation.java | 4 +- .../java/org/tensorflow/EagerSession.java | 6 - .../main/java/org/tensorflow/HasTensors.java | 32 +++- .../main/java/org/tensorflow/RawTensor.java | 30 ++- .../src/main/java/org/tensorflow/Tensor.java | 13 +- .../main/java/org/tensorflow/TensorScope.java | 174 ++++++++++-------- .../org/tensorflow/types/family/TType.java | 9 +- .../java/org/tensorflow/TensorScopeTest.java | 38 ++-- 8 files changed, 179 insertions(+), 127 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 7302ffaa4d9..4e9394b7df0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -167,9 +167,7 @@ private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession TF_Status status = TF_Status.newStatus(); TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(); status.throwExceptionIfNotOK(); - TType typedTensor = RawTensor.fromHandle(tensor).asTypedTensor(); - session.attachTensor(typedTensor); - return typedTensor; + return RawTensor.fromHandle(tensor).asTypedTensor(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java index b363d40d94e..8e7465388a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerSession.java @@ -357,15 +357,10 @@ void detach(Pointer... resources) { } } - void attachTensor(Tensor tensor){ - tensorScope.attach(tensor); - } - private static volatile EagerSession defaultSession = null; private final WeakPointerScope nativeResources; private TFE_Context nativeHandle; - private final TensorScope tensorScope = new TensorScope(false); private EagerSession(Options options) { this.nativeResources = new WeakPointerScope(); @@ -379,7 +374,6 @@ private void checkSession() { } private synchronized void doClose() { - tensorScope.close(); if (nativeHandle != null && !nativeHandle.isNull()) { nativeResources.close(); delete(nativeHandle); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java index f6cd0f199f1..ae0661e0da7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java @@ -28,23 +28,34 @@ public interface HasTensors extends AutoCloseable { /** * Detach these tensors from any scopes managing them. They must be manually closed or attached to another scope. + * * @see Tensor#detach() */ - default void detach(){ + default void detach() { tensors().forEach(Tensor::detach); } /** - * Attach all of these tensors to the most recent scope that accepts automatic attachment. - * No-ops and returns false if there is no scope that does so. - * @see Tensor#detach() + * Attach all of these tensors to the most recent scope. + * + * @throws IllegalStateException if there is no active scope. + * @see Tensor#attachToCurrentScope() */ - default boolean attachToCurrentScope(){ - if(!TensorScope.hasAutoScope()){ - return false; + default void attachToCurrentScope() { + TensorScope scope = TensorScope.currentScope(); + if (scope == null) { + throw new IllegalStateException("Can't attach to current scope: no active tensor scopes."); } - tensors().forEach(Tensor::attachToCurrentScope); - return true; + + tensors().forEach(scope::attach); + } + + /** + * Attach these tensors to the parent of their current scope, removing it from it's current scope. + * @throws IllegalStateException if any tensors do not have a scope, or their scope does not have a parent. + */ + default void attachToParent(){ + tensors().forEach(Tensor::attachToParent); } /** @@ -54,10 +65,11 @@ default boolean attachToCurrentScope(){ * operation or memory will be leaked. May be done automatically via {@link TensorScope}. * *

The Tensor objects are no longer usable after {@code close} returns. + * * @see Tensor#close() */ @Override - default void close(){ + default void close() { tensors().forEach(Tensor::close); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index 0dafad93581..f334a9d4369 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -83,12 +83,28 @@ public void detach() { @Override public boolean isAttached() { - return attached; + return scope != null; } @Override - public boolean attachToCurrentScope() { - return TensorScope.autoAttach(this); + public synchronized void attachToParent() { + if(scope == null){ + throw new IllegalStateException("Can't attach to parent: no scope."); + } + if(scope.parent == null){ + throw new IllegalStateException("Can't attach to parent: scope does not have a parent."); + } + + scope.parent.attach(this); + } + + @Override + public void attachToCurrentScope() { + TensorScope scope = TensorScope.currentScope(); + if(scope == null){ + throw new IllegalStateException("Can't attach to current scope: no active tensor scopes."); + } + scope.attach(this); } /** @@ -231,12 +247,16 @@ private static long[] shape(TF_Tensor handle) { RawTensor(TensorTypeInfo typeInfo, Shape shape) { this.typeInfo = typeInfo; this.shape = shape; - TensorScope.autoAttach(this); + + TensorScope currentScope = TensorScope.currentScope(); + if(currentScope != null) { + this.scope = currentScope.attach(this); + } } private PointerScope tensorScope; private boolean closed; - boolean attached = false; + TensorScope scope; private TF_Tensor tensorHandle; private final TensorTypeInfo typeInfo; private final Shape shape; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 834c0c819fd..7009f03355b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -223,14 +223,21 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData */ void detach(); + /** + * Attach this tensor to the parent of it's current scope, removing it from it's current scope. + * @throws IllegalStateException if it does not have a scope, or it's scope does not have a parent. + */ + void attachToParent(); + /** * Returns true if this tensor is attached to a {@link TensorScope}. */ boolean isAttached(); /** - * Attach this tensor to the most recent scope that accepts automatic attachment. - * No-ops and returns false if there is no scope that does so. + * Attach this tensor to the most recent scope. + * + * @throws IllegalStateException if there are no active scopes */ - boolean attachToCurrentScope(); + void attachToCurrentScope(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index d6590a01f11..81a41097cb4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -24,6 +24,7 @@ import java.util.LinkedHashSet; import java.util.Set; import java.util.WeakHashMap; +import java.util.concurrent.ConcurrentHashMap; /** @@ -32,7 +33,8 @@ * they are created in a sub-scope). Tensors may be manually closed earlier without issue. *

* When auto-attach is true, tensors are automatically tracked on creation. A tensor can me manually added to a scope - * with {@link TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}. The tensor will then be closed when the first of it's managing scopes closes. + * with {@link TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}. The tensor will then be closed + * when the first of it's managing scopes closes. *

* {@link Tensor#detach()} detaches the tensor from all scopes, requiring the user to close it manually or attach it to * another scope. @@ -41,50 +43,32 @@ */ public class TensorScope implements AutoCloseable { - private static final Set allScopes = Collections.newSetFromMap(new WeakHashMap<>()); + private static final InheritableThreadLocal currentScope = new InheritableThreadLocal<>(); - private static final ThreadLocal> scopeStack = ThreadLocal.withInitial(ArrayDeque::new); + public static TensorScope currentScope() { + TensorScope scope = currentScope.get(); - /** - * Attach the tensor to the most recent scope that accepts automatic attachment. - * - * @return true if attached. - */ - static boolean autoAttach(Tensor tensor) { - Iterator iterator = scopeStack.get().descendingIterator(); - while (iterator.hasNext()) { - TensorScope scope = iterator.next(); - if (scope.autoAttach) { - scope.attach(tensor); - return true; - } + if (scope == null || !scope.closed) { + return scope; } - return false; - } - /** - * Return true if there is a scope that accepts auto attachment on the stack. - */ - public static boolean hasAutoScope() { - Iterator iterator = scopeStack.get().descendingIterator(); - while (iterator.hasNext()) { - TensorScope scope = iterator.next(); - if (scope.autoAttach) { - return true; - } + // scope could be closed in another thread, in which case this thread's currentScope wouldn't be updated + while (scope != null && scope.closed) { + scope = scope.parent; } - return false; + currentScope.set(scope); + return scope; } - /** - * Detaches the given tensor from any scopes managing it, requiring it to be manually closed. - */ - public static void detach(Tensor t) { - RawTensor raw = t.asRawTensor(); - synchronized (TensorScope.class) { - allScopes.forEach(x -> x.detachTensor(raw)); + public static void detach(Tensor tensor) { + // ensure that I'm not attaching or detaching at the same time in different threads + RawTensor rt = tensor.asRawTensor(); + synchronized (rt) { + if (rt.scope != null) { + rt.scope.tensors.remove(rt); + rt.scope = null; + } } - raw.attached = false; } /** @@ -92,38 +76,41 @@ public static void detach(Tensor t) { * * @see TensorScope */ - public TensorScope(boolean autoAttach) { - this.autoAttach = autoAttach; + public TensorScope() { + this.parent = currentScope(); + currentScope.set(this); - synchronized (TensorScope.class) { - allScopes.add(this); + if (this.parent != null) { + synchronized (this.parent) { + this.parent.children.add(this); + } } - - localScopeStack = scopeStack.get(); - localScopeStack.push(this); - } - - /** - * Create a new tensor scope that automatically manages tensors. - */ - public TensorScope() { - this(true); } /** * Attach a tensor to this scope. This happens automatically to tensors that are created in the scope. + * * @return this */ - public TensorScope attach(Tensor t) { - RawTensor rt = t.asRawTensor(); - rt.attached = true; - tensors.add(rt); + public synchronized TensorScope attach(Tensor tensor) { + if (this.closed) { + throw new IllegalStateException("Scope has been closed, can not attach new tensor."); + } + + RawTensor rt = tensor.asRawTensor(); + // ensure that I'm not attaching or detaching at the same time in different threads + synchronized (rt) { + detach(tensor); + rt.scope = this; + tensors.add(rt); + } return this; } /** * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * * @return this */ public TensorScope attach(Tensor... tensors) { @@ -138,6 +125,7 @@ public TensorScope attach(Tensor... tensors) { /** * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * * @return this */ public TensorScope attach(HasTensors tensors) { @@ -148,6 +136,7 @@ public TensorScope attach(HasTensors tensors) { /** * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * * @return this */ public TensorScope attach(HasTensors... tensors) { @@ -162,6 +151,7 @@ public TensorScope attach(HasTensors... tensors) { /** * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * * @return this */ public TensorScope attach(Iterable tensors) { @@ -172,9 +162,11 @@ public TensorScope attach(Iterable tensors) { /** * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. + * * @return this */ - public TensorScope attach(Iterable... tensors) { + @SafeVarargs + public final TensorScope attach(Iterable... tensors) { if (tensors != null) { for (Iterable ht : tensors) { attach(ht); @@ -184,52 +176,76 @@ public TensorScope attach(Iterable... tensors) { return this; } - private void detachTensor(Tensor t) { - tensors.remove(t.asRawTensor()); - } + /** + * Closes this scope and its tensors, and any inner scopes. + */ + @Override + public synchronized void close() { + if (closed) { + return; + } - private void closeScope() { + children.forEach(TensorScope::close); tensors.forEach(Tensor::close); - synchronized (TensorScope.class) { - allScopes.remove(this); + closed = true; + + parent.children.remove(this); + + if (currentScope() == this) { + currentScope.set(this.parent); } + } - closed = true; + /** + * Release the tensors and child scopes of this scope to it's parent, without closing them. + * + * @throws IllegalStateException if this scope has no parent. + */ + public synchronized void releaseToParent() { + release(true); } /** - * Closes this scope and its tensors, and any inner scopes. + * Release the tensors and child scopes of this scope without closing them, to it's parent if it has one. + * + * @param requireParent Whether to require a parent scope to release resources to. + * @throws IllegalStateException if this scope has no parent, but {@code requireParent} is true. */ - @Override - public void close() { + public synchronized void release(boolean requireParent) { if (closed) { return; } - if (!localScopeStack.contains(this)) { - throw new IllegalStateException("This scope is not on the scope stack, but was not closed." - + " This should not be possible."); + if (this.parent == null && requireParent) { + throw new IllegalStateException("Can't release to parent: scope does not have parent."); } - while (true) { - TensorScope ts = localScopeStack.removeLast(); - ts.closeScope(); - if (ts == this) { - return; - } + if (this.parent != null) { + TensorScope newParent = this.parent; + newParent.children.addAll(children); + children.forEach(x -> x.parent = newParent); + tensors.forEach(newParent::attach); + } else { + children.forEach(x -> x.parent = null); + tensors.forEach(TensorScope::detach); } + + children.clear(); + tensors.clear(); + + close(); } /** * Gets whether the scope is closed. */ - public boolean isClosed() { + public synchronized boolean isClosed() { return closed; } - private final boolean autoAttach; private boolean closed = false; - private final Set tensors = new HashSet<>(); - private final Deque localScopeStack; + private final Set tensors = ConcurrentHashMap.newKeySet(); + TensorScope parent; + private final Set children = ConcurrentHashMap.newKeySet(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index a2bffe29575..b860a86be5a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -91,13 +91,18 @@ default void detach(){ asRawTensor().detach(); } + @Override + default void attachToParent(){ + asRawTensor().attachToParent(); + } + @Override default boolean isAttached(){ return asRawTensor().isAttached(); } @Override - default boolean attachToCurrentScope(){ - return asRawTensor().attachToCurrentScope(); + default void attachToCurrentScope(){ + asRawTensor().attachToCurrentScope(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index 74f4f6d19f7..76a83954d41 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -97,25 +97,25 @@ public void testAttach(){ firstScope.close(); } - @Test - public void testNoAutoAttach(){ - TensorScope scope = new TensorScope(false); - TFloat32 tensor = makeTensor(10); - assertFalse(tensor.isAttached()); - - TFloat32 detachTensor = makeTensor(10); - assertFalse(detachTensor.isAttached()); - - scope.attach(detachTensor); - assertTrue(detachTensor.isAttached()); - - detachTensor.detach(); - assertFalse(detachTensor.isAttached()); - - tensor.close(); - detachTensor.close(); - scope.close(); - } +// @Test +// public void testNoAutoAttach(){ +// TensorScope scope = new TensorScope(false); +// TFloat32 tensor = makeTensor(10); +// assertFalse(tensor.isAttached()); +// +// TFloat32 detachTensor = makeTensor(10); +// assertFalse(detachTensor.isAttached()); +// +// scope.attach(detachTensor); +// assertTrue(detachTensor.isAttached()); +// +// detachTensor.detach(); +// assertFalse(detachTensor.isAttached()); +// +// tensor.close(); +// detachTensor.close(); +// scope.close(); +// } } From 393b5dacbda3dcafd18bd5c919526cecaf7b8795 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 15 Jan 2021 18:20:39 -0800 Subject: [PATCH 11/35] fix NPE Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/TensorScope.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 81a41097cb4..ec521c39d72 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -190,7 +190,9 @@ public synchronized void close() { closed = true; - parent.children.remove(this); + if(parent != null) { + parent.children.remove(this); + } if (currentScope() == this) { currentScope.set(this.parent); From 7a3f365842873f78ef49f1e07dbe0fd2737e8019 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 15 Jan 2021 18:27:07 -0800 Subject: [PATCH 12/35] Javadoc updates Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Tensor.java | 1 + .../main/java/org/tensorflow/TensorScope.java | 22 ++++++++++++------- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 7009f03355b..b6715a2e764 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -39,6 +39,7 @@ * doSomethingWith(t); * } * } + *

This can be done automatically using {@link TensorScope}. *

Instances of a Tensor are not thread-safe. */ public interface Tensor extends Shaped, AutoCloseable { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index ec521c39d72..9d6f2acb5ed 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -28,18 +28,21 @@ /** - * A scope that can be used to manage tensor resources. If auto-attach is used, any tensors created between a scope's - * creation and calling {@code close()}, that haven't been detached, are guaranteed to be closed with the scope (even if - * they are created in a sub-scope). Tensors may be manually closed earlier without issue. + * A scope that can be used to manage tensor resources. Any tensors created between a scope's + * creation and calling {@code close()} that haven't been detached or attached to a different scope are guaranteed to + * be closed with the scope (even if they are created in a sub-scope). Tensors may be manually closed earlier without + * issue. *

- * When auto-attach is true, tensors are automatically tracked on creation. A tensor can me manually added to a scope - * with {@link TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}. The tensor will then be closed - * when the first of it's managing scopes closes. + * Tensors are automatically tracked on creation. A tensor can me manually added to a scope + * with {@link TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}. A tensor may only have one scope: + * if it currently has a scope when {@code attach} is called, it is removed from it's original scope. *

- * {@link Tensor#detach()} detaches the tensor from all scopes, requiring the user to close it manually or attach it to + * {@link Tensor#detach()} detaches the tensor from it's scope, requiring the user to close it manually or attach it to * another scope. *

- * Note that scope management is thread local, except for detach, which will detach even from scopes on other threads. + * Note that scope management is mostly thread local. The current scope hierarchy will be inherited by new threads, + * and closing a scope will close it's children regardless of which threads they are on, but the active scope is + * thread local. */ public class TensorScope implements AutoCloseable { @@ -211,6 +214,9 @@ public synchronized void releaseToParent() { /** * Release the tensors and child scopes of this scope without closing them, to it's parent if it has one. * + *

WARNING: this method may release resources without assigning them to another scope if + * {@code requireParent} is false. {@link #releaseToParent()} should be used instead wherever possible. + * * @param requireParent Whether to require a parent scope to release resources to. * @throws IllegalStateException if this scope has no parent, but {@code requireParent} is true. */ From cd6f3d2e7d272baf20388bb8fca84d23d3dc66ed Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 15 Jan 2021 18:40:22 -0800 Subject: [PATCH 13/35] New tests, remove eager session tensor closing test Signed-off-by: Ryan Nett --- .../java/org/tensorflow/TensorScopeTest.java | 80 ++++++++++++++----- .../test/java/org/tensorflow/TensorTest.java | 17 ---- 2 files changed, 61 insertions(+), 36 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index 76a83954d41..b36f7f3e60e 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -97,25 +97,67 @@ public void testAttach(){ firstScope.close(); } -// @Test -// public void testNoAutoAttach(){ -// TensorScope scope = new TensorScope(false); -// TFloat32 tensor = makeTensor(10); -// assertFalse(tensor.isAttached()); -// -// TFloat32 detachTensor = makeTensor(10); -// assertFalse(detachTensor.isAttached()); -// -// scope.attach(detachTensor); -// assertTrue(detachTensor.isAttached()); -// -// detachTensor.detach(); -// assertFalse(detachTensor.isAttached()); -// -// tensor.close(); -// detachTensor.close(); -// scope.close(); -// } + @Test + public void testUpwardsAttach(){ + TensorScope firstScope = new TensorScope(); + TFloat32 tensor = makeTensor(10); + TensorScope secondScope = new TensorScope().attach(tensor); + + firstScope.close(); + + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + secondScope.close(); + + assertTrue(tensor.isClosed()); + } + + @Test + public void testReleaseToParentScope() { + TensorScope outerScope = new TensorScope(); + TensorScope scope = new TensorScope(); + + TFloat32 tensor = makeTensor(10); + + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + scope.releaseToParent(); + + assertTrue(scope.isClosed()); + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + outerScope.close(); + + assertTrue(tensor.isClosed()); + assertTrue(outerScope.isClosed()); + } + + @Test + public void testAttachToParentScope() { + TensorScope outerScope = new TensorScope(); + TensorScope scope = new TensorScope(); + + TFloat32 tensor = makeTensor(10); + + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + tensor.attachToParent(); + + scope.close(); + + assertTrue(scope.isClosed()); + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + outerScope.close(); + + assertTrue(tensor.isClosed()); + assertTrue(outerScope.isClosed()); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java index 9415a986222..6b9cb202b97 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java @@ -487,23 +487,6 @@ public void useAfterClose() { } } - @Test - public void eagerTensorIsReleasedAfterSessionIsClosed() { - TInt32 sum; - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); - sum = tf.math.add(tf.constant(10), tf.constant(20)).asTensor(); - sum.asRawTensor().nativeHandle(); // does not throw - assertEquals(30, sum.getInt()); - } - try { - sum.asRawTensor().nativeHandle(); - fail("Tensor native handle should have been closed by ending eager session"); - } catch (IllegalStateException e) { - // as expected - } - } - @Test public void fromHandle() { // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been From 6bc6fceda01739217927216fddab65a0f3e26ce2 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 15 Jan 2021 18:46:41 -0800 Subject: [PATCH 14/35] remove incorrect test Signed-off-by: Ryan Nett --- .../java/org/tensorflow/TensorScopeTest.java | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index b36f7f3e60e..22ca77147a2 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -97,22 +97,6 @@ public void testAttach(){ firstScope.close(); } - @Test - public void testUpwardsAttach(){ - TensorScope firstScope = new TensorScope(); - TFloat32 tensor = makeTensor(10); - TensorScope secondScope = new TensorScope().attach(tensor); - - firstScope.close(); - - assertTrue(tensor.isAttached()); - assertFalse(tensor.isClosed()); - - secondScope.close(); - - assertTrue(tensor.isClosed()); - } - @Test public void testReleaseToParentScope() { TensorScope outerScope = new TensorScope(); From e2cd3665aa8171de8d1ba91022915d66ef7d8c50 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 15 Jan 2021 18:48:53 -0800 Subject: [PATCH 15/35] clarify threading docs Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/TensorScope.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 9d6f2acb5ed..ab3da511ceb 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -40,9 +40,8 @@ * {@link Tensor#detach()} detaches the tensor from it's scope, requiring the user to close it manually or attach it to * another scope. *

- * Note that scope management is mostly thread local. The current scope hierarchy will be inherited by new threads, - * and closing a scope will close it's children regardless of which threads they are on, but the active scope is - * thread local. + * Scopes will be inherited at thread creation, but further scope creation on different threads will be independent, + * other than having the same parent. Closing a scope will close it's children regardless of which threads they are on. */ public class TensorScope implements AutoCloseable { From f3ff90abdc196468f5165ec9f55822a785f8d36f Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 15 Jan 2021 18:50:29 -0800 Subject: [PATCH 16/35] grammar Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/HasTensors.java | 2 +- .../src/main/java/org/tensorflow/Tensor.java | 4 ++-- .../src/main/java/org/tensorflow/TensorScope.java | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java index ae0661e0da7..ccb2f242fb2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java @@ -51,7 +51,7 @@ default void attachToCurrentScope() { } /** - * Attach these tensors to the parent of their current scope, removing it from it's current scope. + * Attach these tensors to the parent of their current scope, removing it from its current scope. * @throws IllegalStateException if any tensors do not have a scope, or their scope does not have a parent. */ default void attachToParent(){ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index b6715a2e764..bd56146fa03 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -225,8 +225,8 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData void detach(); /** - * Attach this tensor to the parent of it's current scope, removing it from it's current scope. - * @throws IllegalStateException if it does not have a scope, or it's scope does not have a parent. + * Attach this tensor to the parent of it's current scope, removing it from its current scope. + * @throws IllegalStateException if it does not have a scope, or its scope does not have a parent. */ void attachToParent(); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index ab3da511ceb..cb8afe50af4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -35,7 +35,7 @@ *

* Tensors are automatically tracked on creation. A tensor can me manually added to a scope * with {@link TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}. A tensor may only have one scope: - * if it currently has a scope when {@code attach} is called, it is removed from it's original scope. + * if it currently has a scope when {@code attach} is called, it is removed from its original scope. *

* {@link Tensor#detach()} detaches the tensor from it's scope, requiring the user to close it manually or attach it to * another scope. From 37e03746c7758b754e4e232a6b390337b61b7d51 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 15 Jan 2021 18:51:06 -0800 Subject: [PATCH 17/35] formatting Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/HasTensors.java | 3 +- .../main/java/org/tensorflow/RawTensor.java | 8 +-- .../src/main/java/org/tensorflow/Tensor.java | 70 +++++++++---------- .../main/java/org/tensorflow/TensorScope.java | 15 ++-- .../java/org/tensorflow/TensorScopeTest.java | 2 +- 5 files changed, 48 insertions(+), 50 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java index ccb2f242fb2..455afdcfb31 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java @@ -52,9 +52,10 @@ default void attachToCurrentScope() { /** * Attach these tensors to the parent of their current scope, removing it from its current scope. + * * @throws IllegalStateException if any tensors do not have a scope, or their scope does not have a parent. */ - default void attachToParent(){ + default void attachToParent() { tensors().forEach(Tensor::attachToParent); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index f334a9d4369..f754349c715 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -88,10 +88,10 @@ public boolean isAttached() { @Override public synchronized void attachToParent() { - if(scope == null){ + if (scope == null) { throw new IllegalStateException("Can't attach to parent: no scope."); } - if(scope.parent == null){ + if (scope.parent == null) { throw new IllegalStateException("Can't attach to parent: scope does not have a parent."); } @@ -101,7 +101,7 @@ public synchronized void attachToParent() { @Override public void attachToCurrentScope() { TensorScope scope = TensorScope.currentScope(); - if(scope == null){ + if (scope == null) { throw new IllegalStateException("Can't attach to current scope: no active tensor scopes."); } scope.attach(this); @@ -249,7 +249,7 @@ private static long[] shape(TF_Tensor handle) { this.shape = shape; TensorScope currentScope = TensorScope.currentScope(); - if(currentScope != null) { + if (currentScope != null) { this.scope = currentScope.attach(this); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index bd56146fa03..0eb35ec3614 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -26,13 +26,13 @@ * A statically typed multi-dimensional array. * *

There are two categories of tensors in TensorFlow Java: {@link TType typed tensors} and - * {@link RawTensor raw tensors}. The former maps the tensor native memory to an - * n-dimensional typed data space, allowing direct I/O operations from the JVM, while the latter - * is only a reference to a native tensor allowing basic operations and flat data access.

+ * {@link RawTensor raw tensors}. The former maps the tensor native memory to an n-dimensional typed data space, + * allowing direct I/O operations from the JVM, while the latter is only a reference to a native tensor allowing basic + * operations and flat data access.

* *

WARNING: Resources consumed by the Tensor object must be explicitly freed by - * invoking the {@link #close()} method when the object is no longer needed. For example, using a - * try-with-resources block: + * invoking the {@link #close()} method when the object is no longer needed. For example, using a try-with-resources + * block: * *

{@code
  * try (Tensor t = Tensor.of(...)) {
@@ -54,10 +54,9 @@ public interface Tensor extends Shaped, AutoCloseable {
    * @param type the tensor type class
    * @param shape shape of the tensor
    * @return an allocated but uninitialized tensor
-   * @throws IllegalArgumentException if elements of the given {@code type} are of variable length
-   *                                  (e.g. strings)
-   * @throws IllegalArgumentException if {@code shape} is totally or partially
-   *                                  {@link Shape#hasUnknownDimension() unknown}
+   * @throws IllegalArgumentException if elements of the given {@code type} are of variable length (e.g. strings)
+   * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension()
+   * unknown}
    * @throws IllegalStateException if tensor failed to be allocated
    */
   static  T of(Class type, Shape shape) {
@@ -68,27 +67,27 @@ static  T of(Class type, Shape shape) {
    * Allocates a tensor of a given datatype, shape and size.
    *
    * 

This method is identical to {@link #of(Class, Shape)}, except that the final size of the - * tensor can be explicitly set instead of computing it from the datatype and shape, which could be - * larger than the actual space required to store the data but not smaller. + * tensor can be explicitly set instead of computing it from the datatype and shape, which could be larger than the + * actual space required to store the data but not smaller. * * @param the tensor type * @param type the tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor or -1 to compute the size from the shape * @return an allocated but uninitialized tensor - * @see #of(Class, Shape) - * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to - * store the tensor data - * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given - * {@code type} are of variable length (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to store the tensor + * data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code type} are of + * variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension() + * unknown} * @throws IllegalStateException if tensor failed to be allocated + * @see #of(Class, Shape) */ static T of(Class type, Shape shape, long size) { RawTensor tensor = RawTensor.allocate(type, shape, size); try { - return (T)tensor.asTypedTensor(); + return (T) tensor.asTypedTensor(); } catch (Exception e) { tensor.close(); throw e; @@ -99,8 +98,8 @@ static T of(Class type, Shape shape, long size) { * Allocates and initialize a tensor of a given datatype and shape. * *

The amount of memory to allocate is derived from the datatype and the shape of the tensor. - * Tensor data is initialized by calling the {@code dataInitializer}, which receives in argument - * the value returned by {@link #data()} on the allocated tensor. For example: + * Tensor data is initialized by calling the {@code dataInitializer}, which receives in argument the value returned by + * {@link #data()} on the allocated tensor. For example: * *

{@code
    * FloatNdArray data = ...
@@ -117,10 +116,9 @@ static  T of(Class type, Shape shape, long size) {
    * @param shape shape of the tensor
    * @param dataInitializer method receiving accessor to the allocated tensor data for initialization
    * @return an allocated and initialized tensor
-   * @throws IllegalArgumentException if elements of the given {@code type} are of variable length
-   *                                  (e.g. strings)
-   * @throws IllegalArgumentException if {@code shape} is totally or partially
-   *                                  {@link Shape#hasUnknownDimension() unknown}
+   * @throws IllegalArgumentException if elements of the given {@code type} are of variable length (e.g. strings)
+   * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension()
+   * unknown}
    * @throws IllegalStateException if tensor failed to be allocated
    */
   static  T of(Class type, Shape shape, Consumer dataInitializer) {
@@ -142,14 +140,14 @@ static  T of(Class type, Shape shape, Consumer dataInitia
    * @param size size in bytes of the tensor or -1 to compute the size from the shape
    * @param dataInitializer method receiving accessor to the allocated tensor data for initialization
    * @return an allocated and initialized tensor
-   * @see #of(Class, Shape, long, Consumer)
-   * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to
-   *                                  store the tensor data
-   * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given
-   *                                  {@code type} are of variable length (e.g. strings)
-   * @throws IllegalArgumentException if {@code shape} is totally or partially
-   *                                  {@link Shape#hasUnknownDimension() unknown}
+   * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to store the tensor
+   * data
+   * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code type} are of
+   * variable length (e.g. strings)
+   * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension()
+   * unknown}
    * @throws IllegalStateException if tensor failed to be allocated
+   * @see #of(Class, Shape, long, Consumer)
    */
   static  T of(Class type, Shape shape, long size, Consumer dataInitializer) {
     T tensor = of(type, shape, size);
@@ -172,10 +170,9 @@ static  T of(Class type, Shape shape, long size, Consumer
    * @param type the tensor type class
    * @param shape the tensor shape.
    * @param rawData a buffer containing the tensor raw data.
-   * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor
-   *                                  data
-   * @throws IllegalArgumentException if {@code shape} is totally or partially
-   *                                  {@link Shape#hasUnknownDimension() unknown}
+   * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor data
+   * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension()
+   * unknown}
    * @throws IllegalStateException if tensor failed to be allocated with the given parameters
    */
   static  T of(Class type, Shape shape, ByteDataBuffer rawData) {
@@ -226,6 +223,7 @@ static  T of(Class type, Shape shape, ByteDataBuffer rawData
 
   /**
    * Attach this tensor to the parent of it's current scope, removing it from its current scope.
+   *
    * @throws IllegalStateException if it does not have a scope, or its scope does not have a parent.
    */
   void attachToParent();
diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java
index cb8afe50af4..b389b891c6e 100644
--- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java
+++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java
@@ -28,14 +28,13 @@
 
 
 /**
- * A scope that can be used to manage tensor resources.  Any tensors created between a scope's
- * creation and calling {@code close()} that haven't been detached or attached to a different scope are guaranteed to
- * be closed with the scope (even if they are created in a sub-scope).  Tensors may be manually closed earlier without
- * issue.
+ * A scope that can be used to manage tensor resources.  Any tensors created between a scope's creation and calling
+ * {@code close()} that haven't been detached or attached to a different scope are guaranteed to be closed with the
+ * scope (even if they are created in a sub-scope).  Tensors may be manually closed earlier without issue.
  * 

- * Tensors are automatically tracked on creation. A tensor can me manually added to a scope - * with {@link TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}. A tensor may only have one scope: - * if it currently has a scope when {@code attach} is called, it is removed from its original scope. + * Tensors are automatically tracked on creation. A tensor can me manually added to a scope with {@link + * TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}. A tensor may only have one scope: if it + * currently has a scope when {@code attach} is called, it is removed from its original scope. *

* {@link Tensor#detach()} detaches the tensor from it's scope, requiring the user to close it manually or attach it to * another scope. @@ -192,7 +191,7 @@ public synchronized void close() { closed = true; - if(parent != null) { + if (parent != null) { parent.children.remove(this); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index 22ca77147a2..f1652233730 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -83,7 +83,7 @@ public void testNestedScope() { } @Test - public void testAttach(){ + public void testAttach() { TensorScope firstScope = new TensorScope(); TFloat32 tensor = makeTensor(10); TensorScope secondScope = new TensorScope().attach(tensor); From 08376347b9f2ace6aa5becd69810e5a7028a2b16 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 15 Jan 2021 18:52:41 -0800 Subject: [PATCH 18/35] Add note about different scopes Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/HasTensors.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java index 455afdcfb31..2c8fa17a6b4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java @@ -53,6 +53,10 @@ default void attachToCurrentScope() { /** * Attach these tensors to the parent of their current scope, removing it from its current scope. * + *

Note that if tensors have different scopes, each tensor will be attached to its scope's parent. + * {@link TensorScope#attach(HasTensors)} or {@link #attachToCurrentScope()} can be used to ensure all tensors have + * the same scope. + * * @throws IllegalStateException if any tensors do not have a scope, or their scope does not have a parent. */ default void attachToParent() { From 555b13bc9f0f23e2edfbbccd59298cf884b963da Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 16 Jan 2021 15:17:15 -0800 Subject: [PATCH 19/35] Add option to not require parent to Tensor and HasTensors Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/HasTensors.java | 20 +++++++++++++++- .../main/java/org/tensorflow/RawTensor.java | 10 +++++--- .../src/main/java/org/tensorflow/Tensor.java | 21 +++++++++++++++-- .../org/tensorflow/types/family/TType.java | 23 +++++++++---------- 4 files changed, 56 insertions(+), 18 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java index 2c8fa17a6b4..c24124cf757 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java @@ -50,6 +50,24 @@ default void attachToCurrentScope() { tensors().forEach(scope::attach); } + + /** + * Attach these tensors to the parents of their current scopes, removing them from their current scopes. + * + *

If {@code requireParent} is false, detaches each tensor if its scope does not have a parent. Otherwise, if + * {@code requireParent} is true and the scope does not have a parent, throws {@link IllegalStateException}. + * + *

WARNING: this method may release resources without assigning them to another scope if + * * {@code requireParent} is false. {@link #attachToParent()} should be used instead wherever possible. + * + * @param requireParent Whether to require a parent scope to release resources to. + * @throws IllegalStateException if the tensor does not have a scope, or if this scope has no parent, but {@code + * requireParent} is true + */ + default void attachToParent(boolean requireParent) { + tensors().forEach(x -> x.attachToParent(requireParent)); + } + /** * Attach these tensors to the parent of their current scope, removing it from its current scope. * @@ -60,7 +78,7 @@ default void attachToCurrentScope() { * @throws IllegalStateException if any tensors do not have a scope, or their scope does not have a parent. */ default void attachToParent() { - tensors().forEach(Tensor::attachToParent); + attachToParent(true); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index f754349c715..803207c548c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -87,15 +87,19 @@ public boolean isAttached() { } @Override - public synchronized void attachToParent() { + public synchronized void attachToParent(boolean requireParent) { if (scope == null) { throw new IllegalStateException("Can't attach to parent: no scope."); } - if (scope.parent == null) { + if (scope.parent == null && requireParent) { throw new IllegalStateException("Can't attach to parent: scope does not have a parent."); } - scope.parent.attach(this); + if (scope.parent != null) { + scope.parent.attach(this); + } else { + this.detach(); + } } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 0eb35ec3614..7efc7d32d9f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -224,9 +224,26 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData /** * Attach this tensor to the parent of it's current scope, removing it from its current scope. * - * @throws IllegalStateException if it does not have a scope, or its scope does not have a parent. + * @throws IllegalStateException if the tensor does not have a scope, or its scope does not have a parent. */ - void attachToParent(); + default void attachToParent() { + attachToParent(true); + } + + /** + * Attach this tensor to the parent of it's current scope, removing it from its current scope. + * + *

If {@code requireParent} is false, detaches the tensor if its scope does not have a parent. Otherwise, if + * {@code requireParent} is true and the scope does not have a parent, throws {@link IllegalStateException}. + * + *

WARNING: this method may release resources without assigning them to another scope if + * * {@code requireParent} is false. {@link #attachToParent()} should be used instead wherever possible. + * + * @param requireParent Whether to require a parent scope to release resources to. + * @throws IllegalStateException if the tensor does not have a scope, or if this scope has no parent, but {@code + * requireParent} is true + */ + void attachToParent(boolean requireParent); /** * Returns true if this tensor is attached to a {@link TensorScope}. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index b860a86be5a..712ba660cc2 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -27,10 +27,9 @@ * to a n-dimensional data space allowing direct I/O access from the JVM.

* *

Subinterfaces of {@code TType} are propagated as a generic parameter to various entities of - * TensorFlow to identify the type of the tensor they carry. For example, a - * {@link org.tensorflow.Operand Operand} is an operand which outputs a 32-bit floating - * point tensor. This parameter ensure type-compatibility between operands of a computation at - * compile-time. For example: + * TensorFlow to identify the type of the tensor they carry. For example, a {@link org.tensorflow.Operand + * Operand} is an operand which outputs a 32-bit floating point tensor. This parameter ensure + * type-compatibility between operands of a computation at compile-time. For example: * *

{@code
  * Ops tf = Ops.create();
@@ -44,8 +43,8 @@
  * }
* *

Even if all typed tensors implements somehow {@link org.tensorflow.ndarray.NdArray NdArray} - * to provide access to their data, {@code TType} deliberately does not extend directly from this - * interface, for the following reasons: + * to provide access to their data, {@code TType} deliberately does not extend directly from this interface, for the + * following reasons: *

    *
  • Implementing {@code NdArray} at this level could only expose boxed-type accessors, which * are less performant than their primitive equivalent, only exposed by subinterfaces of @@ -82,27 +81,27 @@ default void close() { } @Override - default boolean isClosed(){ + default boolean isClosed() { return asRawTensor().isClosed(); } @Override - default void detach(){ + default void detach() { asRawTensor().detach(); } @Override - default void attachToParent(){ - asRawTensor().attachToParent(); + default void attachToParent(boolean requireParent) { + asRawTensor().attachToParent(requireParent); } @Override - default boolean isAttached(){ + default boolean isAttached() { return asRawTensor().isAttached(); } @Override - default void attachToCurrentScope(){ + default void attachToCurrentScope() { asRawTensor().attachToCurrentScope(); } } From 4319a650e4599d29b59e9e90de865b5f7efc9671 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 18 Jan 2021 21:01:39 -0800 Subject: [PATCH 20/35] Adjust API to be more explicit, add release Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/HasTensors.java | 45 --- .../main/java/org/tensorflow/RawTensor.java | 32 +- .../src/main/java/org/tensorflow/Tensor.java | 33 +- .../main/java/org/tensorflow/TensorScope.java | 339 +++++++++++++----- .../org/tensorflow/types/family/TType.java | 15 - .../java/org/tensorflow/TensorScopeTest.java | 4 +- 6 files changed, 253 insertions(+), 215 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java index c24124cf757..5d8344d22a0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java @@ -35,51 +35,6 @@ default void detach() { tensors().forEach(Tensor::detach); } - /** - * Attach all of these tensors to the most recent scope. - * - * @throws IllegalStateException if there is no active scope. - * @see Tensor#attachToCurrentScope() - */ - default void attachToCurrentScope() { - TensorScope scope = TensorScope.currentScope(); - if (scope == null) { - throw new IllegalStateException("Can't attach to current scope: no active tensor scopes."); - } - - tensors().forEach(scope::attach); - } - - - /** - * Attach these tensors to the parents of their current scopes, removing them from their current scopes. - * - *

    If {@code requireParent} is false, detaches each tensor if its scope does not have a parent. Otherwise, if - * {@code requireParent} is true and the scope does not have a parent, throws {@link IllegalStateException}. - * - *

    WARNING: this method may release resources without assigning them to another scope if - * * {@code requireParent} is false. {@link #attachToParent()} should be used instead wherever possible. - * - * @param requireParent Whether to require a parent scope to release resources to. - * @throws IllegalStateException if the tensor does not have a scope, or if this scope has no parent, but {@code - * requireParent} is true - */ - default void attachToParent(boolean requireParent) { - tensors().forEach(x -> x.attachToParent(requireParent)); - } - - /** - * Attach these tensors to the parent of their current scope, removing it from its current scope. - * - *

    Note that if tensors have different scopes, each tensor will be attached to its scope's parent. - * {@link TensorScope#attach(HasTensors)} or {@link #attachToCurrentScope()} can be used to ensure all tensors have - * the same scope. - * - * @throws IllegalStateException if any tensors do not have a scope, or their scope does not have a parent. - */ - default void attachToParent() { - attachToParent(true); - } /** * Release resources associated with these tensors. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index 803207c548c..b5b21c8c46b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -76,41 +76,11 @@ public boolean isClosed() { return closed; } - @Override - public void detach() { - TensorScope.detach(this); - } - @Override public boolean isAttached() { return scope != null; } - @Override - public synchronized void attachToParent(boolean requireParent) { - if (scope == null) { - throw new IllegalStateException("Can't attach to parent: no scope."); - } - if (scope.parent == null && requireParent) { - throw new IllegalStateException("Can't attach to parent: scope does not have a parent."); - } - - if (scope.parent != null) { - scope.parent.attach(this); - } else { - this.detach(); - } - } - - @Override - public void attachToCurrentScope() { - TensorScope scope = TensorScope.currentScope(); - if (scope == null) { - throw new IllegalStateException("Can't attach to current scope: no active tensor scopes."); - } - scope.attach(this); - } - /** * Returns the raw data of this tensor as a buffer of bytes. * @@ -254,7 +224,7 @@ private static long[] shape(TF_Tensor handle) { TensorScope currentScope = TensorScope.currentScope(); if (currentScope != null) { - this.scope = currentScope.attach(this); + this.scope = currentScope.withAttached(this); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 7efc7d32d9f..2910349aa7a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -218,42 +218,15 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData /** * Detach this tensor from any scopes managing it. It must be manually closed or attached to another scope. - */ - void detach(); - - /** - * Attach this tensor to the parent of it's current scope, removing it from its current scope. * - * @throws IllegalStateException if the tensor does not have a scope, or its scope does not have a parent. + *

    Semantically, this makes the tensor everyone's responsibility: whoever uses it last needs to close it. */ - default void attachToParent() { - attachToParent(true); + default void detach() { + TensorScope.detach(this); } - /** - * Attach this tensor to the parent of it's current scope, removing it from its current scope. - * - *

    If {@code requireParent} is false, detaches the tensor if its scope does not have a parent. Otherwise, if - * {@code requireParent} is true and the scope does not have a parent, throws {@link IllegalStateException}. - * - *

    WARNING: this method may release resources without assigning them to another scope if - * * {@code requireParent} is false. {@link #attachToParent()} should be used instead wherever possible. - * - * @param requireParent Whether to require a parent scope to release resources to. - * @throws IllegalStateException if the tensor does not have a scope, or if this scope has no parent, but {@code - * requireParent} is true - */ - void attachToParent(boolean requireParent); - /** * Returns true if this tensor is attached to a {@link TensorScope}. */ boolean isAttached(); - - /** - * Attach this tensor to the most recent scope. - * - * @throws IllegalStateException if there are no active scopes - */ - void attachToCurrentScope(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index b389b891c6e..4218005307f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -16,14 +16,7 @@ */ package org.tensorflow; -import java.util.ArrayDeque; -import java.util.Collections; -import java.util.Deque; -import java.util.HashSet; -import java.util.Iterator; -import java.util.LinkedHashSet; import java.util.Set; -import java.util.WeakHashMap; import java.util.concurrent.ConcurrentHashMap; @@ -33,7 +26,7 @@ * scope (even if they are created in a sub-scope). Tensors may be manually closed earlier without issue. *

    * Tensors are automatically tracked on creation. A tensor can me manually added to a scope with {@link - * TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrentScope()}. A tensor may only have one scope: if it + * TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrent()}. A tensor may only have one scope: if it * currently has a scope when {@code attach} is called, it is removed from its original scope. *

    * {@link Tensor#detach()} detaches the tensor from it's scope, requiring the user to close it manually or attach it to @@ -61,17 +54,6 @@ public static TensorScope currentScope() { return scope; } - public static void detach(Tensor tensor) { - // ensure that I'm not attaching or detaching at the same time in different threads - RawTensor rt = tensor.asRawTensor(); - synchronized (rt) { - if (rt.scope != null) { - rt.scope.tensors.remove(rt); - rt.scope = null; - } - } - } - /** * Create a new tensor scope. If {@code autoAttach} is false, will not automatically manage tensors. * @@ -88,12 +70,154 @@ public TensorScope() { } } + /** + * Closes this scope and its tensors, and any inner scopes. + */ + @Override + public synchronized void close() { + if (closed) { + return; + } + + children.forEach(TensorScope::close); + tensors.forEach(Tensor::close); + + closed = true; + + if (parent != null) { + parent.children.remove(this); + } + + if (currentScope() == this) { + currentScope.set(this.parent); + } + } + + /** + * Release the tensors and child scopes of this scope to it's parent, without closing them. + *

    + * Semantically, calling this method makes all of the resources in this scope the parent's responsibility, as if this + * scope had never existed. + *

    + * This will close this scope, but does not close any of it's resources. + * + * @throws IllegalStateException if this scope has no parent. If this happens, + * * the scope is not closed and no resources are released. + */ + public synchronized void releaseToParent() { + release(true); + } + + /** + * Release the tensors and child scopes of this scope to it's parent, or detach them if this scope has no parent. + *

    + * Semantically, calling this method makes all of the resources in this scope the parent's responsibility, as if this + * scope had never existed. It can be used in a method to transfer control to the caller, leaving how the resources + * are managed up to the caller. + *

    + * This will close this scope, but does not close any of it's resources. + */ + public synchronized void release() { + release(false); + } + + /** + * Release the tensors and child scopes of this scope without closing them, to it's parent if it has one. + * + *

    WARNING: this method may release resources without assigning them to another scope if + * {@code requireParent} is false. {@link #releaseToParent()} should be used instead wherever possible. + * + * @param requireParent Whether to require a parent scope to release resources to. + * @throws IllegalStateException if this scope has no parent, but {@code requireParent} is true. If this happens, + * the scope is not closed and no resources are released. + */ + public synchronized void release(boolean requireParent) { + if (closed) { + return; + } + + if (this.parent == null && requireParent) { + throw new IllegalStateException("Can't release to parent: scope does not have parent."); + } + + if (this.parent != null) { + TensorScope newParent = this.parent; + newParent.children.addAll(children); + children.forEach(x -> x.parent = newParent); + tensors.forEach(newParent::attach); + } else { + children.forEach(x -> x.parent = null); + tensors.forEach(TensorScope::detach); + } + + children.clear(); + tensors.clear(); + + close(); + } + + public static T detach(T tensor) { + // ensure that I'm not attaching or detaching at the same time in different threads + RawTensor rt = tensor.asRawTensor(); + synchronized (rt) { + if (rt.scope != null) { + rt.scope.tensors.remove(rt); + rt.scope = null; + } + } + return tensor; + } + + /** + * @see #detach(Tensor) + */ + public static void detach(Tensor... tensors){ + for(Tensor t : tensors){ + detach(t); + } + } + + /** + * @see #detach(Tensor) + */ + public static T detach(T tensors){ + detach(tensors.tensors()); + return tensors; + } + + /** + * @see #detach(Tensor) + */ + public static void detach(HasTensors... tensors){ + for(HasTensors ht : tensors){ + detach(ht); + } + } + + /** + * @see #detach(Tensor) + */ + public static > T detach(T tensors){ + tensors.forEach(TensorScope::detach); + return tensors; + } + + /** + * @see #detach(Tensor) + */ + @SafeVarargs + public static void detach(Iterable... tensors){ + for(Iterable iterable : tensors){ + detach(iterable); + } + } + /** * Attach a tensor to this scope. This happens automatically to tensors that are created in the scope. * * @return this */ - public synchronized TensorScope attach(Tensor tensor) { + public synchronized T attach(T tensor) { if (this.closed) { throw new IllegalStateException("Scope has been closed, can not attach new tensor."); } @@ -106,141 +230,172 @@ public synchronized TensorScope attach(Tensor tensor) { tensors.add(rt); } - return this; + return tensor; } /** - * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. - * - * @return this + * @see #attach(Tensor) */ - public TensorScope attach(Tensor... tensors) { + public void attach(Tensor... tensors) { if (tensors != null) { for (Tensor t : tensors) { attach(t); } } - - return this; } /** - * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. - * - * @return this + * @see #attach(Tensor) */ - public TensorScope attach(HasTensors tensors) { - tensors.tensors().forEach(this::attach); - - return this; + public T attach(T tensors) { + attach(tensors.tensors()); + return tensors; } /** - * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. - * - * @return this + * @see #attach(Tensor) */ - public TensorScope attach(HasTensors... tensors) { + public void attach(HasTensors... tensors) { if (tensors != null) { for (HasTensors ht : tensors) { attach(ht); } } - - return this; } /** - * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. - * - * @return this + * @see #attach(Tensor) */ - public TensorScope attach(Iterable tensors) { + public > T attach(T tensors) { tensors.forEach(this::attach); - return this; + return tensors; } /** - * Attach tensors to this scope. This happens automatically to tensors that are created in the scope. - * - * @return this + * @see #attach(Tensor) */ @SafeVarargs - public final TensorScope attach(Iterable... tensors) { + public final void attach(Iterable... tensors) { if (tensors != null) { - for (Iterable ht : tensors) { + for (Iterable ht : tensors) { attach(ht); } } + } + /** + * @see #attach(Tensor) + */ + public TensorScope withAttached(Tensor... tensors){ + attach(tensors); return this; } /** - * Closes this scope and its tensors, and any inner scopes. + * @see #attach(Tensor) */ - @Override - public synchronized void close() { - if (closed) { - return; - } - - children.forEach(TensorScope::close); - tensors.forEach(Tensor::close); + public TensorScope withAttached(HasTensors... tensors){ + attach(tensors); + return this; + } - closed = true; + /** + * @see #attach(Tensor) + */ + public TensorScope withAttached(Iterable... tensors){ + attach(tensors); + return this; + } - if (parent != null) { - parent.children.remove(this); + /** + * Attach this tensor to the parent of this scope, removing it from its current scope, or detach it if there is + * no current scope or the current scope does not have a parent. + * + *

    Semantically, this makes the tensor's resources this scope's parent's responsibility. + * + * @param requireParent whether to require a parent scope to release resources to. + * @throws IllegalStateException if there is no current scope or the current scope does not have a parent, but {@code + * requireParent} is true. If this happens, the tensor's scope is not changed. + */ + public T release(T tensor, boolean requireParent){ + if (parent == null && requireParent) { + throw new IllegalStateException( + "Can't release to parent: not in a current scope, or the current scope does not have a parent."); } - if (currentScope() == this) { - currentScope.set(this.parent); + detach(tensor); + if (parent != null) { + parent.attach(tensor); } + return tensor; } + /** - * Release the tensors and child scopes of this scope to it's parent, without closing them. + * Attach this tensor to the parent of this scope, removing it from its current scope, or detach it if there is + * no current scope or the current scope does not have a parent. * - * @throws IllegalStateException if this scope has no parent. + *

    Semantically, this makes the tensor's resources this scope's parent's responsibility. */ - public synchronized void releaseToParent() { - release(true); + public T release(T tensor){ + return release(tensor, false); } /** - * Release the tensors and child scopes of this scope without closing them, to it's parent if it has one. - * - *

    WARNING: this method may release resources without assigning them to another scope if - * {@code requireParent} is false. {@link #releaseToParent()} should be used instead wherever possible. - * - * @param requireParent Whether to require a parent scope to release resources to. - * @throws IllegalStateException if this scope has no parent, but {@code requireParent} is true. + * @see #release(Tensor) */ - public synchronized void release(boolean requireParent) { - if (closed) { - return; + public void release(Tensor... tensors){ + for(Tensor t : tensors){ + release(t); } + } - if (this.parent == null && requireParent) { - throw new IllegalStateException("Can't release to parent: scope does not have parent."); - } + /** + * @see #release(Tensor) + */ + public T release(T tensors){ + release(tensors.tensors()); + return tensors; + } - if (this.parent != null) { - TensorScope newParent = this.parent; - newParent.children.addAll(children); - children.forEach(x -> x.parent = newParent); - tensors.forEach(newParent::attach); - } else { - children.forEach(x -> x.parent = null); - tensors.forEach(TensorScope::detach); + /** + * @see #release(Tensor) + */ + public void release(HasTensors... tensors){ + for(HasTensors ht : tensors){ + release(ht); } + } - children.clear(); - tensors.clear(); + /** + * @see #release(Tensor) + */ + public > T release(T tensors){ + tensors.forEach(this::release); + return tensors; + } - close(); + /** + * @see #release(Tensor) + */ + @SafeVarargs + public final void release(Iterable... tensors){ + for(Iterable iterable : tensors){ + release(iterable); + } + } + + /** + * Attach this tensor to the parent of this scope, removing it from its current scope. + * + *

    Semantically, this makes the tensor's resources this scope's parent's responsibility. + * + * @throws IllegalStateException if there is no current scope or the current scope does not have a parent, but {@code + * requireParent} is true. If this happens, the tensor's scope is not changed. + */ + public T releaseToParent(T tensor){ + return release(tensor, true); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index 712ba660cc2..afe10f685a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -85,23 +85,8 @@ default boolean isClosed() { return asRawTensor().isClosed(); } - @Override - default void detach() { - asRawTensor().detach(); - } - - @Override - default void attachToParent(boolean requireParent) { - asRawTensor().attachToParent(requireParent); - } - @Override default boolean isAttached() { return asRawTensor().isAttached(); } - - @Override - default void attachToCurrentScope() { - asRawTensor().attachToCurrentScope(); - } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index f1652233730..df63e16bc61 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -86,7 +86,7 @@ public void testNestedScope() { public void testAttach() { TensorScope firstScope = new TensorScope(); TFloat32 tensor = makeTensor(10); - TensorScope secondScope = new TensorScope().attach(tensor); + TensorScope secondScope = new TensorScope().withAttached(tensor); assertTrue(tensor.isAttached()); assertFalse(tensor.isClosed()); @@ -129,7 +129,7 @@ public void testAttachToParentScope() { assertTrue(tensor.isAttached()); assertFalse(tensor.isClosed()); - tensor.attachToParent(); + scope.release(tensor); scope.close(); From c13c8dbf94460a06e1a8e9ef8c0d35f91e3f9dd8 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 23 Jan 2021 14:29:15 -0800 Subject: [PATCH 21/35] Make constructor package private, use static methods. Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/TensorScope.java | 119 +++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 4218005307f..c45f108cc25 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -18,6 +18,9 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; /** @@ -54,12 +57,126 @@ public static TensorScope currentScope() { return scope; } + /** + * Runs {@code block}, then closes any tensors created during its execution. + *

    To release tensors, use {@link #withCleanup(Consumer)} or one of the {@code produceWithCleanup} methods. + */ + public static void withCleanup(Runnable block){ + try(TensorScope scope = new TensorScope()){ + block.run(); + } + } + + /** + * Runs {@code block}, then closes any tensors created during its execution (or attached to the scope). + *

    Tensors can be released using the passed scope. + */ + public static void withCleanup(Consumer block){ + try(TensorScope scope = new TensorScope()){ + block.accept(scope); + } + } + + /** + * Runs {@code block} and returns the result, then closes any tensors created during its execution. + *

    To release tensors, use {@link #withCleanup(Function)} or one of the {@code produceWithCleanup} methods. + */ + public static T withCleanup(Supplier block){ + try(TensorScope scope = new TensorScope()){ + return block.get(); + } + } + + /** + * Runs {@code block} and returns the result, then closes any tensors created during its execution (or attached to the scope). + *

    Tensors can be released using the passed scope. + */ + public static T withCleanup(Function block){ + try(TensorScope scope = new TensorScope()){ + return block.apply(scope); + } + } + + /** + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution. + *

    To release other tensors, use {@link #produceTensorWithCleanup(Function)}. + * + * @return the released result of {@code block} + */ + public static T produceTensorWithCleanup(Supplier block){ + try(TensorScope scope = new TensorScope()){ + return scope.release(block.get()); + } + } + + /** + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution (or attached to the scope). + *

    Tensors can be released using the passed scope. + * + * @return the released result of {@code block} + */ + public static T produceTensorWithCleanup(Function block){ + try(TensorScope scope = new TensorScope()){ + return scope.release(block.apply(scope)); + } + } + + + /** + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution. + *

    To release other tensors, use {@link #produceTensorWithCleanup(Function)}. + * + * @return the released result of {@code block} + */ + public static T produceHasTensorsWithCleanup(Supplier block){ + try(TensorScope scope = new TensorScope()){ + return scope.release(block.get()); + } + } + + /** + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution (or attached to the scope). + *

    Tensors can be released using the passed scope. + * + * @return the released result of {@code block} + */ + public static T produceHasTensorsWithCleanup(Function block){ + try(TensorScope scope = new TensorScope()){ + return scope.release(block.apply(scope)); + } + } + + + /** + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution. + *

    To release other tensors, use {@link #produceTensorWithCleanup(Function)}. + * + * @return the released result of {@code block} + */ + public static > T produceTensorsWithCleanup(Supplier block){ + try(TensorScope scope = new TensorScope()){ + return scope.release(block.get()); + } + } + + /** + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution (or attached to the scope). + *

    Tensors can be released using the passed scope. + * + * @return the released result of {@code block} + */ + public static > T produceTensorsWithCleanup(Function block){ + try(TensorScope scope = new TensorScope()){ + return scope.release(block.apply(scope)); + } + } + /** * Create a new tensor scope. If {@code autoAttach} is false, will not automatically manage tensors. * * @see TensorScope */ - public TensorScope() { + TensorScope() { this.parent = currentScope(); currentScope.set(this); From 7ebb44715795f0afe5650db7670ba2ff58816bdb Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sat, 23 Jan 2021 14:29:52 -0800 Subject: [PATCH 22/35] format Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/TensorScope.java | 125 +++++++++--------- 1 file changed, 66 insertions(+), 59 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index c45f108cc25..5a5f5922ee7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -29,8 +29,8 @@ * scope (even if they are created in a sub-scope). Tensors may be manually closed earlier without issue. *

    * Tensors are automatically tracked on creation. A tensor can me manually added to a scope with {@link - * TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrent()}. A tensor may only have one scope: if it - * currently has a scope when {@code attach} is called, it is removed from its original scope. + * TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrent()}. A tensor may only have one scope: if it currently + * has a scope when {@code attach} is called, it is removed from its original scope. *

    * {@link Tensor#detach()} detaches the tensor from it's scope, requiring the user to close it manually or attach it to * another scope. @@ -61,8 +61,8 @@ public static TensorScope currentScope() { * Runs {@code block}, then closes any tensors created during its execution. *

    To release tensors, use {@link #withCleanup(Consumer)} or one of the {@code produceWithCleanup} methods. */ - public static void withCleanup(Runnable block){ - try(TensorScope scope = new TensorScope()){ + public static void withCleanup(Runnable block) { + try (TensorScope scope = new TensorScope()) { block.run(); } } @@ -71,8 +71,8 @@ public static void withCleanup(Runnable block){ * Runs {@code block}, then closes any tensors created during its execution (or attached to the scope). *

    Tensors can be released using the passed scope. */ - public static void withCleanup(Consumer block){ - try(TensorScope scope = new TensorScope()){ + public static void withCleanup(Consumer block) { + try (TensorScope scope = new TensorScope()) { block.accept(scope); } } @@ -81,92 +81,99 @@ public static void withCleanup(Consumer block){ * Runs {@code block} and returns the result, then closes any tensors created during its execution. *

    To release tensors, use {@link #withCleanup(Function)} or one of the {@code produceWithCleanup} methods. */ - public static T withCleanup(Supplier block){ - try(TensorScope scope = new TensorScope()){ + public static T withCleanup(Supplier block) { + try (TensorScope scope = new TensorScope()) { return block.get(); } } /** - * Runs {@code block} and returns the result, then closes any tensors created during its execution (or attached to the scope). + * Runs {@code block} and returns the result, then closes any tensors created during its execution (or attached to the + * scope). *

    Tensors can be released using the passed scope. */ - public static T withCleanup(Function block){ - try(TensorScope scope = new TensorScope()){ + public static T withCleanup(Function block) { + try (TensorScope scope = new TensorScope()) { return block.apply(scope); } } /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution. + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its + * execution. *

    To release other tensors, use {@link #produceTensorWithCleanup(Function)}. * * @return the released result of {@code block} */ - public static T produceTensorWithCleanup(Supplier block){ - try(TensorScope scope = new TensorScope()){ + public static T produceTensorWithCleanup(Supplier block) { + try (TensorScope scope = new TensorScope()) { return scope.release(block.get()); } } /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution (or attached to the scope). + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its + * execution (or attached to the scope). *

    Tensors can be released using the passed scope. * * @return the released result of {@code block} */ - public static T produceTensorWithCleanup(Function block){ - try(TensorScope scope = new TensorScope()){ + public static T produceTensorWithCleanup(Function block) { + try (TensorScope scope = new TensorScope()) { return scope.release(block.apply(scope)); } } /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution. + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its + * execution. *

    To release other tensors, use {@link #produceTensorWithCleanup(Function)}. * * @return the released result of {@code block} */ - public static T produceHasTensorsWithCleanup(Supplier block){ - try(TensorScope scope = new TensorScope()){ + public static T produceHasTensorsWithCleanup(Supplier block) { + try (TensorScope scope = new TensorScope()) { return scope.release(block.get()); } } /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution (or attached to the scope). + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its + * execution (or attached to the scope). *

    Tensors can be released using the passed scope. * * @return the released result of {@code block} */ - public static T produceHasTensorsWithCleanup(Function block){ - try(TensorScope scope = new TensorScope()){ + public static T produceHasTensorsWithCleanup(Function block) { + try (TensorScope scope = new TensorScope()) { return scope.release(block.apply(scope)); } } /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution. + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its + * execution. *

    To release other tensors, use {@link #produceTensorWithCleanup(Function)}. * * @return the released result of {@code block} */ - public static > T produceTensorsWithCleanup(Supplier block){ - try(TensorScope scope = new TensorScope()){ + public static > T produceTensorsWithCleanup(Supplier block) { + try (TensorScope scope = new TensorScope()) { return scope.release(block.get()); } } /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its execution (or attached to the scope). + * Runs {@code block} and releases and returns the result, then closes any other tensors created during its + * execution (or attached to the scope). *

    Tensors can be released using the passed scope. * * @return the released result of {@code block} */ - public static > T produceTensorsWithCleanup(Function block){ - try(TensorScope scope = new TensorScope()){ + public static > T produceTensorsWithCleanup(Function block) { + try (TensorScope scope = new TensorScope()) { return scope.release(block.apply(scope)); } } @@ -218,8 +225,8 @@ public synchronized void close() { *

    * This will close this scope, but does not close any of it's resources. * - * @throws IllegalStateException if this scope has no parent. If this happens, - * * the scope is not closed and no resources are released. + * @throws IllegalStateException if this scope has no parent. If this happens, * the scope is not closed and no + * resources are released. */ public synchronized void releaseToParent() { release(true); @@ -245,8 +252,8 @@ public synchronized void release() { * {@code requireParent} is false. {@link #releaseToParent()} should be used instead wherever possible. * * @param requireParent Whether to require a parent scope to release resources to. - * @throws IllegalStateException if this scope has no parent, but {@code requireParent} is true. If this happens, - * the scope is not closed and no resources are released. + * @throws IllegalStateException if this scope has no parent, but {@code requireParent} is true. If this happens, the + * scope is not closed and no resources are released. */ public synchronized void release(boolean requireParent) { if (closed) { @@ -288,8 +295,8 @@ public static T detach(T tensor) { /** * @see #detach(Tensor) */ - public static void detach(Tensor... tensors){ - for(Tensor t : tensors){ + public static void detach(Tensor... tensors) { + for (Tensor t : tensors) { detach(t); } } @@ -297,7 +304,7 @@ public static void detach(Tensor... tensors){ /** * @see #detach(Tensor) */ - public static T detach(T tensors){ + public static T detach(T tensors) { detach(tensors.tensors()); return tensors; } @@ -305,8 +312,8 @@ public static T detach(T tensors){ /** * @see #detach(Tensor) */ - public static void detach(HasTensors... tensors){ - for(HasTensors ht : tensors){ + public static void detach(HasTensors... tensors) { + for (HasTensors ht : tensors) { detach(ht); } } @@ -314,7 +321,7 @@ public static void detach(HasTensors... tensors){ /** * @see #detach(Tensor) */ - public static > T detach(T tensors){ + public static > T detach(T tensors) { tensors.forEach(TensorScope::detach); return tensors; } @@ -323,8 +330,8 @@ public static > T detach(T tensors){ * @see #detach(Tensor) */ @SafeVarargs - public static void detach(Iterable... tensors){ - for(Iterable iterable : tensors){ + public static void detach(Iterable... tensors) { + for (Iterable iterable : tensors) { detach(iterable); } } @@ -404,7 +411,7 @@ public final void attach(Iterable... tensors) { /** * @see #attach(Tensor) */ - public TensorScope withAttached(Tensor... tensors){ + public TensorScope withAttached(Tensor... tensors) { attach(tensors); return this; } @@ -412,7 +419,7 @@ public TensorScope withAttached(Tensor... tensors){ /** * @see #attach(Tensor) */ - public TensorScope withAttached(HasTensors... tensors){ + public TensorScope withAttached(HasTensors... tensors) { attach(tensors); return this; } @@ -420,14 +427,14 @@ public TensorScope withAttached(HasTensors... tensors){ /** * @see #attach(Tensor) */ - public TensorScope withAttached(Iterable... tensors){ + public TensorScope withAttached(Iterable... tensors) { attach(tensors); return this; } /** - * Attach this tensor to the parent of this scope, removing it from its current scope, or detach it if there is - * no current scope or the current scope does not have a parent. + * Attach this tensor to the parent of this scope, removing it from its current scope, or detach it if there is no + * current scope or the current scope does not have a parent. * *

    Semantically, this makes the tensor's resources this scope's parent's responsibility. * @@ -435,7 +442,7 @@ public TensorScope withAttached(Iterable... tensors){ * @throws IllegalStateException if there is no current scope or the current scope does not have a parent, but {@code * requireParent} is true. If this happens, the tensor's scope is not changed. */ - public T release(T tensor, boolean requireParent){ + public T release(T tensor, boolean requireParent) { if (parent == null && requireParent) { throw new IllegalStateException( "Can't release to parent: not in a current scope, or the current scope does not have a parent."); @@ -450,20 +457,20 @@ public T release(T tensor, boolean requireParent){ /** - * Attach this tensor to the parent of this scope, removing it from its current scope, or detach it if there is - * no current scope or the current scope does not have a parent. + * Attach this tensor to the parent of this scope, removing it from its current scope, or detach it if there is no + * current scope or the current scope does not have a parent. * *

    Semantically, this makes the tensor's resources this scope's parent's responsibility. */ - public T release(T tensor){ + public T release(T tensor) { return release(tensor, false); } /** * @see #release(Tensor) */ - public void release(Tensor... tensors){ - for(Tensor t : tensors){ + public void release(Tensor... tensors) { + for (Tensor t : tensors) { release(t); } } @@ -471,7 +478,7 @@ public void release(Tensor... tensors){ /** * @see #release(Tensor) */ - public T release(T tensors){ + public T release(T tensors) { release(tensors.tensors()); return tensors; } @@ -479,8 +486,8 @@ public T release(T tensors){ /** * @see #release(Tensor) */ - public void release(HasTensors... tensors){ - for(HasTensors ht : tensors){ + public void release(HasTensors... tensors) { + for (HasTensors ht : tensors) { release(ht); } } @@ -488,7 +495,7 @@ public void release(HasTensors... tensors){ /** * @see #release(Tensor) */ - public > T release(T tensors){ + public > T release(T tensors) { tensors.forEach(this::release); return tensors; } @@ -497,8 +504,8 @@ public > T release(T tensors){ * @see #release(Tensor) */ @SafeVarargs - public final void release(Iterable... tensors){ - for(Iterable iterable : tensors){ + public final void release(Iterable... tensors) { + for (Iterable iterable : tensors) { release(iterable); } } @@ -511,7 +518,7 @@ public final void release(Iterable... tensors){ * @throws IllegalStateException if there is no current scope or the current scope does not have a parent, but {@code * requireParent} is true. If this happens, the tensor's scope is not changed. */ - public T releaseToParent(T tensor){ + public T releaseToParent(T tensor) { return release(tensor, true); } From 49ac26e1164a5298498d3d063700e1c478c497b3 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 24 Jan 2021 15:01:27 -0800 Subject: [PATCH 23/35] fixes Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/RawTensor.java | 14 +++--- .../{HasTensors.java => TensorContainer.java} | 2 +- .../main/java/org/tensorflow/TensorScope.java | 48 +++++++++---------- .../java/org/tensorflow/TensorScopeTest.java | 4 +- 4 files changed, 34 insertions(+), 34 deletions(-) rename tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/{HasTensors.java => TensorContainer.java} (96%) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index b5b21c8c46b..040bd6a91d4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -66,7 +66,7 @@ public RawTensor asRawTensor() { @Override public void close() { if (!closed) { - tensorScope.close(); + pointerScope.close(); closed = true; } } @@ -78,7 +78,7 @@ public boolean isClosed() { @Override public boolean isAttached() { - return scope != null; + return tensorScope != null; } /** @@ -146,7 +146,7 @@ static RawTensor allocate(Class type, Shape shape, long size) { scope.attach(nativeHandle); RawTensor t = new RawTensor(typeInfo, shape); t.tensorHandle = nativeHandle; - t.tensorScope = scope.extend(); + t.pointerScope = scope.extend(); return t; } } @@ -162,7 +162,7 @@ static RawTensor fromHandle(TF_Tensor handle) { try (PointerScope scope = new PointerScope()) { scope.attach(handle); t.tensorHandle = handle; - t.tensorScope = scope.extend(); + t.pointerScope = scope.extend(); } return t; } @@ -224,13 +224,13 @@ private static long[] shape(TF_Tensor handle) { TensorScope currentScope = TensorScope.currentScope(); if (currentScope != null) { - this.scope = currentScope.withAttached(this); + this.tensorScope = currentScope.withTensors(this); } } - private PointerScope tensorScope; + private PointerScope pointerScope; private boolean closed; - TensorScope scope; + TensorScope tensorScope; private TF_Tensor tensorHandle; private final TensorTypeInfo typeInfo; private final Shape shape; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java similarity index 96% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java rename to tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java index 5d8344d22a0..248bf5a7e2a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/HasTensors.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java @@ -19,7 +19,7 @@ /** * An interface representing a collection or group of tensors. Provides methods for resource management. */ -public interface HasTensors extends AutoCloseable { +public interface TensorContainer extends AutoCloseable { /** * Get the tensors held by this object. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 5a5f5922ee7..704b71a8ca4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -132,7 +132,7 @@ public static T produceTensorWithCleanup(Function T produceHasTensorsWithCleanup(Supplier block) { + public static T produceHasTensorsWithCleanup(Supplier block) { try (TensorScope scope = new TensorScope()) { return scope.release(block.get()); } @@ -145,7 +145,7 @@ public static T produceHasTensorsWithCleanup(Supplier * * @return the released result of {@code block} */ - public static T produceHasTensorsWithCleanup(Function block) { + public static T produceHasTensorsWithCleanup(Function block) { try (TensorScope scope = new TensorScope()) { return scope.release(block.apply(scope)); } @@ -228,7 +228,7 @@ public synchronized void close() { * @throws IllegalStateException if this scope has no parent. If this happens, * the scope is not closed and no * resources are released. */ - public synchronized void releaseToParent() { + public synchronized void releaseAllToParent() { release(true); } @@ -241,7 +241,7 @@ public synchronized void releaseToParent() { *

    * This will close this scope, but does not close any of it's resources. */ - public synchronized void release() { + public synchronized void releaseAll() { release(false); } @@ -249,13 +249,13 @@ public synchronized void release() { * Release the tensors and child scopes of this scope without closing them, to it's parent if it has one. * *

    WARNING: this method may release resources without assigning them to another scope if - * {@code requireParent} is false. {@link #releaseToParent()} should be used instead wherever possible. + * {@code requireParent} is false. {@link #releaseAllToParent()} should be used instead wherever possible. * * @param requireParent Whether to require a parent scope to release resources to. * @throws IllegalStateException if this scope has no parent, but {@code requireParent} is true. If this happens, the * scope is not closed and no resources are released. */ - public synchronized void release(boolean requireParent) { + private synchronized void release(boolean requireParent) { if (closed) { return; } @@ -284,9 +284,9 @@ public static T detach(T tensor) { // ensure that I'm not attaching or detaching at the same time in different threads RawTensor rt = tensor.asRawTensor(); synchronized (rt) { - if (rt.scope != null) { - rt.scope.tensors.remove(rt); - rt.scope = null; + if (rt.tensorScope != null) { + rt.tensorScope.tensors.remove(rt); + rt.tensorScope = null; } } return tensor; @@ -304,7 +304,7 @@ public static void detach(Tensor... tensors) { /** * @see #detach(Tensor) */ - public static T detach(T tensors) { + public static T detach(T tensors) { detach(tensors.tensors()); return tensors; } @@ -312,8 +312,8 @@ public static T detach(T tensors) { /** * @see #detach(Tensor) */ - public static void detach(HasTensors... tensors) { - for (HasTensors ht : tensors) { + public static void detach(TensorContainer... tensors) { + for (TensorContainer ht : tensors) { detach(ht); } } @@ -350,7 +350,7 @@ public synchronized T attach(T tensor) { // ensure that I'm not attaching or detaching at the same time in different threads synchronized (rt) { detach(tensor); - rt.scope = this; + rt.tensorScope = this; tensors.add(rt); } @@ -371,7 +371,7 @@ public void attach(Tensor... tensors) { /** * @see #attach(Tensor) */ - public T attach(T tensors) { + public T attach(T tensors) { attach(tensors.tensors()); return tensors; } @@ -379,9 +379,9 @@ public T attach(T tensors) { /** * @see #attach(Tensor) */ - public void attach(HasTensors... tensors) { + public void attach(TensorContainer... tensors) { if (tensors != null) { - for (HasTensors ht : tensors) { + for (TensorContainer ht : tensors) { attach(ht); } } @@ -411,7 +411,7 @@ public final void attach(Iterable... tensors) { /** * @see #attach(Tensor) */ - public TensorScope withAttached(Tensor... tensors) { + public TensorScope withTensors(Tensor... tensors) { attach(tensors); return this; } @@ -419,7 +419,7 @@ public TensorScope withAttached(Tensor... tensors) { /** * @see #attach(Tensor) */ - public TensorScope withAttached(HasTensors... tensors) { + public TensorScope withTensors(TensorContainer... tensors) { attach(tensors); return this; } @@ -427,7 +427,7 @@ public TensorScope withAttached(HasTensors... tensors) { /** * @see #attach(Tensor) */ - public TensorScope withAttached(Iterable... tensors) { + public TensorScope withTensors(Iterable... tensors) { attach(tensors); return this; } @@ -442,7 +442,7 @@ public TensorScope withAttached(Iterable... tensors) { * @throws IllegalStateException if there is no current scope or the current scope does not have a parent, but {@code * requireParent} is true. If this happens, the tensor's scope is not changed. */ - public T release(T tensor, boolean requireParent) { + private T release(T tensor, boolean requireParent) { if (parent == null && requireParent) { throw new IllegalStateException( "Can't release to parent: not in a current scope, or the current scope does not have a parent."); @@ -478,7 +478,7 @@ public void release(Tensor... tensors) { /** * @see #release(Tensor) */ - public T release(T tensors) { + public T release(T tensors) { release(tensors.tensors()); return tensors; } @@ -486,8 +486,8 @@ public T release(T tensors) { /** * @see #release(Tensor) */ - public void release(HasTensors... tensors) { - for (HasTensors ht : tensors) { + public void release(TensorContainer... tensors) { + for (TensorContainer ht : tensors) { release(ht); } } @@ -531,6 +531,6 @@ public synchronized boolean isClosed() { private boolean closed = false; private final Set tensors = ConcurrentHashMap.newKeySet(); - TensorScope parent; + private TensorScope parent; private final Set children = ConcurrentHashMap.newKeySet(); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index df63e16bc61..b496c50b221 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -86,7 +86,7 @@ public void testNestedScope() { public void testAttach() { TensorScope firstScope = new TensorScope(); TFloat32 tensor = makeTensor(10); - TensorScope secondScope = new TensorScope().withAttached(tensor); + TensorScope secondScope = new TensorScope().withTensors(tensor); assertTrue(tensor.isAttached()); assertFalse(tensor.isClosed()); @@ -107,7 +107,7 @@ public void testReleaseToParentScope() { assertTrue(tensor.isAttached()); assertFalse(tensor.isClosed()); - scope.releaseToParent(); + scope.releaseAllToParent(); assertTrue(scope.isClosed()); assertTrue(tensor.isAttached()); From b4b3ed462abe6196d7d4908cd800ca287d10bf45 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 24 Jan 2021 18:03:02 -0800 Subject: [PATCH 24/35] remove extra closed tracking Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/RawTensor.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index 040bd6a91d4..d239dffd6f7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -65,15 +65,14 @@ public RawTensor asRawTensor() { @Override public void close() { - if (!closed) { + if (!isClosed()) { pointerScope.close(); - closed = true; } } @Override public boolean isClosed() { - return closed; + return tensorHandle.isNull(); } @Override @@ -229,7 +228,6 @@ private static long[] shape(TF_Tensor handle) { } private PointerScope pointerScope; - private boolean closed; TensorScope tensorScope; private TF_Tensor tensorHandle; private final TensorTypeInfo typeInfo; From 775d8bd6617f4a4646465e813556aca320970018 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Sun, 24 Jan 2021 18:33:57 -0800 Subject: [PATCH 25/35] New tests, make static methods build on eachother Signed-off-by: Ryan Nett --- .../java/org/tensorflow/TensorContainer.java | 2 +- .../main/java/org/tensorflow/TensorScope.java | 32 +++----- .../java/org/tensorflow/TensorScopeTest.java | 79 +++++++++++++++++++ 3 files changed, 92 insertions(+), 21 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java index 248bf5a7e2a..ef1a368cd84 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java @@ -24,7 +24,7 @@ public interface TensorContainer extends AutoCloseable { /** * Get the tensors held by this object. */ - Iterable tensors(); + Iterable tensors(); /** * Detach these tensors from any scopes managing them. They must be manually closed or attached to another scope. diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 704b71a8ca4..4b9f311e979 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -62,9 +62,7 @@ public static TensorScope currentScope() { *

    To release tensors, use {@link #withCleanup(Consumer)} or one of the {@code produceWithCleanup} methods. */ public static void withCleanup(Runnable block) { - try (TensorScope scope = new TensorScope()) { - block.run(); - } + TensorScope.withCleanup((scope) -> block.run()); } /** @@ -79,20 +77,20 @@ public static void withCleanup(Consumer block) { /** * Runs {@code block} and returns the result, then closes any tensors created during its execution. - *

    To release tensors, use {@link #withCleanup(Function)} or one of the {@code produceWithCleanup} methods. + *

    To release tensors, use {@link #getWithCleanup(Function)} or one of the {@code produceWithCleanup} methods. + *

    Does not release or detach the result. If you return a tensor, it will be closed unless otherwise released. */ - public static T withCleanup(Supplier block) { - try (TensorScope scope = new TensorScope()) { - return block.get(); - } + public static T getWithCleanup(Supplier block) { + return TensorScope.getWithCleanup((scope) -> block.get()); } /** * Runs {@code block} and returns the result, then closes any tensors created during its execution (or attached to the * scope). *

    Tensors can be released using the passed scope. + *

    Does not release or detach the result. If you return a tensor, it will be closed unless otherwise released. */ - public static T withCleanup(Function block) { + public static T getWithCleanup(Function block) { try (TensorScope scope = new TensorScope()) { return block.apply(scope); } @@ -106,9 +104,7 @@ public static T withCleanup(Function block) { * @return the released result of {@code block} */ public static T produceTensorWithCleanup(Supplier block) { - try (TensorScope scope = new TensorScope()) { - return scope.release(block.get()); - } + return produceTensorWithCleanup((scope) -> block.get()); } /** @@ -132,10 +128,8 @@ public static T produceTensorWithCleanup(Function T produceHasTensorsWithCleanup(Supplier block) { - try (TensorScope scope = new TensorScope()) { - return scope.release(block.get()); - } + public static T produceTensorContainerWithCleanup(Supplier block) { + return produceTensorContainerWithCleanup((scope) -> block.get()); } /** @@ -145,7 +139,7 @@ public static T produceHasTensorsWithCleanup(Supplie * * @return the released result of {@code block} */ - public static T produceHasTensorsWithCleanup(Function block) { + public static T produceTensorContainerWithCleanup(Function block) { try (TensorScope scope = new TensorScope()) { return scope.release(block.apply(scope)); } @@ -160,9 +154,7 @@ public static T produceHasTensorsWithCleanup(Functio * @return the released result of {@code block} */ public static > T produceTensorsWithCleanup(Supplier block) { - try (TensorScope scope = new TensorScope()) { - return scope.release(block.get()); - } + return TensorScope.produceTensorsWithCleanup((scope) -> block.get()); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index b496c50b221..a9c6d3774fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -19,6 +19,9 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TFloat32; @@ -143,5 +146,81 @@ public void testAttachToParentScope() { assertTrue(outerScope.isClosed()); } + @Test + public void testWithCleanup() { + final Tensor[] tensor = new Tensor[1]; + TensorScope.withCleanup(() -> { + tensor[0] = makeTensor(2); + }); + assertTrue(tensor[0].isClosed()); + } + + @Test + public void testGetWithCleanup() { + Tensor tensor = TensorScope.getWithCleanup(() -> makeTensor(2)); + assertTrue(tensor.isClosed()); + } + + @Test + public void testProduceTensorWithCleanup() { + final Tensor[] closedTensor = new Tensor[1]; + Tensor openTensor = TensorScope.produceTensorWithCleanup(() -> { + closedTensor[0] = makeTensor(2); + return makeTensor(3); + }); + + assertTrue(closedTensor[0].isClosed()); + assertFalse(openTensor.isClosed()); + openTensor.close(); + } + + private static class TestTensorContainer implements TensorContainer { + + private final List tensors; + + TestTensorContainer(List tensors) { + this.tensors = tensors; + } + + @SafeVarargs + TestTensorContainer(T... tensors) { + this(Arrays.asList(tensors)); + } + + @Override + public Iterable tensors() { + return tensors; + } + + public List getTensors() { + return tensors; + } + } + + @Test + public void testProduceTensorContainerWithCleanup() { + final TestTensorContainer[] closedTensor = new TestTensorContainer[1]; + TestTensorContainer openTensor = TensorScope.produceTensorContainerWithCleanup(() -> { + closedTensor[0] = new TestTensorContainer<>(makeTensor(2)); + return new TestTensorContainer<>(makeTensor(3)); + }); + + assertTrue(closedTensor[0].getTensors().get(0).isClosed()); + assertFalse(openTensor.getTensors().get(0).isClosed()); + openTensor.getTensors().get(0).close(); + } + + @Test + public void testProduceTensorsWithCleanup(){ + final List[] closedTensor = new List[1]; + List openTensor = TensorScope.produceTensorsWithCleanup(() -> { + closedTensor[0] = Collections.singletonList(makeTensor(2)); + return Collections.singletonList(makeTensor(2)); + }); + + assertTrue(closedTensor[0].get(0).isClosed()); + assertFalse(openTensor.get(0).isClosed()); + openTensor.get(0).close(); + } } From 8032310c1dd23ec450117de39694694b0d3c72e2 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 28 Jan 2021 20:51:22 -0800 Subject: [PATCH 26/35] Convert to scope-passing style Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 451 +++++++------ .../org/tensorflow/AbstractOperation.java | 3 +- .../java/org/tensorflow/ConcreteFunction.java | 47 +- .../java/org/tensorflow/EagerOperation.java | 17 +- .../java/org/tensorflow/GraphOperation.java | 14 +- .../src/main/java/org/tensorflow/Operand.java | 5 +- .../src/main/java/org/tensorflow/Output.java | 32 +- .../main/java/org/tensorflow/RawTensor.java | 20 +- .../java/org/tensorflow/SavedModelBundle.java | 74 ++- .../src/main/java/org/tensorflow/Session.java | 187 +++--- .../src/main/java/org/tensorflow/Tensor.java | 35 +- .../main/java/org/tensorflow/TensorScope.java | 320 +--------- .../java/org/tensorflow/op/core/Constant.java | 438 +++++++------ .../java/org/tensorflow/types/TBfloat16.java | 42 +- .../main/java/org/tensorflow/types/TBool.java | 35 +- .../java/org/tensorflow/types/TFloat16.java | 42 +- .../java/org/tensorflow/types/TFloat32.java | 35 +- .../java/org/tensorflow/types/TFloat64.java | 35 +- .../java/org/tensorflow/types/TInt32.java | 35 +- .../java/org/tensorflow/types/TInt64.java | 35 +- .../java/org/tensorflow/types/TString.java | 49 +- .../java/org/tensorflow/types/TUint8.java | 35 +- .../org/tensorflow/ConcreteFunctionTest.java | 35 +- .../java/org/tensorflow/DeviceSpecTest.java | 302 ++++----- .../tensorflow/EagerOperationBuilderTest.java | 109 ++-- .../org/tensorflow/EagerOperationTest.java | 13 +- .../tensorflow/GraphOperationBuilderTest.java | 62 +- .../org/tensorflow/GraphOperationTest.java | 17 +- .../test/java/org/tensorflow/GraphTest.java | 136 ++-- .../java/org/tensorflow/RawTensorTest.java | 79 +-- .../org/tensorflow/SavedModelBundleTest.java | 86 +-- .../test/java/org/tensorflow/SessionTest.java | 151 +++-- .../java/org/tensorflow/TensorScopeTest.java | 160 +---- .../test/java/org/tensorflow/TensorTest.java | 601 +++++++++--------- .../tensorflow/benchmark/TensorBenchmark.java | 195 +++--- .../java/org/tensorflow/op/ScopeTest.java | 50 +- .../org/tensorflow/op/core/ConstantTest.java | 77 +-- .../op/core/GeneratedOperationsTest.java | 31 +- .../org/tensorflow/op/core/GradientsTest.java | 40 +- .../org/tensorflow/op/core/ShapesTest.java | 460 +++++++------- .../org/tensorflow/op/core/ZerosTest.java | 65 +- .../types/NumericTypesTestBase.java | 121 ++-- .../org/tensorflow/types/TBfloat16Test.java | 5 +- .../org/tensorflow/types/TFloat16Test.java | 5 +- .../org/tensorflow/types/TFloat32Test.java | 5 +- .../org/tensorflow/types/TFloat64Test.java | 5 +- .../java/org/tensorflow/types/TInt32Test.java | 5 +- .../java/org/tensorflow/types/TInt64Test.java | 5 +- .../org/tensorflow/types/TStringTest.java | 110 ++-- .../java/org/tensorflow/types/TUint8Test.java | 5 +- 50 files changed, 2377 insertions(+), 2544 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index ea3ef31313e..47be3383364 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -527,8 +527,8 @@ public Constant array(byte... data) { * * @param scope is a scope used to add the underlying operation. * @param charset charset for encoding/decoding strings bytes. - * @param data An array containing the values to put into the new constant. String elements are - * sequences of bytes from the last array dimension. + * @param data An array containing the values to put into the new constant. String elements are sequences of bytes + * from the last array dimension. * @return the {@code String} constant */ public Constant array(Charset charset, String... data) { @@ -1189,62 +1189,62 @@ public Constant constant(LongNdArray data) { } /** - * Creates a rank-1 constant of {@code int} elements. + * Creates a rank-3 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant */ - public Constant constant(int[] data) { - return Constant.vectorOf(scope, data); + public Constant constant(double[][][] data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-3 constant of {@code int} elements. + * Creates a rank-2 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant */ - public Constant constant(int[][][] data) { + public Constant constant(double[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant containing a single {@code double} element. + * Creates a rank-5 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data The value to put into the new constant. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ - public Constant constant(double data) { - return Constant.scalarOf(scope, data); + public Constant constant(double[][][][][] data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code long} elements. + * Creates a rank-2 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ - public Constant constant(long[][][][][] data) { + public Constant constant(long[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code boolean} elements. + * Creates a constant containing a single {@code double} element. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a boolean constant + * @param data The value to put into the new constant. + * @return a double constant */ - public Constant constant(boolean[][][][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(double data) { + return Constant.scalarOf(scope, data); } /** @@ -1270,26 +1270,26 @@ public Constant constant(DoubleNdArray data) { } /** - * Creates a rank-4 constant of {@code int} elements. + * Creates a rank-3 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a boolean constant */ - public Constant constant(int[][][][] data) { + public Constant constant(boolean[][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code float} elements. + * Creates a rank-5 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a byte constant */ - public Constant constant(float[][][][][][] data) { + public Constant constant(byte[][][][][] data) { return Constant.tensorOf(scope, data); } @@ -1305,168 +1305,180 @@ public Constant constant(byte data) { } /** - * Creates a rank-3 constant of {@code boolean} elements. + * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, using the default + * UTF-8 encoding. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a boolean constant + * @param data an n-dimensional array of {@code String} elements. + * @return a string constant */ - public Constant constant(boolean[][][] data) { + public Constant constant(NdArray data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-4 constant of {@code float} elements. + * Creates a rank-1 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(float[][][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(int[] data) { + return Constant.vectorOf(scope, data); } /** - * Creates a rank-2 constant of {@code long} elements. + * Creates a rank-4 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a long constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a boolean constant */ - public Constant constant(long[][] data) { + public Constant constant(boolean[][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code byte} elements. + * Creates a rank-4 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a byte constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant */ - public Constant constant(byte[][][][][] data) { + public Constant constant(double[][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant of {@code boolean} elements that is a copy of a given n-dimensional array. + * Creates a rank-2 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data an n-dimensional array of {@code boolean} elements. - * @return a boolean constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a byte constant */ - public Constant constant(BooleanNdArray data) { + public Constant constant(byte[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-2 constant of {@code float} elements. + * Creates a rank-6 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a byte constant */ - public Constant constant(float[][] data) { + public Constant constant(byte[][][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant of {@code byte} elements that is a copy of a given n-dimensional array. + * Creates a constant of {@code boolean} elements that is a copy of a given n-dimensional array. * * @param scope is a scope used to add the underlying operation. - * @param data an n-dimensional array of {@code byte} elements. - * @return a byte constant + * @param data an n-dimensional array of {@code boolean} elements. + * @return a boolean constant */ - public Constant constant(ByteNdArray data) { + public Constant constant(BooleanNdArray data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-2 constant of {@code byte} elements. + * Creates a rank-3 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ - public Constant constant(byte[][] data) { + public Constant constant(byte[][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code double} elements. + * Creates a rank-5 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(double[][][][][] data) { + public Constant constant(int[][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-3 constant of {@code float} elements. + * Creates a rank-1 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ - public Constant constant(float[][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(float[] data) { + return Constant.vectorOf(scope, data); } /** - * Creates a rank-1 constant of {@code byte} elements. + * Creates a constant of {@code byte} elements that is a copy of a given n-dimensional array. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data an n-dimensional array of {@code byte} elements. * @return a byte constant */ - public Constant constant(byte[] data) { - return Constant.vectorOf(scope, data); + public Constant constant(ByteNdArray data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-1 constant of {@code float} elements. + * Creates a rank-4 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(float[] data) { + public Constant constant(int[][][][] data) { + return Constant.tensorOf(scope, data); + } + + /** + * Creates a rank-1 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant + */ + public Constant constant(double[] data) { return Constant.vectorOf(scope, data); } /** - * Creates a rank-2 constant of {@code boolean} elements. + * Creates a rank-5 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a boolean constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a float constant */ - public Constant constant(boolean[][] data) { + public Constant constant(float[][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, - * using the default UTF-8 encoding. + * Creates a rank-3 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data an n-dimensional array of {@code String} elements. - * @return a string constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a float constant */ - public Constant constant(NdArray data) { + public Constant constant(float[][][] data) { return Constant.tensorOf(scope, data); } @@ -1482,46 +1494,46 @@ public Constant constant(String data) { } /** - * Creates a rank-4 constant of {@code double} elements. + * Creates a rank-2 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a boolean constant */ - public Constant constant(double[][][][] data) { + public Constant constant(boolean[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-2 constant of {@code double} elements. + * Creates a constant containing a single {@code int} element. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data The value to put into the new constant. + * @return an integer constant */ - public Constant constant(double[][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(int data) { + return Constant.scalarOf(scope, data); } /** - * Creates a constant containing a single {@code int} element. + * Creates a rank-1 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data The value to put into the new constant. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a long constant */ - public Constant constant(int data) { - return Constant.scalarOf(scope, data); + public Constant constant(long[] data) { + return Constant.vectorOf(scope, data); } /** * Creates a rank-4 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ public Constant constant(byte[][][][] data) { @@ -1529,14 +1541,14 @@ public Constant constant(byte[][][][] data) { } /** - * Creates a rank-6 constant of {@code int} elements. + * Creates a rank-6 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant */ - public Constant constant(int[][][][][][] data) { + public Constant constant(double[][][][][][] data) { return Constant.tensorOf(scope, data); } @@ -1563,169 +1575,157 @@ public Constant constant(float data) { } /** - * Creates a rank-5 constant of {@code float} elements. + * Creates a rank-1 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a byte constant */ - public Constant constant(float[][][][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(byte[] data) { + return Constant.vectorOf(scope, data); } /** - * Creates a rank-3 constant of {@code double} elements. + * Creates a rank-2 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(double[][][] data) { + public Constant constant(int[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code long} elements. + * Creates a rank-4 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a long constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a float constant */ - public Constant constant(long[][][][][][] data) { + public Constant constant(float[][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-4 constant of {@code long} elements. + * Creates a rank-6 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a long constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a boolean constant */ - public Constant constant(long[][][][] data) { + public Constant constant(boolean[][][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-1 constant of {@code long} elements. + * Creates a constant of {@code float} elements that is a copy of a given n-dimensional array. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a long constant + * @param data an n-dimensional array of {@code float} elements. + * @return a float constant */ - public Constant constant(long[] data) { - return Constant.vectorOf(scope, data); + public Constant constant(FloatNdArray data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-1 constant of {@code boolean} elements. + * Creates a rank-5 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ - public Constant constant(boolean[] data) { - return Constant.vectorOf(scope, data); - } - - /** - * Creates a rank-3 constant of {@code byte} elements. - * - * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a byte constant - */ - public Constant constant(byte[][][] data) { + public Constant constant(boolean[][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code byte} elements. + * Creates a rank-6 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a byte constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a long constant */ - public Constant constant(byte[][][][][][] data) { + public Constant constant(long[][][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-2 constant of {@code int} elements. + * Creates a rank-4 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a long constant */ - public Constant constant(int[][] data) { + public Constant constant(long[][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant of {@code float} elements that is a copy of a given n-dimensional array. + * Creates a rank-2 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data an n-dimensional array of {@code float} elements. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ - public Constant constant(FloatNdArray data) { + public Constant constant(float[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code int} elements. + * Creates a rank-5 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a long constant */ - public Constant constant(int[][][][][] data) { + public Constant constant(long[][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-1 constant of {@code double} elements. + * Creates a rank-6 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a float constant */ - public Constant constant(double[] data) { - return Constant.vectorOf(scope, data); + public Constant constant(float[][][][][][] data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code boolean} elements. + * Creates a rank-3 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a boolean constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(boolean[][][][][][] data) { + public Constant constant(int[][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code double} elements. + * Creates a rank-6 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(double[][][][][][] data) { + public Constant constant(int[][][][][][] data) { return Constant.tensorOf(scope, data); } @@ -1741,23 +1741,23 @@ public Constant constant(boolean data) { } /** - * Creates a rank-4 constant of {@code boolean} elements. + * Creates a rank-1 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ - public Constant constant(boolean[][][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(boolean[] data) { + return Constant.vectorOf(scope, data); } /** * Creates a rank-3 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ public Constant constant(long[][][] data) { @@ -1765,8 +1765,7 @@ public Constant constant(long[][][] data) { } /** - * Creates a rank-1 constant of {@code long} elements representing the size of each dimensions of - * the given shape. + * Creates a rank-1 constant of {@code long} elements representing the size of each dimensions of the given shape. * * @param scope is a scope used to add the underlying operation. * @param shape a shape @@ -1781,8 +1780,8 @@ public Constant constant(Shape shape) { * * @param scope is a scope used to add the underlying operation. * @param charset charset for encoding/decoding strings bytes. - * @param data An array containing the values to put into the new constant. String elements are - * sequences of bytes from the last array dimension. + * @param data An array containing the values to put into the new constant. String elements are sequences of bytes + * from the last array dimension. * @return the {@code String} constant */ public Constant constant(Charset charset, String[] data) { @@ -1802,8 +1801,8 @@ public Constant constant(Charset charset, String data) { } /** - * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, - * using the given encoding. + * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, using the given + * encoding. * * @param scope is a scope used to add the underlying operation. * @param charset charset used to encode/decode string bytes. @@ -1854,29 +1853,28 @@ public Constant constant(Shape shape, ByteDataBuffer data) { } /** - * Create a {@link TInt64} constant with data from the given buffer. + * Create a {@link TString} constant with data from the given buffer, using the default UTF-8 encoding. * * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. - * @return a long constant + * @return a string constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public Constant constant(Shape shape, LongDataBuffer data) { + public Constant constant(Shape shape, DataBuffer data) { return Constant.tensorOf(scope, shape, data); } /** - * Create a {@link TString} constant with data from the given buffer, using the default UTF-8 - * encoding. + * Create a {@link TInt64} constant with data from the given buffer. * * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. - * @return a string constant + * @return a long constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public Constant constant(Shape shape, DataBuffer data) { + public Constant constant(Shape shape, LongDataBuffer data) { return Constant.tensorOf(scope, shape, data); } @@ -1943,8 +1941,7 @@ public Constant constant(Charset charset, Shape shape, DataBuffer Constant constant(Class type, Shape shape, ByteDataBuffer data) { return Constant.tensorOf(scope, type, shape, data); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java index 0ffd6c2205e..eba6d755ce0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java @@ -84,8 +84,9 @@ public String toString() { * *

    This is only supported in an eager execution environment. * + * @param scope the {@link TensorScope} to create the tensor in * @param outputIdx index of the output of this operation * @return output tensor */ - abstract Tensor tensor(int outputIdx); + abstract Tensor tensor(TensorScope scope, int outputIdx); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 71dc0f7cefc..ea529d4c374 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -16,9 +16,9 @@ package org.tensorflow; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.ListIterator; -import java.util.HashMap; import java.util.Map; import java.util.function.Function; import org.tensorflow.op.Ops; @@ -43,12 +43,12 @@ public class ConcreteFunction implements AutoCloseable { * Creates a function by building a new graph. * *

    The {@code functionBuilder} must initialize the function graph from the provided - * {@link Ops} instance and return a valid signature that will be used to feed the input tensors - * and fetch the output tensors on execution. + * {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output + * tensors on execution. * *

    The function will be the owner of the new graph and its resulting session. Therefore, - * the function must be enclosed properly with a try-with-resources block to guarantee that - * all native resources will be freed once the function is discarded. For example: + * the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will + * be freed once the function is discarded. For example: * *

    {@code
        * public class MyModel {
    @@ -87,8 +87,8 @@ public static ConcreteFunction create(Function functionBuilder)
        * Create a function from a signature and an existing graph.
        *
        * 

    The function will keep the ownership of the session used to run the graph but not - * the graph itself, meaning that the lifetime of the latter can extend beyond the scope - * of the function. For example: + * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For + * example: * *

    {@code
        * try (Graph g = new Graph()) {
    @@ -116,8 +116,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
        * Create a function from a signature and a valid graph session.
        *
        * 

    The function will not own the session nor its graph, meaning that their lifetime - * can extend beyond the scope of the function. Therefore the function does not need to be - * closed after its usage. For example: + * can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For + * example: * *

    {@code
        * try (Graph g = new Graph()) {
    @@ -158,12 +158,11 @@ public Signature signature() {
        *
        * 

    Caller is responsible for closing all Tensors. * - * @param arguments list of tensors to pass in input to the function, - * mapped by their signature name - * @return output tensors resulting from the execution of the function, - * mapped by their signature name + * @param scope the {@link TensorScope} to create the outputs in + * @param arguments list of tensors to pass in input to the function, mapped by their signature name + * @return output tensors resulting from the execution of the function, mapped by their signature name */ - public Map call(Map arguments) + public Map call(TensorScope scope, Map arguments) throws IllegalArgumentException { final SignatureDef signatureDef = signature.asSignatureDef(); @@ -180,13 +179,13 @@ public Map call(Map arguments) Map outputToNode = signatureDef.getOutputsMap(); outputToNode.values().forEach(t -> runner.fetch(t.getName())); - List resultTensors = runner.run(); + List resultTensors = runner.run(scope); try { ListIterator resultTensorIter = resultTensors.listIterator(); Map returnMap = new HashMap(); // Use the output names as present in the signature definition - for (String nodeName: outputToNode.keySet()) { + for (String nodeName : outputToNode.keySet()) { returnMap.put(nodeName, resultTensorIter.next()); } return returnMap; @@ -205,27 +204,27 @@ public Map call(Map arguments) * *

    Caller is responsible for closing all Tensors. * + * @param scope the {@link TensorScope} to create the output in * @param tensor input tensor * @return output tensor - * @throws IllegalArgumentException if there are multiple input or output parameters defined - * in the function + * @throws IllegalArgumentException if there are multiple input or output parameters defined in the function */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { + public Tensor call(TensorScope scope, Tensor tensor) throws IllegalArgumentException { final SignatureDef signatureDef = signature.asSignatureDef(); if (signatureDef.getInputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); + String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); } String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName(); if (signatureDef.getOutputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); + String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); } String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName(); - return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); + return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run(scope).get(0); } /** @@ -245,8 +244,8 @@ public void save(String exportDir) throws IOException { * Returns the session used to execute the graph when calling this function * *

    In general, a user does not need to handle directly the session of a function and rely - * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to - * the session might be necessary, as it allows more running options. + * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to the session might be + * necessary, as it allows more running options. * * @return the function session */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 4e9394b7df0..407efe9bf32 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -30,14 +30,13 @@ import org.tensorflow.internal.c_api.TF_Tensor; import org.tensorflow.ndarray.Shape; import org.tensorflow.proto.framework.DataType; -import org.tensorflow.types.family.TType; /** * Implementation of an {@link Operation} executed eagerly. * *

    EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of - * is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the - * EagerOperation instance may fail with an {@code IllegalStateException}. + * is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the EagerOperation instance may + * fail with an {@code IllegalStateException}. * *

    EagerOperation instances are thread-safe. */ @@ -121,10 +120,10 @@ DataType dtype(int outputIndex) { } @Override - Tensor tensor(int outputIndex) { + Tensor tensor(TensorScope scope, int outputIndex) { Tensor tensor = outputTensors.get(outputIndex); if (tensor == null) { - tensor = resolveTensor(outputIndex); + tensor = resolveTensor(scope, outputIndex); } return tensor; } @@ -134,11 +133,11 @@ Tensor tensor(int outputIndex) { private final String name; private final AtomicReferenceArray outputTensors; - private Tensor resolveTensor(int outputIndex) { + private Tensor resolveTensor(TensorScope scope, int outputIndex) { // Take an optimistic approach, where we attempt to resolve the output tensor without locking. // If another thread has resolved it meanwhile, release our copy and reuse the existing one // instead. - Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session); + Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), scope); if (!outputTensors.compareAndSet(outputIndex, null, tensor)) { session.detach(tensor.asRawTensor().nativeHandle()); tensor = outputTensors.get(outputIndex); @@ -161,13 +160,13 @@ private static void requireTensorHandle(TFE_TensorHandle handle) { } } - private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) { + private static Tensor resolveTensorHandle(TFE_TensorHandle handle, TensorScope tensorScope) { requireTensorHandle(handle); try (PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(); status.throwExceptionIfNotOK(); - return RawTensor.fromHandle(tensor).asTypedTensor(); + return RawTensor.fromHandle(tensorScope, tensor).asTypedTensor(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java index fbad92160a2..220abc8e590 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java @@ -37,8 +37,8 @@ * Implementation for an {@link Operation} added as a node to a {@link Graph}. * *

    GraphOperation instances are valid only as long as the {@link Graph} they are a part of is - * valid. Thus, if {@link Graph#close()} has been invoked, then methods on the GraphOperation - * instance may fail with an {@code IllegalStateException}. + * valid. Thus, if {@link Graph#close()} has been invoked, then methods on the GraphOperation instance may fail with an + * {@code IllegalStateException}. * *

    GraphOperation instances are immutable and thread-safe. */ @@ -166,7 +166,7 @@ DataType dtype(int outputIdx) { } @Override - Tensor tensor(int outputIdx) { + Tensor tensor(TensorScope scope, int outputIdx) { throw new IllegalStateException("Graph tensors must be fetched by running a session"); } @@ -236,7 +236,9 @@ private static long[] shape(TF_Graph graphHandle, TF_Operation opHandle, int out TF_Status status = TF_Status.newStatus(); int numDims = TF_GraphGetTensorNumDims(graphHandle, output, status); status.throwExceptionIfNotOK(); - if (numDims < 0) return null; + if (numDims < 0) { + return null; + } long[] dims = new long[numDims]; TF_GraphGetTensorShape(graphHandle, output, dims, numDims, status); status.throwExceptionIfNotOK(); @@ -250,8 +252,8 @@ private static int dtype(TF_Graph graphHandle, TF_Operation opHandle, int output int numOutputs = TF_OperationNumOutputs(opHandle); if (outputIndex < 0 || outputIndex >= numOutputs) { - throw new IndexOutOfBoundsException("invalid output index (" + outputIndex - + ") for an operation that has " + numOutputs + " outputs"); + throw new IndexOutOfBoundsException("invalid output index (" + outputIndex + + ") for an operation that has " + numOutputs + " outputs"); } try (PointerScope scope = new PointerScope()) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java index 80f62eb5acc..5cdaae9e22e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java @@ -58,11 +58,12 @@ public interface Operand extends Op, Shaped { * * Only works when running in an eager execution * + * @param scope the {@link TensorScope} to create the tensor in * @return the tensor * @throws IllegalStateException if this is an operand of a graph */ - default T asTensor() { - return asOutput().asTensor(); + default T asTensor(TensorScope scope) { + return asOutput().asTensor(scope); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index 9e7dedfdc75..5e40850acd8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -33,31 +33,36 @@ */ public final class Output implements Operand { - /** Returns the index into the outputs of the Operation. */ + /** + * Returns the index into the outputs of the Operation. + */ public int index() { return index; } - /** Returns the DataType of the tensor referred to by this Output. */ + /** + * Returns the DataType of the tensor referred to by this Output. + */ @SuppressWarnings("unchecked") public DataType dataType() { return operation.dtype(index); } - /** Returns the type of the tensor referred to by this Output. */ + /** + * Returns the type of the tensor referred to by this Output. + */ @SuppressWarnings("unchecked") @Override public Class type() { - return (Class)TensorTypeRegistry.find(dataType()).type(); + return (Class) TensorTypeRegistry.find(dataType()).type(); } /** - * Returns this Output object with the type {@code Output}. This method is useful when given a - * value of type {@code Output}. + * Returns this Output object with the type {@code Output}. This method is useful when given a value of type {@code + * Output}. * * @param type any supported tensor type - * @throws IllegalArgumentException if the actual data type of this object does not match the type - * {@code U}. + * @throws IllegalArgumentException if the actual data type of this object does not match the type {@code U}. */ @SuppressWarnings("unchecked") public Output expect(Class type) { @@ -72,8 +77,7 @@ public Output expect(Class type) { * Returns the tensor at this output. * *

    This operation is only supported on the outputs of an operation executed eagerly. For graph - * environments, output tensors must be fetched by running a session, using {@link - * Session.Runner#fetch(Output)}. + * environments, output tensors must be fetched by running a session, using {@link Session.Runner#fetch(Output)}. * *

    It is recommended to close explicitly the returned tensor as soon as possible, since the * garbage collector is not aware of the amount of memory it consumes, which can be significant. @@ -84,8 +88,8 @@ public Output expect(Class type) { * @see EagerSession */ @SuppressWarnings("unchecked") - public T asTensor() { - return (T)operation.tensor(index); + public T asTensor(TensorScope scope) { + return (T) operation.tensor(scope, index); } /** @@ -130,7 +134,9 @@ public String toString() { operation.type(), operation.name(), index, shape().toString(), dataType()); } - /** Handle to the idx-th output of the Operation {@code op}. */ + /** + * Handle to the idx-th output of the Operation {@code op}. + */ Output(AbstractOperation op, int idx) { operation = op; index = idx; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index d239dffd6f7..c738aac3e4c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -108,6 +108,7 @@ public String toString() { * given type and shape. More memory can also be allocated to store also metadata within the tensor itself, e.g. a * lookup table in a string tensor. * + * @param tensorScope the {@link TensorScope} to create the tensor in * @param type tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor, or -1 to compute the size from the shape @@ -120,7 +121,7 @@ public String toString() { * unknown} * @throws IllegalStateException if tensor failed to be allocated */ - static RawTensor allocate(Class type, Shape shape, long size) { + static RawTensor allocate(TensorScope tensorScope, Class type, Shape shape, long size) { if (shape.hasUnknownDimension()) { throw new IllegalArgumentException( "Cannot allocate a tensor from a totally or partially unknown shape"); @@ -143,7 +144,7 @@ static RawTensor allocate(Class type, Shape shape, long size) { TF_Tensor nativeHandle = allocate(typeInfo.dataType().getNumber(), shape.asArray(), allocatedSize); try (PointerScope scope = new PointerScope()) { scope.attach(nativeHandle); - RawTensor t = new RawTensor(typeInfo, shape); + RawTensor t = new RawTensor(typeInfo, shape, tensorScope); t.tensorHandle = nativeHandle; t.pointerScope = scope.extend(); return t; @@ -155,9 +156,9 @@ static RawTensor allocate(Class type, Shape shape, long size) { * *

    Takes ownership of the handle. */ - static RawTensor fromHandle(TF_Tensor handle) { + static RawTensor fromHandle(TensorScope tensorScope, TF_Tensor handle) { TensorTypeInfo typeInfo = TensorTypeRegistry.find(DataType.forNumber(dtype(handle))); - RawTensor t = new RawTensor(typeInfo, Shape.of(shape(handle))); + RawTensor t = new RawTensor(typeInfo, Shape.of(shape(handle)), tensorScope); try (PointerScope scope = new PointerScope()) { scope.attach(handle); t.tensorHandle = handle; @@ -217,14 +218,15 @@ private static long[] shape(TF_Tensor handle) { return dims; } - RawTensor(TensorTypeInfo typeInfo, Shape shape) { + RawTensor(TensorTypeInfo typeInfo, Shape shape, TensorScope tensorScope) { this.typeInfo = typeInfo; this.shape = shape; - - TensorScope currentScope = TensorScope.currentScope(); - if (currentScope != null) { - this.tensorScope = currentScope.withTensors(this); + if (tensorScope == null) { + throw new NullPointerException("Can't create a tensor with a null TensorScope"); } + + tensorScope.attach(this); + this.tensorScope = tensorScope; } private PointerScope pointerScope; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 0974cc94a24..dc213a227e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -51,19 +51,22 @@ * SavedModelBundle represents a model loaded from storage. * *

    The model consists of a description of the computation (a {@link Graph}), a {@link Session} - * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, - * and a description of the model as a MetaGraphDef + * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, and a description + * of the model as a MetaGraphDef * protocol buffer. */ public class SavedModelBundle implements AutoCloseable { public static final String DEFAULT_TAG = "serve"; - /** Options for loading a SavedModel. */ + /** + * Options for loading a SavedModel. + */ public static final class Loader { - /** Load a SavedModelBundle with the configured options. */ + /** + * Load a SavedModelBundle with the configured options. + */ public SavedModelBundle load() { return SavedModelBundle.load(exportDir, tags, configProto, runOptions); } @@ -71,9 +74,8 @@ public SavedModelBundle load() { /** * Sets options to use when executing model initialization operations. * - * @param options A RunOptions - * protocol buffer. + * @param options A RunOptions + * protocol buffer. * @return this object */ public Loader withRunOptions(RunOptions options) { @@ -84,9 +86,8 @@ public Loader withRunOptions(RunOptions options) { /** * Set configuration of the Session object created when loading the model. * - * @param configProto A ConfigProto - * protocol buffer. + * @param configProto A ConfigProto + * protocol buffer. * @return this object */ public Loader withConfigProto(ConfigProto configProto) { @@ -114,12 +115,14 @@ private Loader(String exportDir) { } private String exportDir = null; - private String[] tags = { DEFAULT_TAG }; + private String[] tags = {DEFAULT_TAG}; private ConfigProto configProto = null; private RunOptions runOptions = null; } - /** Options for exporting a SavedModel. */ + /** + * Options for exporting a SavedModel. + */ public static final class Exporter { /** @@ -144,9 +147,9 @@ public Exporter withTags(String... tags) { * names to a graph) and a valid session to a graph to be saved in the model. * *

    Note:Eventually, TensorFlow for Java will support the export of functions objects like - * the Python API does but right now, only session-centric models are supported (i.e. models that - * has a single main graph and one or more signatures). These models are compatible with those - * exported by TensorFlow 1.x or by TensorFlow 2.x estimators. + * the Python API does but right now, only session-centric models are supported (i.e. models that has a single main + * graph and one or more signatures). These models are compatible with those exported by TensorFlow 1.x or by + * TensorFlow 2.x estimators. * *
    Therefore, all functions exported in a model should share the same session at the moment * or an exception will be thrown.
    @@ -154,8 +157,8 @@ public Exporter withTags(String... tags) { * @param function a function carrying a signature and a valid session to the graph to be saved * @return this object * @throws IllegalArgumentException if a function with the same name has already been added to the model - * @throws UnsupportedOperationException if this function does not share the same session with the other - * functions added to this model + * @throws UnsupportedOperationException if this function does not share the same session with the other functions + * added to this model */ public Exporter withFunction(ConcreteFunction function) { Signature signature = function.signature(); @@ -166,7 +169,8 @@ public Exporter withFunction(ConcreteFunction function) { if (session == null) { session = function.session(); } else if (session != function.session()) { - throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); + throw new UnsupportedOperationException( + "Saving multiple functions with different graphs/sessions is not supported yet."); } metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef()); return this; @@ -213,16 +217,15 @@ public void export() throws IOException { } private final String exportDir; - private String[] tags = { DEFAULT_TAG }; + private String[] tags = {DEFAULT_TAG}; private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); private final Map functions = new LinkedHashMap<>(); private Session session; } /** - * Load a saved model from an export directory. The model that is being loaded should be created - * using the Saved Model - * API. + * Load a saved model from an export directory. The model that is being loaded should be created using the Saved Model API. * *

    This method is a shorthand for: * @@ -267,15 +270,16 @@ public static Exporter exporter(String exportDir) { } /** - * Returns the MetaGraphDef + * Returns the MetaGraphDef * protocol buffer associated with the saved model. */ public MetaGraphDef metaGraphDef() { return metaGraphDef; } - /** Returns the graph that describes the computation performed by the model. */ + /** + * Returns the graph that describes the computation performed by the model. + */ public Graph graph() { return graph; } @@ -306,8 +310,7 @@ public List signatures() { * * @param signatureKey name of the {@code SignatureDef} in the saved model. * @return object that can be used to make calls to a function - * @throws IllegalArgumentException if {@code signatureKey} is not found in this - * saved model. + * @throws IllegalArgumentException if {@code signatureKey} is not found in this saved model. */ public ConcreteFunction function(String signatureKey) { ConcreteFunction function = functions.get(signatureKey); @@ -330,11 +333,12 @@ public ConcreteFunction function(String signatureKey) { * *

    Caller is responsible for closing all returned Tensors. * + * @param scope the {@link TensorScope} to create the outputs in * @param arguments list of input tensors, mapped by their signature name * @return list of output tensors, mapped by the signature name * @throws IllegalArgumentException if no function can be selected by default */ - public Map call(Map arguments) { + public Map call(TensorScope scope, Map arguments) { ConcreteFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); @@ -344,12 +348,11 @@ public Map call(Map arguments) { if (function == null) { throw new IllegalArgumentException("Cannot elect a default function for this model"); } - return function.call(arguments); + return function.call(scope, arguments); } /** - * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model - * bundle. + * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model bundle. */ @Override public void close() { @@ -362,7 +365,8 @@ public void close() { private final MetaGraphDef metaGraphDef; private final Map functions; - private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) { + private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, + Map functions) { this.graph = graph; this.session = session; this.metaGraphDef = metaGraphDef; @@ -370,8 +374,8 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef } /** - * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session - * object, plus the MetaGraphDef. + * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session object, plus the + * MetaGraphDef. * *

    Invoked from the native load method. Takes ownership of the handles. */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index b67f4a611e6..7cf3f39e144 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -15,7 +15,16 @@ package org.tensorflow; +import static org.tensorflow.Graph.resolveOutputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; + import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; +import java.util.List; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -33,15 +42,9 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; - -import java.util.ArrayList; -import java.util.List; import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; -import static org.tensorflow.Graph.resolveOutputs; -import static org.tensorflow.internal.c_api.global.tensorflow.*; - /** * Driver for {@link Graph} execution. * @@ -84,11 +87,9 @@ public Session(Graph g) { * Construct a new session with the associated {@link Graph} and configuration options. * * @param g The {@link Graph} the created Session will operate on. - * @param config Configuration parameters for the session specified as a ConfigProto - * protocol buffer. - * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto - * protocol buffer. + * @param config Configuration parameters for the session specified as a ConfigProto + * protocol buffer. + * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto protocol buffer. */ public Session(Graph g, ConfigProto config) { graph = g; @@ -101,7 +102,9 @@ public Session(Graph g, ConfigProto config) { } } - /** Wrap an existing session with the associated {@link Graph}. */ + /** + * Wrap an existing session with the associated {@link Graph}. + */ Session(Graph g, TF_Session nativeHandle) { graph = g; this.nativeHandle = nativeHandle; @@ -111,7 +114,7 @@ public Session(Graph g, ConfigProto config) { /** * Release resources associated with the Session. * - *

    Blocks until there are no active executions ({@link Session.Runner#run()} calls). A Session + *

    Blocks until there are no active executions ({@link Session.Runner#run(TensorScope)} calls). A Session * is not usable after close returns. */ @Override @@ -139,22 +142,20 @@ public void close() { * Run {@link Operation}s and evaluate {@link Tensor Tensors}. * *

    A Runner runs the necessary graph fragments to execute every {@link Operation} required to - * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String,int,Tensor)} call allows - * callers to override the value of {@link Tensor Tensors} in the graph by substituting the - * provided {@link Tensor Tensors} for the outputs of the operations provided to {@link - * #feed(String,int,Tensor)}. + * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String, int, Tensor)} call allows callers to override + * the value of {@link Tensor Tensors} in the graph by substituting the provided {@link Tensor Tensors} for the + * outputs of the operations provided to {@link #feed(String, int, Tensor)}. */ public final class Runner { /** * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces. * - * @param operation Is either the string name of the operation, in which case this method is a - * shorthand for {@code feed(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * feed(operation_name, output_index)}. These colon-separated names are commonly used in the - * {@code SignatureDef} protocol buffer messages that are included in {@link - * SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code + * feed(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * feed(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} + * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. * @param t the tensor substituting the operation * @return this session runner */ @@ -163,8 +164,8 @@ public Runner feed(String operation, Tensor t) { } /** - * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} - * for the value it produces. + * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} for the value it + * produces. * *

    Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one {@code t} is being provided for. @@ -183,8 +184,7 @@ public Runner feed(String operation, int index, Tensor t) { } /** - * Use {@code t} instead of the Tensor referred to by executing the operation referred to by - * {@code operand}. + * Use {@code t} instead of the Tensor referred to by executing the operation referred to by {@code operand}. * * @param operand the node in the graph representing the operation to substitute * @param t the tensor substituting the operation @@ -197,14 +197,13 @@ public Runner feed(Operand operand, Tensor t) { } /** - * Make {@link #run()} return the output of {@code operation}. + * Make {@link #run(TensorScope)} return the output of {@code operation}. * - * @param operation Is either the string name of the operation, in which case this method is a - * shorthand for {@code fetch(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * fetch(operation_name, output_index)}. These colon-separated names are commonly used in - * the {@code SignatureDef} protocol buffer messages that are included in {@link - * SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code + * fetch(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * fetch(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} + * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. * @return this session runner */ public Runner fetch(String operation) { @@ -212,7 +211,7 @@ public Runner fetch(String operation) { } /** - * Make {@link #run()} return the {@code index}-th output of {@code operation}. + * Make {@link #run(TensorScope)} return the {@code index}-th output of {@code operation}. * *

    Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one to return. @@ -229,7 +228,7 @@ public Runner fetch(String operation, int index) { } /** - * Makes {@link #run()} return the Tensor referred to by {@code output}. + * Makes {@link #run(TensorScope)} return the Tensor referred to by {@code output}. * * @param output the node to fetch the tensor from * @return this session runner @@ -240,7 +239,7 @@ public Runner fetch(Output output) { } /** - * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. + * Makes {@link #run(TensorScope)} return the Tensor referred to by the output of {@code operand}. * * @param operand the node to fetch the tensor from, as an operand * @return this session runner @@ -250,8 +249,7 @@ public Runner fetch(Operand operand) { } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor - * Tensors}. + * Make {@link #run(TensorScope)} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. * * @param operation the string name of the operation to execute * @return this session runner @@ -265,8 +263,7 @@ public Runner addTarget(String operation) { } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor - * Tensors}. + * Make {@link #run(TensorScope)} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. * * @param operation the operation to execute * @return this session runner @@ -297,8 +294,7 @@ public Runner addTarget(Op op) { * Set options (typically for debugging) for this run. * *

    The options are presented as a RunOptions - * protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions protocol buffer. * * @param options a {@code RunOptions} proto * @return this session runner @@ -312,38 +308,37 @@ public Runner setOptions(RunOptions options) { * Execute the graph fragments necessary to compute all requested fetches. * *

    WARNING: The caller assumes ownership of all returned {@link Tensor Tensors}, i.e., - * the caller must call {@link Tensor#close} on all elements of the returned list to free up - * resources. + * the caller must call {@link Tensor#close} on all elements of the returned list to free up resources. * *

    TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it - * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in - * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a - * {@code Map}? + * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in SessionTest.java), and + * (b) Evaluate whether the return value should be a list, or maybe a {@code Map}? * *

    TODO(andrewmyers): It would also be good if whatever is returned here made it easier to * extract output tensors in a type-safe way. * + * @param scope the {@link TensorScope} to create outputs in. May be null if there are no outputs. * @return list of resulting tensors fetched by this session runner */ - public List run() { - return runHelper(false).outputs; + public List run(TensorScope scope) { + return runHelper(scope, false).outputs; } /** * Execute graph fragments to compute requested fetches and return metadata about the run. * - *

    This is exactly like {@link #run()}, but in addition to the requested Tensors, also - * returns metadata about the graph execution in the form of a RunMetadata + *

    This is exactly like {@link #run(TensorScope)}, but in addition to the requested Tensors, also + * returns metadata about the graph execution in the form of a RunMetadata * protocol buffer. * + * @param scope the {@link TensorScope} to create outputs in. May be null if there are no outputs. * @return list of resulting tensors fetched by this session runner, with execution metadata */ - public Run runAndFetchMetadata() { - return runHelper(true); + public Run runAndFetchMetadata(TensorScope scope) { + return runHelper(scope, true); } - private Run runHelper(boolean wantMetadata) { + private Run runHelper(TensorScope tensorScope, boolean wantMetadata) { TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()]; TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()]; int[] inputOpIndices = new int[inputs.size()]; @@ -351,6 +346,10 @@ private Run runHelper(boolean wantMetadata) { int[] outputOpIndices = new int[outputs.size()]; TF_Operation[] targetOpHandles = new TF_Operation[targets.size()]; + if (outputs.size() > 0 && tensorScope == null) { + throw new NullPointerException("tensorScope must not be null when outputs are requested"); + } + // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the // validity of the Graph and graphRef ensures that. int idx = 0; @@ -388,7 +387,8 @@ private Run runHelper(boolean wantMetadata) { outputOpIndices, targetOpHandles, wantMetadata, - outputs); + outputs, + tensorScope); } catch (Exception e) { for (Tensor t : outputs) { t.close(); @@ -405,6 +405,7 @@ private Run runHelper(boolean wantMetadata) { } private class Reference implements AutoCloseable { + public Reference() { synchronized (nativeHandleLock) { if (nativeHandle == null || nativeHandle.isNull()) { @@ -457,7 +458,9 @@ private Output parseOutput(String opName) { private RunOptions runOptions = null; } - /** Create a Runner to execute graph operations and evaluate Tensors. */ + /** + * Create a Runner to execute graph operations and evaluate Tensors. + */ public Runner runner() { return new Runner(); } @@ -476,7 +479,7 @@ public void run(String opName) { throw new IllegalArgumentException( "Operation named '" + opName + "' cannot be found in the graph"); } - runner().addTarget(operation).run(); + runner().addTarget(operation).run(null); } /** @@ -487,7 +490,7 @@ public void run(String opName) { * @param op the operation to run. */ public void run(Op op) { - runner().addTarget(op.op()).run(); + runner().addTarget(op.op()).run(null); } @@ -495,12 +498,11 @@ public void run(Op op) { * Execute the graph's initializers. * *

    This method is equivalent to {@code session.run(Ops.create(session.graph).init())}. - * */ - public void runInit(){ + public void runInit() { Runner runner = runner(); graph.initializers().forEach(runner::addTarget); - runner.run(); + runner.run(null); } /** @@ -518,14 +520,16 @@ public void runInit(){ */ public void save(String prefix) { SaverDef saverDef = graph.saverDef(); - runner().addTarget(saverDef.getSaveTensorName()) - .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) - .run(); + try (TensorScope scope = new TensorScope()) { + runner().addTarget(saverDef.getSaveTensorName()) + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(scope, prefix)) + .run(scope); + } } /** * Restore the actual state of the variables of this session's graph. - * + * *

    {@code prefix} is the path where the files containing the variables state live, * followed by the filename prefix. For example, if {@code prefix} is set to * mymodel/myvariables/variables, then the files are loaded from @@ -538,26 +542,30 @@ public void save(String prefix) { */ public void restore(String prefix) { SaverDef saverDef = graph.saverDef(); - runner().addTarget(saverDef.getRestoreOpName()) - .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) - .run(); + try (TensorScope scope = new TensorScope()) { + runner().addTarget(saverDef.getRestoreOpName()) + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(scope, prefix)) + .run(null); + } } /** * Output tensors and metadata obtained when executing a session. * - *

    See {@link Runner#runAndFetchMetadata()} + *

    See {@link Runner#runAndFetchMetadata(TensorScope)} */ public static final class Run { - /** Tensors from requested fetches. */ + + /** + * Tensors from requested fetches. + */ public List outputs; /** * Metadata about the run. * *

    A RunMetadata - * protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata protocol buffer. */ public RunMetadata metadata; } @@ -633,20 +641,20 @@ private static void delete(TF_Session handle) { * @param runOptions A RunOptions protocol buffer, or null * @param inputOpHandles (see inputOpIndices) * @param inputOpIndices (see inputTensorHandles) - * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values - * that are being "fed" (do not need to be computed) during graph execution. - * inputTensorHandles[i] (which corresponds to a Tensor.nativeHandle) is considered to be the - * inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that - * inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. + * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed" + * (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a + * Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, + * it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. * @param outputOpHandles (see outputOpIndices) - * @param outputOpIndices together with outputOpHandles identifies the set of values that should - * be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is - * required that outputOpHandles.length == outputOpIndices.length. - * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose - * output will not be returned + * @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The + * outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length == + * outputOpIndices.length. + * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose output will not be + * returned * @param wantRunMetadata indicates whether metadata about this execution should be returned. - * @param outputTensors will be filled in with tensors to the outputs requested. It is required - * that outputs.length == outputOpHandles.length. + * @param outputTensors will be filled in with tensors to the outputs requested. It is required that outputs.length == + * outputOpHandles.length. + * @param tensorScope the {@link TensorScope} to create tensors in * @return if wantRunMetadata is true, a RunMetadata protocol buffer, false otherwise. */ private static RunMetadata run( @@ -659,7 +667,8 @@ private static RunMetadata run( int[] outputOpIndices, TF_Operation[] targetOpHandles, boolean wantRunMetadata, - List outputTensors) { + List outputTensors, + TensorScope tensorScope) { requireHandle(handle); int ninputs = inputTensorHandles.length; @@ -698,8 +707,8 @@ private static RunMetadata run( status.throwExceptionIfNotOK(); for (int i = 0; i < noutputs; ++i) { - TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(); - outputTensors.add(RawTensor.fromHandle(h).asTypedTensor()); + TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(false); + outputTensors.add(RawTensor.fromHandle(tensorScope, h).asTypedTensor()); } try { return runMetadata != null ? RunMetadata.parseFrom(runMetadata.dataAsByteBuffer()) : null; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 2910349aa7a..4c649079d2d 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -51,6 +51,7 @@ public interface Tensor extends Shaped, AutoCloseable { * and is left uninitialized. * * @param the tensor type + * @param scope the {@link TensorScope} to create the tensor in * @param type the tensor type class * @param shape shape of the tensor * @return an allocated but uninitialized tensor @@ -59,18 +60,19 @@ public interface Tensor extends Shaped, AutoCloseable { * unknown} * @throws IllegalStateException if tensor failed to be allocated */ - static T of(Class type, Shape shape) { - return of(type, shape, -1); + static T of(TensorScope scope, Class type, Shape shape) { + return of(scope, type, shape, -1); } /** * Allocates a tensor of a given datatype, shape and size. * - *

    This method is identical to {@link #of(Class, Shape)}, except that the final size of the + *

    This method is identical to {@link #of(TensorScope, Class, Shape)}, except that the final size of the * tensor can be explicitly set instead of computing it from the datatype and shape, which could be larger than the * actual space required to store the data but not smaller. * * @param the tensor type + * @param scope the {@link TensorScope} to create the tensor in * @param type the tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor or -1 to compute the size from the shape @@ -82,10 +84,10 @@ static T of(Class type, Shape shape) { * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension() * unknown} * @throws IllegalStateException if tensor failed to be allocated - * @see #of(Class, Shape) + * @see #of(TensorScope, Class, Shape) */ - static T of(Class type, Shape shape, long size) { - RawTensor tensor = RawTensor.allocate(type, shape, size); + static T of(TensorScope scope, Class type, Shape shape, long size) { + RawTensor tensor = RawTensor.allocate(scope, type, shape, size); try { return (T) tensor.asTypedTensor(); } catch (Exception e) { @@ -112,6 +114,7 @@ static T of(Class type, Shape shape, long size) { * automatically released before rethrowing the same exception. * * @param the tensor type + * @param scope the {@link TensorScope} to create the tensor in * @param type the tensor type class * @param shape shape of the tensor * @param dataInitializer method receiving accessor to the allocated tensor data for initialization @@ -121,20 +124,21 @@ static T of(Class type, Shape shape, long size) { * unknown} * @throws IllegalStateException if tensor failed to be allocated */ - static T of(Class type, Shape shape, Consumer dataInitializer) { - return of(type, shape, -1, dataInitializer); + static T of(TensorScope scope, Class type, Shape shape, Consumer dataInitializer) { + return of(scope, type, shape, -1, dataInitializer); } /** * Allocates a tensor of a given datatype, shape and size. * - *

    This method is identical to {@link #of(Class, Shape, Consumer)}, except that the final + *

    This method is identical to {@link #of(TensorScope, Class, Shape, Consumer)}, except that the final * size for the tensor can be explicitly set instead of being computed from the datatype and shape. * *

    This could be useful for tensor types that stores data but also metadata in the tensor memory, * such as the lookup table in a tensor of strings. * * @param the tensor type + * @param scope the {@link TensorScope} to create the tensor in * @param type the tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor or -1 to compute the size from the shape @@ -147,10 +151,10 @@ static T of(Class type, Shape shape, Consumer dataInitia * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension() * unknown} * @throws IllegalStateException if tensor failed to be allocated - * @see #of(Class, Shape, long, Consumer) + * @see #of(TensorScope, Class, Shape, long, Consumer) */ - static T of(Class type, Shape shape, long size, Consumer dataInitializer) { - T tensor = of(type, shape, size); + static T of(TensorScope scope, Class type, Shape shape, long size, Consumer dataInitializer) { + T tensor = of(scope, type, shape, size); try { dataInitializer.accept(tensor); return tensor; @@ -167,6 +171,7 @@ static T of(Class type, Shape shape, long size, Consumer * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API. * * @param the tensor type + * @param scope the {@link TensorScope} to create the tensor in * @param type the tensor type class * @param shape the tensor shape. * @param rawData a buffer containing the tensor raw data. @@ -175,8 +180,8 @@ static T of(Class type, Shape shape, long size, Consumer * unknown} * @throws IllegalStateException if tensor failed to be allocated with the given parameters */ - static T of(Class type, Shape shape, ByteDataBuffer rawData) { - return of(type, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size())); + static T of(TensorScope scope, Class type, Shape shape, ByteDataBuffer rawData) { + return of(scope, type, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size())); } /** @@ -218,8 +223,6 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData /** * Detach this tensor from any scopes managing it. It must be manually closed or attached to another scope. - * - *

    Semantically, this makes the tensor everyone's responsibility: whoever uses it last needs to close it. */ default void detach() { TensorScope.detach(this); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 4b9f311e979..294e025e96a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -16,11 +16,8 @@ */ package org.tensorflow; +import java.util.HashSet; import java.util.Set; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Consumer; -import java.util.function.Function; -import java.util.function.Supplier; /** @@ -35,155 +32,17 @@ * {@link Tensor#detach()} detaches the tensor from it's scope, requiring the user to close it manually or attach it to * another scope. *

    - * Scopes will be inherited at thread creation, but further scope creation on different threads will be independent, - * other than having the same parent. Closing a scope will close it's children regardless of which threads they are on. + * Like Tensors, TensorScope is not thread safe. */ -public class TensorScope implements AutoCloseable { - - private static final InheritableThreadLocal currentScope = new InheritableThreadLocal<>(); - - public static TensorScope currentScope() { - TensorScope scope = currentScope.get(); - - if (scope == null || !scope.closed) { - return scope; - } - - // scope could be closed in another thread, in which case this thread's currentScope wouldn't be updated - while (scope != null && scope.closed) { - scope = scope.parent; - } - currentScope.set(scope); - return scope; - } - - /** - * Runs {@code block}, then closes any tensors created during its execution. - *

    To release tensors, use {@link #withCleanup(Consumer)} or one of the {@code produceWithCleanup} methods. - */ - public static void withCleanup(Runnable block) { - TensorScope.withCleanup((scope) -> block.run()); - } - - /** - * Runs {@code block}, then closes any tensors created during its execution (or attached to the scope). - *

    Tensors can be released using the passed scope. - */ - public static void withCleanup(Consumer block) { - try (TensorScope scope = new TensorScope()) { - block.accept(scope); - } - } - - /** - * Runs {@code block} and returns the result, then closes any tensors created during its execution. - *

    To release tensors, use {@link #getWithCleanup(Function)} or one of the {@code produceWithCleanup} methods. - *

    Does not release or detach the result. If you return a tensor, it will be closed unless otherwise released. - */ - public static T getWithCleanup(Supplier block) { - return TensorScope.getWithCleanup((scope) -> block.get()); - } - - /** - * Runs {@code block} and returns the result, then closes any tensors created during its execution (or attached to the - * scope). - *

    Tensors can be released using the passed scope. - *

    Does not release or detach the result. If you return a tensor, it will be closed unless otherwise released. - */ - public static T getWithCleanup(Function block) { - try (TensorScope scope = new TensorScope()) { - return block.apply(scope); - } - } - - /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its - * execution. - *

    To release other tensors, use {@link #produceTensorWithCleanup(Function)}. - * - * @return the released result of {@code block} - */ - public static T produceTensorWithCleanup(Supplier block) { - return produceTensorWithCleanup((scope) -> block.get()); - } - - /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its - * execution (or attached to the scope). - *

    Tensors can be released using the passed scope. - * - * @return the released result of {@code block} - */ - public static T produceTensorWithCleanup(Function block) { - try (TensorScope scope = new TensorScope()) { - return scope.release(block.apply(scope)); - } - } +public final class TensorScope implements AutoCloseable { - /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its - * execution. - *

    To release other tensors, use {@link #produceTensorWithCleanup(Function)}. - * - * @return the released result of {@code block} - */ - public static T produceTensorContainerWithCleanup(Supplier block) { - return produceTensorContainerWithCleanup((scope) -> block.get()); - } - - /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its - * execution (or attached to the scope). - *

    Tensors can be released using the passed scope. - * - * @return the released result of {@code block} - */ - public static T produceTensorContainerWithCleanup(Function block) { - try (TensorScope scope = new TensorScope()) { - return scope.release(block.apply(scope)); - } - } - - - /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its - * execution. - *

    To release other tensors, use {@link #produceTensorWithCleanup(Function)}. - * - * @return the released result of {@code block} - */ - public static > T produceTensorsWithCleanup(Supplier block) { - return TensorScope.produceTensorsWithCleanup((scope) -> block.get()); - } - - /** - * Runs {@code block} and releases and returns the result, then closes any other tensors created during its - * execution (or attached to the scope). - *

    Tensors can be released using the passed scope. - * - * @return the released result of {@code block} - */ - public static > T produceTensorsWithCleanup(Function block) { - try (TensorScope scope = new TensorScope()) { - return scope.release(block.apply(scope)); - } - } - /** * Create a new tensor scope. If {@code autoAttach} is false, will not automatically manage tensors. * * @see TensorScope */ - TensorScope() { - this.parent = currentScope(); - currentScope.set(this); - - if (this.parent != null) { - synchronized (this.parent) { - this.parent.children.add(this); - } - } + public TensorScope() { } /** @@ -194,82 +53,24 @@ public synchronized void close() { if (closed) { return; } - - children.forEach(TensorScope::close); tensors.forEach(Tensor::close); closed = true; - - if (parent != null) { - parent.children.remove(this); - } - - if (currentScope() == this) { - currentScope.set(this.parent); - } - } - - /** - * Release the tensors and child scopes of this scope to it's parent, without closing them. - *

    - * Semantically, calling this method makes all of the resources in this scope the parent's responsibility, as if this - * scope had never existed. - *

    - * This will close this scope, but does not close any of it's resources. - * - * @throws IllegalStateException if this scope has no parent. If this happens, * the scope is not closed and no - * resources are released. - */ - public synchronized void releaseAllToParent() { - release(true); } /** - * Release the tensors and child scopes of this scope to it's parent, or detach them if this scope has no parent. + * Detach all of this scope's tensors, then close the scope. *

    - * Semantically, calling this method makes all of the resources in this scope the parent's responsibility, as if this - * scope had never existed. It can be used in a method to transfer control to the caller, leaving how the resources - * are managed up to the caller. - *

    - * This will close this scope, but does not close any of it's resources. - */ - public synchronized void releaseAll() { - release(false); - } - - /** - * Release the tensors and child scopes of this scope without closing them, to it's parent if it has one. - * - *

    WARNING: this method may release resources without assigning them to another scope if - * {@code requireParent} is false. {@link #releaseAllToParent()} should be used instead wherever possible. + * EXTREMELY DANGEROUS: this will close this scope, but does not close any of it's resources. * - * @param requireParent Whether to require a parent scope to release resources to. - * @throws IllegalStateException if this scope has no parent, but {@code requireParent} is true. If this happens, the - * scope is not closed and no resources are released. + * @return All of this scope's now-detached tensors */ - private synchronized void release(boolean requireParent) { - if (closed) { - return; - } - - if (this.parent == null && requireParent) { - throw new IllegalStateException("Can't release to parent: scope does not have parent."); - } - - if (this.parent != null) { - TensorScope newParent = this.parent; - newParent.children.addAll(children); - children.forEach(x -> x.parent = newParent); - tensors.forEach(newParent::attach); - } else { - children.forEach(x -> x.parent = null); - tensors.forEach(TensorScope::detach); - } - - children.clear(); + public synchronized Set detachAll() { + Set detachedTensors = new HashSet<>(this.tensors); + detachedTensors.forEach(TensorScope::detach); + closed = true; tensors.clear(); - - close(); + return detachedTensors; } public static T detach(T tensor) { @@ -419,101 +220,12 @@ public TensorScope withTensors(TensorContainer... tensors) { /** * @see #attach(Tensor) */ - public TensorScope withTensors(Iterable... tensors) { + @SafeVarargs + public final TensorScope withTensors(Iterable... tensors) { attach(tensors); return this; } - /** - * Attach this tensor to the parent of this scope, removing it from its current scope, or detach it if there is no - * current scope or the current scope does not have a parent. - * - *

    Semantically, this makes the tensor's resources this scope's parent's responsibility. - * - * @param requireParent whether to require a parent scope to release resources to. - * @throws IllegalStateException if there is no current scope or the current scope does not have a parent, but {@code - * requireParent} is true. If this happens, the tensor's scope is not changed. - */ - private T release(T tensor, boolean requireParent) { - if (parent == null && requireParent) { - throw new IllegalStateException( - "Can't release to parent: not in a current scope, or the current scope does not have a parent."); - } - - detach(tensor); - if (parent != null) { - parent.attach(tensor); - } - return tensor; - } - - - /** - * Attach this tensor to the parent of this scope, removing it from its current scope, or detach it if there is no - * current scope or the current scope does not have a parent. - * - *

    Semantically, this makes the tensor's resources this scope's parent's responsibility. - */ - public T release(T tensor) { - return release(tensor, false); - } - - /** - * @see #release(Tensor) - */ - public void release(Tensor... tensors) { - for (Tensor t : tensors) { - release(t); - } - } - - /** - * @see #release(Tensor) - */ - public T release(T tensors) { - release(tensors.tensors()); - return tensors; - } - - /** - * @see #release(Tensor) - */ - public void release(TensorContainer... tensors) { - for (TensorContainer ht : tensors) { - release(ht); - } - } - - /** - * @see #release(Tensor) - */ - public > T release(T tensors) { - tensors.forEach(this::release); - return tensors; - } - - /** - * @see #release(Tensor) - */ - @SafeVarargs - public final void release(Iterable... tensors) { - for (Iterable iterable : tensors) { - release(iterable); - } - } - - /** - * Attach this tensor to the parent of this scope, removing it from its current scope. - * - *

    Semantically, this makes the tensor's resources this scope's parent's responsibility. - * - * @throws IllegalStateException if there is no current scope or the current scope does not have a parent, but {@code - * requireParent} is true. If this happens, the tensor's scope is not changed. - */ - public T releaseToParent(T tensor) { - return release(tensor, true); - } - /** * Gets whether the scope is closed. */ @@ -522,7 +234,5 @@ public synchronized boolean isClosed() { } private boolean closed = false; - private final Set tensors = ConcurrentHashMap.newKeySet(); - private TensorScope parent; - private final Set children = ConcurrentHashMap.newKeySet(); + private final Set tensors = new HashSet<>(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 497ee5f2d46..21f5794186b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -20,6 +20,7 @@ import org.tensorflow.Operation; import org.tensorflow.Output; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.DoubleNdArray; @@ -37,13 +38,13 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; -import org.tensorflow.op.Ops; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; -import org.tensorflow.types.TBfloat16; +import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; +import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TFloat16; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -80,7 +81,8 @@ public final class Constant extends RawOp implements Operand */ @Endpoint public static Constant scalarOf(Scope scope, int data) { - try (TInt32 value = TInt32.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -89,13 +91,14 @@ public static Constant scalarOf(Scope scope, int data) { * Creates a rank-1 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant vectorOf(Scope scope, int[] data) { - try (TInt32 value = TInt32.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -119,14 +122,15 @@ public static Constant arrayOf(Scope scope, int... data) { * Creates a rank-2 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -135,14 +139,15 @@ public static Constant tensorOf(Scope scope, int[][] data) { * Creates a rank-3 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -151,14 +156,15 @@ public static Constant tensorOf(Scope scope, int[][][] data) { * Creates a rank-4 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -167,14 +173,15 @@ public static Constant tensorOf(Scope scope, int[][][][] data) { * Creates a rank-5 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -183,14 +190,15 @@ public static Constant tensorOf(Scope scope, int[][][][][] data) { * Creates a rank-6 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][][][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -207,7 +215,8 @@ public static Constant tensorOf(Scope scope, IntNdArray data) { if (data instanceof TInt32) { return create(scope, (TInt32) data); } - try (TInt32 value = TInt32.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -223,7 +232,8 @@ public static Constant tensorOf(Scope scope, IntNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, IntDataBuffer data) { - try (TInt32 value = TInt32.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -237,7 +247,8 @@ public static Constant tensorOf(Scope scope, Shape shape, IntDataBuffer */ @Endpoint public static Constant scalarOf(Scope scope, float data) { - try (TFloat32 value = TFloat32.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -246,13 +257,14 @@ public static Constant scalarOf(Scope scope, float data) { * Creates a rank-1 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant vectorOf(Scope scope, float[] data) { - try (TFloat32 value = TFloat32.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -276,14 +288,15 @@ public static Constant arrayOf(Scope scope, float... data) { * Creates a rank-2 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -292,14 +305,15 @@ public static Constant tensorOf(Scope scope, float[][] data) { * Creates a rank-3 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -308,14 +322,15 @@ public static Constant tensorOf(Scope scope, float[][][] data) { * Creates a rank-4 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -324,14 +339,15 @@ public static Constant tensorOf(Scope scope, float[][][][] data) { * Creates a rank-5 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -340,14 +356,15 @@ public static Constant tensorOf(Scope scope, float[][][][][] data) { * Creates a rank-6 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][][][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -364,7 +381,8 @@ public static Constant tensorOf(Scope scope, FloatNdArray data) { if (data instanceof TFloat32) { return create(scope, (TFloat32) data); } - try (TFloat32 value = TFloat32.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -380,7 +398,8 @@ public static Constant tensorOf(Scope scope, FloatNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, FloatDataBuffer data) { - try (TFloat32 value = TFloat32.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -394,7 +413,8 @@ public static Constant tensorOf(Scope scope, Shape shape, FloatDataBuf */ @Endpoint public static Constant scalarOf(Scope scope, double data) { - try (TFloat64 value = TFloat64.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -403,13 +423,14 @@ public static Constant scalarOf(Scope scope, double data) { * Creates a rank-1 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant vectorOf(Scope scope, double[] data) { - try (TFloat64 value = TFloat64.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -433,14 +454,15 @@ public static Constant arrayOf(Scope scope, double... data) { * Creates a rank-2 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -449,14 +471,15 @@ public static Constant tensorOf(Scope scope, double[][] data) { * Creates a rank-3 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -465,14 +488,15 @@ public static Constant tensorOf(Scope scope, double[][][] data) { * Creates a rank-4 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -481,14 +505,15 @@ public static Constant tensorOf(Scope scope, double[][][][] data) { * Creates a rank-5 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -497,14 +522,15 @@ public static Constant tensorOf(Scope scope, double[][][][][] data) { * Creates a rank-6 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][][][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -521,7 +547,8 @@ public static Constant tensorOf(Scope scope, DoubleNdArray data) { if (data instanceof TFloat64) { return create(scope, (TFloat64) data); } - try (TFloat64 value = TFloat64.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -537,7 +564,8 @@ public static Constant tensorOf(Scope scope, DoubleNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, DoubleDataBuffer data) { - try (TFloat64 value = TFloat64.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -551,7 +579,8 @@ public static Constant tensorOf(Scope scope, Shape shape, DoubleDataBu */ @Endpoint public static Constant scalarOf(Scope scope, long data) { - try (TInt64 value = TInt64.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -560,13 +589,14 @@ public static Constant scalarOf(Scope scope, long data) { * Creates a rank-1 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant vectorOf(Scope scope, long[] data) { - try (TInt64 value = TInt64.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -575,14 +605,15 @@ public static Constant vectorOf(Scope scope, long[] data) { * Creates a rank-2 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -606,14 +637,15 @@ public static Constant arrayOf(Scope scope, long... data) { * Creates a rank-3 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -622,14 +654,15 @@ public static Constant tensorOf(Scope scope, long[][][] data) { * Creates a rank-4 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -638,14 +671,15 @@ public static Constant tensorOf(Scope scope, long[][][][] data) { * Creates a rank-5 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -654,14 +688,15 @@ public static Constant tensorOf(Scope scope, long[][][][][] data) { * Creates a rank-6 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][][][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -678,7 +713,8 @@ public static Constant tensorOf(Scope scope, LongNdArray data) { if (data instanceof TInt64) { return create(scope, (TInt64) data); } - try (TInt64 value = TInt64.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -694,7 +730,8 @@ public static Constant tensorOf(Scope scope, LongNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, LongDataBuffer data) { - try (TInt64 value = TInt64.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -708,7 +745,8 @@ public static Constant tensorOf(Scope scope, Shape shape, LongDataBuffer */ @Endpoint public static Constant scalarOf(Scope scope, boolean data) { - try (TBool value = TBool.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -717,13 +755,14 @@ public static Constant scalarOf(Scope scope, boolean data) { * Creates a rank-1 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant vectorOf(Scope scope, boolean[] data) { - try (TBool value = TBool.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -747,14 +786,15 @@ public static Constant arrayOf(Scope scope, boolean... data) { * Creates a rank-2 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -763,14 +803,15 @@ public static Constant tensorOf(Scope scope, boolean[][] data) { * Creates a rank-3 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -779,14 +820,15 @@ public static Constant tensorOf(Scope scope, boolean[][][] data) { * Creates a rank-4 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -795,14 +837,15 @@ public static Constant tensorOf(Scope scope, boolean[][][][] data) { * Creates a rank-5 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -811,14 +854,15 @@ public static Constant tensorOf(Scope scope, boolean[][][][][] data) { * Creates a rank-6 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][][][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -835,7 +879,8 @@ public static Constant tensorOf(Scope scope, BooleanNdArray data) { if (data instanceof TBool) { return create(scope, (TBool) data); } - try (TBool value = TBool.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -851,7 +896,8 @@ public static Constant tensorOf(Scope scope, BooleanNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, BooleanDataBuffer data) { - try (TBool value = TBool.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -865,7 +911,8 @@ public static Constant tensorOf(Scope scope, Shape shape, BooleanDataBuff */ @Endpoint public static Constant scalarOf(Scope scope, byte data) { - try (TUint8 value = TUint8.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -874,13 +921,14 @@ public static Constant scalarOf(Scope scope, byte data) { * Creates a rank-1 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant vectorOf(Scope scope, byte[] data) { - try (TUint8 value = TUint8.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -904,14 +952,15 @@ public static Constant arrayOf(Scope scope, byte... data) { * Creates a rank-2 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -920,14 +969,15 @@ public static Constant tensorOf(Scope scope, byte[][] data) { * Creates a rank-3 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -936,14 +986,15 @@ public static Constant tensorOf(Scope scope, byte[][][] data) { * Creates a rank-4 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -952,14 +1003,15 @@ public static Constant tensorOf(Scope scope, byte[][][][] data) { * Creates a rank-5 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -968,14 +1020,15 @@ public static Constant tensorOf(Scope scope, byte[][][][][] data) { * Creates a rank-6 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][][][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -992,7 +1045,8 @@ public static Constant tensorOf(Scope scope, ByteNdArray data) { if (data instanceof TUint8) { return create(scope, (TUint8) data); } - try (TUint8 value = TUint8.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -1008,7 +1062,8 @@ public static Constant tensorOf(Scope scope, ByteNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer data) { - try (TUint8 value = TUint8.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -1022,13 +1077,13 @@ public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @return a constant of type `type` - * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the - * buffer + * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the buffer */ @Endpoint public static Constant tensorOf(Scope scope, Class type, Shape shape, ByteDataBuffer data) { - try (T value = Tensor.of(type, shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + T value = Tensor.of(tensorScope, type, shape, data)) { return create(scope, value); } } @@ -1042,7 +1097,8 @@ public static Constant tensorOf(Scope scope, Class type, */ @Endpoint public static Constant scalarOf(Scope scope, String data) { - try (TString value = TString.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -1057,7 +1113,8 @@ public static Constant scalarOf(Scope scope, String data) { */ @Endpoint public static Constant scalarOf(Scope scope, Charset charset, String data) { - try (TString value = TString.tensorOf(charset, NdArrays.scalarOfObject(data))) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, charset, NdArrays.scalarOfObject(data))) { return create(scope, value); } } @@ -1071,7 +1128,8 @@ public static Constant scalarOf(Scope scope, Charset charset, String da */ public static Constant vectorOf(Scope scope, String[] data) { NdArray src = NdArrays.vectorOfObjects(data); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1081,13 +1139,14 @@ public static Constant vectorOf(Scope scope, String[] data) { * * @param scope is a scope used to add the underlying operation. * @param charset charset for encoding/decoding strings bytes. - * @param data An array containing the values to put into the new constant. String elements are - * sequences of bytes from the last array dimension. + * @param data An array containing the values to put into the new constant. String elements are sequences of bytes + * from the last array dimension. * @return the {@code String} constant */ @Endpoint public static Constant vectorOf(Scope scope, Charset charset, String[] data) { - try (TString value = TString.tensorOf(charset, NdArrays.vectorOfObjects(data))) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, charset, NdArrays.vectorOfObjects(data))) { return Constant.create(scope, value); } } @@ -1112,8 +1171,8 @@ public static Constant arrayOf(Scope scope, String... data) { * * @param scope is a scope used to add the underlying operation. * @param charset charset for encoding/decoding strings bytes. - * @param data An array containing the values to put into the new constant. String elements are - * sequences of bytes from the last array dimension. + * @param data An array containing the values to put into the new constant. String elements are sequences of bytes + * from the last array dimension. * @return the {@code String} constant */ @Endpoint(name = "array") @@ -1134,7 +1193,8 @@ public static Constant arrayOf(Scope scope, Charset charset, String... public static Constant tensorOf(Scope scope, String[][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1149,7 +1209,8 @@ public static Constant tensorOf(Scope scope, String[][] data) { public static Constant tensorOf(Scope scope, String[][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1164,7 +1225,8 @@ public static Constant tensorOf(Scope scope, String[][][] data) { public static Constant tensorOf(Scope scope, String[][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1179,7 +1241,8 @@ public static Constant tensorOf(Scope scope, String[][][][] data) { public static Constant tensorOf(Scope scope, String[][][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1194,14 +1257,15 @@ public static Constant tensorOf(Scope scope, String[][][][][] data) { public static Constant tensorOf(Scope scope, String[][][][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } /** - * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, - * using the default UTF-8 encoding. + * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, using the default + * UTF-8 encoding. * * @param scope is a scope used to add the underlying operation. * @param data an n-dimensional array of {@code String} elements. @@ -1212,14 +1276,15 @@ public static Constant tensorOf(Scope scope, NdArray data) { if (data instanceof TString) { return create(scope, (TString) data); } - try (TString value = TString.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, data)) { return create(scope, value); } } /** - * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, - * using the given encoding. + * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, using the given + * encoding. * * @param scope is a scope used to add the underlying operation. * @param charset charset used to encode/decode string bytes. @@ -1228,14 +1293,14 @@ public static Constant tensorOf(Scope scope, NdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Charset charset, NdArray data) { - try (TString value = TString.tensorOf(charset, data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, charset, data)) { return create(scope, value); } } /** - * Create a {@link TString} constant with data from the given buffer, using the default UTF-8 - * encoding. + * Create a {@link TString} constant with data from the given buffer, using the default UTF-8 encoding. * * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. @@ -1245,7 +1310,8 @@ public static Constant tensorOf(Scope scope, Charset charset, NdArray tensorOf(Scope scope, Shape shape, DataBuffer data) { - try (TString value = TString.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -1263,14 +1329,14 @@ public static Constant tensorOf(Scope scope, Shape shape, DataBuffer tensorOf(Scope scope, Charset charset, Shape shape, DataBuffer data) { - try (TString value = TString.tensorOf(charset, shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, charset, shape, data)) { return create(scope, value); } } /** - * Creates a rank-1 constant of {@code long} elements representing the size of each dimensions of - * the given shape. + * Creates a rank-1 constant of {@code long} elements representing the size of each dimensions of the given shape. * * @param scope is a scope used to add the underlying operation. * @param shape a shape diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java index ef20b5ec2b6..5203c0892bf 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TBfloat16Mapper; import org.tensorflow.ndarray.FloatNdArray; @@ -34,14 +35,13 @@ * Brain 16-bit float tensor type. * *

    This type differs from {@link TFloat16} as it truncates the mantissa of a 32-bit float and - * preserve all exponent bits for faster conversion, while the latter shrink the exponent and have a - * longer mantissa for more precision. + * preserve all exponent bits for faster conversion, while the latter shrink the exponent and have a longer mantissa for + * more precision. * *

    Since there is no floating-point type that fits in 16 bits in Java, a conversion (with - * potentially a precision loss) is required for each 32 bits value written or read on a tensor of - * this type from the JVM. Therefore, if a lot of I/O operations are to be expected on a tensor, - * performances will be improved by working with {@link TFloat32} or {@link TFloat64} data types - * whenever possible. + * potentially a precision loss) is required for each 32 bits value written or read on a tensor of this type from the + * JVM. Therefore, if a lot of I/O operations are to be expected on a tensor, performances will be improved by working + * with {@link TFloat32} or {@link TFloat64} data types whenever possible. * *

    Note that some CPUs support the bfloat16 format natively, which can result in faster * computation compared to {@link TFloat16} when GPUs are not used. @@ -52,24 +52,26 @@ public interface TBfloat16 extends FloatNdArray, TFloating { /** * Allocates a new tensor for storing a single float value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value float to store in the new tensor * @return the new tensor */ - static TBfloat16 scalarOf(float value) { - return Tensor.of(TBfloat16.class, Shape.scalar(), data -> data.setFloat(value)); + static TBfloat16 scalarOf(TensorScope scope, float value) { + return Tensor.of(scope, TBfloat16.class, Shape.scalar(), data -> data.setFloat(value)); } /** * Allocates a new tensor for storing a vector of floats. * + * @param scope the {@link TensorScope} to create the tensor in * @param values floats to store in the new tensor * @return the new tensor */ - static TBfloat16 vectorOf(float... values) { + static TBfloat16 vectorOf(TensorScope scope, float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TBfloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TBfloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -77,44 +79,48 @@ static TBfloat16 vectorOf(float... values) { * *

    The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TBfloat16 tensorOf(NdArray src) { - return Tensor.of(TBfloat16.class, src.shape(), src::copyTo); + static TBfloat16 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TBfloat16.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TBfloat16 tensorOf(Shape shape) { - return Tensor.of(TBfloat16.class, shape); + static TBfloat16 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TBfloat16.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static TBfloat16 tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(TBfloat16.class, shape, d -> d.write(data)); + static TBfloat16 tensorOf(TensorScope scope, Shape shape, FloatDataBuffer data) { + return Tensor.of(scope, TBfloat16.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TBfloat16 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TBfloat16.class, shape, dataInit); + static TBfloat16 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TBfloat16.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java index 0158c12b910..47179faf045 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TBoolMapper; import org.tensorflow.ndarray.BooleanNdArray; @@ -35,8 +36,8 @@ * Boolean tensor type. * *

    If direct memory mapping is not available in the JVM, tensors of this type might require an - * explicit mapping between Java boolean values and byte buffers using the {@link DataLayouts#BOOL - * BOOL} layout, which may impact I/O performances. + * explicit mapping between Java boolean values and byte buffers using the {@link DataLayouts#BOOL BOOL} layout, which + * may impact I/O performances. */ @TensorType(dataType = DataType.DT_BOOL, byteSize = 1, mapperClass = TBoolMapper.class) public interface TBool extends BooleanNdArray, TType { @@ -44,24 +45,26 @@ public interface TBool extends BooleanNdArray, TType { /** * Allocates a new tensor for storing a single boolean value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value boolean to store in the new tensor * @return the new tensor */ - static TBool scalarOf(boolean value) { - return Tensor.of(TBool.class, Shape.scalar(), data -> data.setBoolean(value)); + static TBool scalarOf(TensorScope scope, boolean value) { + return Tensor.of(scope, TBool.class, Shape.scalar(), data -> data.setBoolean(value)); } /** * Allocates a new tensor for storing a vector of booleans. * + * @param scope the {@link TensorScope} to create the tensor in * @param values booleans to store in the new tensor * @return the new tensor */ - static TBool vectorOf(boolean... values) { + static TBool vectorOf(TensorScope scope, boolean... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TBool.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TBool.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -69,43 +72,47 @@ static TBool vectorOf(boolean... values) { * *

    The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TBool tensorOf(NdArray src) { - return Tensor.of(TBool.class, src.shape(), src::copyTo); + static TBool tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TBool.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TBool tensorOf(Shape shape) { - return Tensor.of(TBool.class, shape); + static TBool tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TBool.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of booleans to initialize the tensor with * @return the new tensor */ - static TBool tensorOf(Shape shape, BooleanDataBuffer data) { - return Tensor.of(TBool.class, shape, d -> d.write(data)); + static TBool tensorOf(TensorScope scope, Shape shape, BooleanDataBuffer data) { + return Tensor.of(scope, TBool.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TBool tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TBool.class, shape, dataInit); + static TBool tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TBool.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java index a43a0831f10..9e907d7e77c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TFloat16Mapper; import org.tensorflow.ndarray.FloatNdArray; @@ -34,14 +35,13 @@ * IEEE-754 half-precision 16-bit float tensor type. * *

    Since there is no floating-point type that fits in 16 bits in Java, a conversion (with - * potentially a precision loss) is required for each 32 bits value written or read on a tensor of - * this type from the JVM. Therefore, if a lot of I/O operations are to be expected on a tensor, - * performances will be improved by working with {@link TFloat32} or {@link TFloat64} data types - * whenever possible. + * potentially a precision loss) is required for each 32 bits value written or read on a tensor of this type from the + * JVM. Therefore, if a lot of I/O operations are to be expected on a tensor, performances will be improved by working + * with {@link TFloat32} or {@link TFloat64} data types whenever possible. * *

    Also, {@code TFloat16} tensors normally perform better if they are located in GPU memory since - * most CPUs do not support this format natively. For CPU computation on 16-bit floats, the {@link - * TBfloat16} tensor type might be a better option. + * most CPUs do not support this format natively. For CPU computation on 16-bit floats, the {@link TBfloat16} tensor + * type might be a better option. */ @TensorType(dataType = DataType.DT_HALF, byteSize = 2, mapperClass = TFloat16Mapper.class) public interface TFloat16 extends FloatNdArray, TFloating { @@ -49,24 +49,26 @@ public interface TFloat16 extends FloatNdArray, TFloating { /** * Allocates a new tensor for storing a single float value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value float to store in the new tensor * @return the new tensor */ - static TFloat16 scalarOf(float value) { - return Tensor.of(TFloat16.class, Shape.scalar(), data -> data.setFloat(value)); + static TFloat16 scalarOf(TensorScope scope, float value) { + return Tensor.of(scope, TFloat16.class, Shape.scalar(), data -> data.setFloat(value)); } /** * Allocates a new tensor for storing a vector of floats. * + * @param scope the {@link TensorScope} to create the tensor in * @param values floats to store in the new tensor * @return the new tensor */ - static TFloat16 vectorOf(float... values) { + static TFloat16 vectorOf(TensorScope scope, float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TFloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TFloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -74,43 +76,47 @@ static TFloat16 vectorOf(float... values) { * *

    The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TFloat16 tensorOf(NdArray src) { - return Tensor.of(TFloat16.class, src.shape(), src::copyTo); + static TFloat16 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TFloat16.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TFloat16 tensorOf(Shape shape) { - return Tensor.of(TFloat16.class, shape); + static TFloat16 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TFloat16.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static TFloat16 tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(TFloat16.class, shape, d -> d.write(data)); + static TFloat16 tensorOf(TensorScope scope, Shape shape, FloatDataBuffer data) { + return Tensor.of(scope, TFloat16.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TFloat16 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TFloat16.class, shape, dataInit); + static TFloat16 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TFloat16.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java index 35208f7de43..0e8496fd0a1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TFloat32Mapper; import org.tensorflow.ndarray.FloatNdArray; @@ -30,31 +31,35 @@ import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; -/** IEEE-754 single-precision 32-bit float tensor type. */ +/** + * IEEE-754 single-precision 32-bit float tensor type. + */ @TensorType(dataType = DataType.DT_FLOAT, byteSize = 4, mapperClass = TFloat32Mapper.class) public interface TFloat32 extends FloatNdArray, TFloating { /** * Allocates a new tensor for storing a single float value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value float to store in the new tensor * @return the new tensor */ - static TFloat32 scalarOf(float value) { - return Tensor.of(TFloat32.class, Shape.scalar(), data -> data.setFloat(value)); + static TFloat32 scalarOf(TensorScope scope, float value) { + return Tensor.of(scope, TFloat32.class, Shape.scalar(), data -> data.setFloat(value)); } /** * Allocates a new tensor for storing a vector of floats. * + * @param scope the {@link TensorScope} to create the tensor in * @param values floats to store in the new tensor * @return the new tensor */ - static TFloat32 vectorOf(float... values) { + static TFloat32 vectorOf(TensorScope scope, float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TFloat32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TFloat32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -62,43 +67,47 @@ static TFloat32 vectorOf(float... values) { * *

    The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TFloat32 tensorOf(NdArray src) { - return Tensor.of(TFloat32.class, src.shape(), src::copyTo); + static TFloat32 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TFloat32.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TFloat32 tensorOf(Shape shape) { - return Tensor.of(TFloat32.class, shape); + static TFloat32 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TFloat32.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static TFloat32 tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(TFloat32.class, shape, d -> d.write(data)); + static TFloat32 tensorOf(TensorScope scope, Shape shape, FloatDataBuffer data) { + return Tensor.of(scope, TFloat32.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TFloat32 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TFloat32.class, shape, dataInit); + static TFloat32 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TFloat32.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java index 957612691e5..2c2f6f95f78 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TFloat64Mapper; import org.tensorflow.ndarray.DoubleNdArray; @@ -31,31 +32,35 @@ import org.tensorflow.types.family.TFloating; -/** IEEE-754 double-precision 64-bit float tensor type. */ +/** + * IEEE-754 double-precision 64-bit float tensor type. + */ @TensorType(dataType = DataType.DT_DOUBLE, byteSize = 8, mapperClass = TFloat64Mapper.class) public interface TFloat64 extends DoubleNdArray, TFloating { /** * Allocates a new tensor for storing a single double value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value double to store in the new tensor * @return the new tensor */ - static TFloat64 scalarOf(double value) { - return Tensor.of(TFloat64.class, Shape.scalar(), data -> data.setDouble(value)); + static TFloat64 scalarOf(TensorScope scope, double value) { + return Tensor.of(scope, TFloat64.class, Shape.scalar(), data -> data.setDouble(value)); } /** * Allocates a new tensor for storing a vector of doubles. * + * @param scope the {@link TensorScope} to create the tensor in * @param values doubles to store in the new tensor * @return the new tensor */ - static TFloat64 vectorOf(double... values) { + static TFloat64 vectorOf(TensorScope scope, double... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TFloat64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TFloat64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -63,43 +68,47 @@ static TFloat64 vectorOf(double... values) { * *

    The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TFloat64 tensorOf(NdArray src) { - return Tensor.of(TFloat64.class, src.shape(), src::copyTo); + static TFloat64 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TFloat64.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TFloat64 tensorOf(Shape shape) { - return Tensor.of(TFloat64.class, shape); + static TFloat64 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TFloat64.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of doubles to initialize the tensor with * @return the new tensor */ - static TFloat64 tensorOf(Shape shape, DoubleDataBuffer data) { - return Tensor.of(TFloat64.class, shape, d -> d.write(data)); + static TFloat64 tensorOf(TensorScope scope, Shape shape, DoubleDataBuffer data) { + return Tensor.of(scope, TFloat64.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TFloat64 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TFloat64.class, shape, dataInit); + static TFloat64 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TFloat64.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java index 8f6b587795b..6b005b25630 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.internal.types.TInt32Mapper; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.NdArray; @@ -29,32 +30,36 @@ import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TIntegral; -/** 32-bit signed integer tensor type. */ +/** + * 32-bit signed integer tensor type. + */ @TensorType(dataType = DataType.DT_INT32, byteSize = 4, mapperClass = TInt32Mapper.class) public interface TInt32 extends IntNdArray, TIntegral { /** * Allocates a new tensor for storing a single int value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value int to store in the new tensor * @return the new tensor */ - static TInt32 scalarOf(int value) { - return Tensor.of(TInt32.class, Shape.scalar(), data -> data.setInt(value)); + static TInt32 scalarOf(TensorScope scope, int value) { + return Tensor.of(scope, TInt32.class, Shape.scalar(), data -> data.setInt(value)); } /** * Allocates a new tensor for storing a vector of ints. * + * @param scope the {@link TensorScope} to create the tensor in * @param values ints to store in the new tensor * @return the new tensor * @throws IllegalArgumentException if no values are provided */ - static TInt32 vectorOf(int... values) { + static TInt32 vectorOf(TensorScope scope, int... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TInt32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TInt32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -62,43 +67,47 @@ static TInt32 vectorOf(int... values) { * *

    The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TInt32 tensorOf(NdArray src) { - return Tensor.of(TInt32.class, src.shape(), src::copyTo); + static TInt32 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TInt32.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TInt32 tensorOf(Shape shape) { - return Tensor.of(TInt32.class, shape); + static TInt32 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TInt32.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of ints to initialize the tensor with * @return the new tensor */ - static TInt32 tensorOf(Shape shape, IntDataBuffer data) { - return Tensor.of(TInt32.class, shape, d -> d.write(data)); + static TInt32 tensorOf(TensorScope scope, Shape shape, IntDataBuffer data) { + return Tensor.of(scope, TInt32.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor */ - static TInt32 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TInt32.class, shape, dataInit); + static TInt32 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TInt32.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java index 867248c5392..05a4434df65 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TInt64Mapper; import org.tensorflow.ndarray.LongNdArray; @@ -30,31 +31,35 @@ import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TIntegral; -/** 64-bit signed integer tensor type. */ +/** + * 64-bit signed integer tensor type. + */ @TensorType(dataType = DataType.DT_INT64, byteSize = 8, mapperClass = TInt64Mapper.class) public interface TInt64 extends LongNdArray, TIntegral { /** * Allocates a new tensor for storing a single long value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value long to store in the new tensor * @return the new tensor */ - static TInt64 scalarOf(long value) { - return Tensor.of(TInt64.class, Shape.scalar(), data -> data.setLong(value)); + static TInt64 scalarOf(TensorScope scope, long value) { + return Tensor.of(scope, TInt64.class, Shape.scalar(), data -> data.setLong(value)); } /** * Allocates a new tensor for storing a vector of longs. * + * @param scope the {@link TensorScope} to create the tensor in * @param values longs to store in the new tensor * @return the new tensor */ - static TInt64 vectorOf(long... values) { + static TInt64 vectorOf(TensorScope scope, long... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TInt64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TInt64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -62,43 +67,47 @@ static TInt64 vectorOf(long... values) { * *

    The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TInt64 tensorOf(NdArray src) { - return Tensor.of(TInt64.class, src.shape(), src::copyTo); + static TInt64 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TInt64.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TInt64 tensorOf(Shape shape) { - return Tensor.of(TInt64.class, shape); + static TInt64 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TInt64.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of longs to initialize the tensor with * @return the new tensor */ - static TInt64 tensorOf(Shape shape, LongDataBuffer data) { - return Tensor.of(TInt64.class, shape, d -> d.write(data)); + static TInt64 tensorOf(TensorScope scope, Shape shape, LongDataBuffer data) { + return Tensor.of(scope, TInt64.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TInt64 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TInt64.class, shape, dataInit); + static TInt64 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TInt64.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java index b3000cc2f8a..7cb098257a1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets; import java.util.function.Function; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.internal.types.TStringInitializer; import org.tensorflow.internal.types.TStringMapper; import org.tensorflow.ndarray.NdArray; @@ -37,8 +38,8 @@ *

    This type can be used to store any arbitrary byte sequence of variable length. * *

    Since the size of a tensor is fixed, creating a tensor of this type requires to provide all of - * its values initially, so TensorFlow can compute and allocate the right amount of memory. Then the - * data in the tensor is initialized once and cannot be modified afterwards. + * its values initially, so TensorFlow can compute and allocate the right amount of memory. Then the data in the tensor + * is initialized once and cannot be modified afterwards. */ @TensorType(dataType = DataType.DT_STRING, byteSize = -1, mapperClass = TStringMapper.class) public interface TString extends NdArray, TType { @@ -48,11 +49,12 @@ public interface TString extends NdArray, TType { * *

    The string is encoded into bytes using the UTF-8 charset. * + * @param scope the {@link TensorScope} to create the tensor in * @param value scalar value to store in the new tensor * @return the new tensor */ - static TString scalarOf(String value) { - return tensorOf(NdArrays.scalarOfObject(value)); + static TString scalarOf(TensorScope scope, String value) { + return tensorOf(scope, NdArrays.scalarOfObject(value)); } /** @@ -60,14 +62,15 @@ static TString scalarOf(String value) { * *

    The strings are encoded into bytes using the UTF-8 charset. * + * @param scope the {@link TensorScope} to create the tensor in * @param values values to store in the new tensor * @return the new tensor */ - static TString vectorOf(String... values) { + static TString vectorOf(TensorScope scope, String... values) { if (values == null) { throw new IllegalArgumentException(); } - return tensorOf(NdArrays.vectorOfObjects(values)); + return tensorOf(scope, NdArrays.vectorOfObjects(values)); } /** @@ -76,11 +79,12 @@ static TString vectorOf(String... values) { *

    The tensor will have the same shape as the source array and its data will be copied. The * strings are encoded into bytes using the UTF-8 charset. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TString tensorOf(NdArray src) { - return tensorOf(StandardCharsets.UTF_8, src); + static TString tensorOf(TensorScope scope, NdArray src) { + return tensorOf(scope, StandardCharsets.UTF_8, src); } /** @@ -100,13 +104,14 @@ static TString tensorOf(NdArray src) { * assertEquals(originalStrings.getObject(0), tensorStrings.getObject(0)); * }

    * + * @param scope the {@link TensorScope} to create the tensor in * @param charset charset to use for encoding the strings into bytes * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TString tensorOf(Charset charset, NdArray src) { + static TString tensorOf(TensorScope scope, Charset charset, NdArray src) { TStringInitializer initializer = new TStringInitializer<>(src, s -> s.getBytes(charset)); - return Tensor.of(TString.class, src.shape(), initializer.computeRequiredSize(), initializer); + return Tensor.of(scope, TString.class, src.shape(), initializer.computeRequiredSize(), initializer); } /** @@ -115,12 +120,13 @@ static TString tensorOf(Charset charset, NdArray src) { *

    The data will be copied from the provided buffer to the tensor after it is allocated. The * strings are encoded into bytes using the UTF-8 charset. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor * @param data buffer of strings to initialize the tensor with * @return the new tensor */ - static TString tensorOf(Shape shape, DataBuffer data) { - return tensorOf(NdArrays.wrap(shape, data)); + static TString tensorOf(TensorScope scope, Shape shape, DataBuffer data) { + return tensorOf(scope, NdArrays.wrap(shape, data)); } /** @@ -141,13 +147,14 @@ static TString tensorOf(Shape shape, DataBuffer data) { * assertEquals(originalStrings.getObject(0), tensorStrings.getObject(0)); * }

    * + * @param scope the {@link TensorScope} to create the tensor in * @param charset charset to use for encoding the strings into bytes * @param shape shape of the tensor * @param data buffer of strings to initialize the tensor with * @return the new tensor */ - static TString tensorOf(Charset charset, Shape shape, DataBuffer data) { - return tensorOf(charset, NdArrays.wrap(shape, data)); + static TString tensorOf(TensorScope scope, Charset charset, Shape shape, DataBuffer data) { + return tensorOf(scope, charset, NdArrays.wrap(shape, data)); } /** @@ -162,12 +169,13 @@ static TString tensorOf(Charset charset, Shape shape, DataBuffer data) { * byte[] bytes = tensor.data().asBytes().getObject(0); // returns first sequence of bytes in the tensor * }
    * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TString tensorOfBytes(NdArray src) { + static TString tensorOfBytes(TensorScope scope, NdArray src) { TStringInitializer initializer = new TStringInitializer<>(src, Function.identity()); - return Tensor.of(TString.class, src.shape(), initializer.computeRequiredSize(), initializer); + return Tensor.of(scope, TString.class, src.shape(), initializer.computeRequiredSize(), initializer); } /** @@ -182,12 +190,13 @@ static TString tensorOfBytes(NdArray src) { * byte[] bytes = tensor.data().asBytes().getObject(0); // returns first sequence of bytes in the tensor * }
* + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to create * @param data the source array giving the shape and data to the new tensor * @return the new tensor */ - static TString tensorOfBytes(Shape shape, DataBuffer data) { - return tensorOfBytes(NdArrays.wrap(shape, data)); + static TString tensorOfBytes(TensorScope scope, Shape shape, DataBuffer data) { + return tensorOfBytes(scope, NdArrays.wrap(shape, data)); } /** @@ -208,6 +217,8 @@ static TString tensorOfBytes(Shape shape, DataBuffer data) { */ TString using(Charset charset); - /** @return the tensor data as a n-dimensional array of raw byte sequences. */ + /** + * @return the tensor data as a n-dimensional array of raw byte sequences. + */ NdArray asBytes(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java index eae86414cb4..8744b8016a6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TUint8Mapper; import org.tensorflow.ndarray.ByteNdArray; @@ -30,31 +31,35 @@ import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TIntegral; -/** 8-bit unsigned integer tensor type. */ +/** + * 8-bit unsigned integer tensor type. + */ @TensorType(dataType = DataType.DT_UINT8, byteSize = 1, mapperClass = TUint8Mapper.class) public interface TUint8 extends ByteNdArray, TIntegral { /** * Allocates a new tensor for storing a single byte value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value byte to store in the new tensor * @return the new tensor */ - static TUint8 scalarOf(byte value) { - return Tensor.of(TUint8.class, Shape.scalar(), data -> data.setByte(value)); + static TUint8 scalarOf(TensorScope scope, byte value) { + return Tensor.of(scope, TUint8.class, Shape.scalar(), data -> data.setByte(value)); } /** * Allocates a new tensor for storing a vector of bytes. * + * @param scope the {@link TensorScope} to create the tensor in * @param values bytes to store in the new tensor * @return the new tensor */ - static TUint8 vectorOf(byte... values) { + static TUint8 vectorOf(TensorScope scope, byte... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TUint8.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TUint8.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -62,43 +67,47 @@ static TUint8 vectorOf(byte... values) { * *

The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TUint8 tensorOf(NdArray src) { - return Tensor.of(TUint8.class, src.shape(), src::copyTo); + static TUint8 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TUint8.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TUint8 tensorOf(Shape shape) { - return Tensor.of(TUint8.class, shape); + static TUint8 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TUint8.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of bytes to initialize the tensor with * @return the new tensor */ - static TUint8 tensorOf(Shape shape, ByteDataBuffer data) { - return Tensor.of(TUint8.class, shape, d -> d.write(data)); + static TUint8 tensorOf(TensorScope scope, Shape shape, ByteDataBuffer data) { + return Tensor.of(scope, TUint8.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TUint8 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TUint8.class, shape, dataInit); + static TUint8 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TUint8.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index b2b2c34e223..515a5faa067 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -43,32 +43,30 @@ private static Signature minusTwo(Ops tf) { @Test public void createFunction() { try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive); - TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); + TensorScope scope = new TensorScope()) { + TFloat32 x = TFloat32.scalarOf(scope, 3.0f); + assertEquals(8.0f, ((TFloat32) f.call(scope, x)).getFloat()); } } @Test public void createFunctionFromGraph() { - try (Graph g = new Graph()) { - Signature signature = plusFive(Ops.create(g)); - try (ConcreteFunction f = ConcreteFunction.create(signature, g); - TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); - } + try (Graph g = new Graph(); + TensorScope scope = new TensorScope(); + ConcreteFunction f = ConcreteFunction.create(plusFive(Ops.create(g)), g)) { + TFloat32 x = TFloat32.scalarOf(scope, 3.0f); + assertEquals(8.0f, ((TFloat32) f.call(scope, x)).getFloat()); } } @Test public void createFunctionFromSession() { - try (Graph g = new Graph()) { - Signature signature = plusFive(Ops.create(g)); - try (Session s = new Session(g)) { - try (ConcreteFunction f = ConcreteFunction.create(signature, s); - TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); - } - } + try (Graph g = new Graph(); + Session s = new Session(g); + TensorScope scope = new TensorScope(); + ConcreteFunction f = ConcreteFunction.create(plusFive(Ops.create(g)), s)) { + TFloat32 x = TFloat32.scalarOf(scope, 3.0f); + assertEquals(8.0f, ((TFloat32) f.call(scope, x)).getFloat()); } } @@ -76,8 +74,9 @@ public void createFunctionFromSession() { public void chainFunctions() { try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive); ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo); - TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(6.0f, ((TFloat32)f2.call(f1.call(x))).getFloat()); + TensorScope scope = new TensorScope()) { + TFloat32 x = TFloat32.scalarOf(scope, 3.0f); + assertEquals(6.0f, ((TFloat32) f2.call(scope, f1.call(scope, x))).getFloat()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index e4340da3275..47e289efffd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -14,6 +14,12 @@ ==============================================================================*/ package org.tensorflow; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import static org.tensorflow.DeviceSpec.DeviceType; + +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; @@ -21,200 +27,200 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.types.TInt32; -import static com.google.common.truth.Truth.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; -import static org.tensorflow.DeviceSpec.DeviceType; - -/** Tests for {@link DeviceSpec}. */ +/** + * Tests for {@link DeviceSpec}. + */ public class DeviceSpecTest { + @Test public void withDeviceMethod() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) - .build(); + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - Ops tf = Ops.create(g).withSubScope("testScope"); + try (Graph g = new Graph(); Session session = new Session(g, config)) { + Ops tf = Ops.create(g).withSubScope("testScope"); - Constant aOps = tf.constant(-1); + Constant aOps = tf.constant(-1); - DeviceSpec deviceSpec = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); + DeviceSpec deviceSpec = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) + .build(); - Output absOps = tf - .withName("absWithDevice") - .withDevice(deviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = tf + .withName("absWithDevice") + .withDevice(deviceSpec) + .math + .abs(aOps) + .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + List t = session.runner().fetch(absOps).run(scope); + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } @Test public void withEmptyDeviceSpec() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) - .build(); + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - Ops tf = Ops.create(g).withSubScope("testScope"); + try (Graph g = new Graph(); Session session = new Session(g, config)) { + Ops tf = Ops.create(g).withSubScope("testScope"); - Constant aOps = tf.constant(-1); + Constant aOps = tf.constant(-1); - DeviceSpec deviceSpec = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); + DeviceSpec deviceSpec = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) + .build(); - Output absOps = tf - .withName("absWithDevice") - .withDevice(deviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = tf + .withName("absWithDevice") + .withDevice(deviceSpec) + .math + .abs(aOps) + .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + List t = session.runner().fetch(absOps).run(scope); + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } @Test public void withTwoScopes() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); + + try (Graph g = new Graph(); Session session = new Session(g, config)) { + DeviceSpec deviceSpec1 = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) + .build(); + + DeviceSpec deviceSpec2 = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec deviceSpec1 = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); - - DeviceSpec deviceSpec2 = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); - - Ops tf1 = Ops.create(g).withSubScope("testScope1").withDevice(deviceSpec1); - Ops tf2 = Ops.create(g).withSubScope("testScope2").withDevice(deviceSpec2); - - Constant aOps = tf1.constant(-1); - Constant bOps = tf2.constant(10); - - Output absOps = tf1 - .withName("absWithDevice") - .math - .abs(aOps) - .asOutput(); - - Output mulOps = tf2 - .withName("mulWithDevice") - .math - .mul(absOps, bOps) - .asOutput(); - - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { - assertEquals(10, ((TInt32)t.get(0)).getInt()); + Ops tf1 = Ops.create(g).withSubScope("testScope1").withDevice(deviceSpec1); + Ops tf2 = Ops.create(g).withSubScope("testScope2").withDevice(deviceSpec2); + + Constant aOps = tf1.constant(-1); + Constant bOps = tf2.constant(10); + + Output absOps = tf1 + .withName("absWithDevice") + .math + .abs(aOps) + .asOutput(); + + Output mulOps = tf2 + .withName("mulWithDevice") + .math + .mul(absOps, bOps) + .asOutput(); + + List t = session.runner().fetch(mulOps).run(scope); + assertEquals(10, ((TInt32) t.get(0)).getInt()); } } } @Test public void withIncorrectDeviceSpec() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); + + try (Graph g = new Graph(); Session session = new Session(g, config)) { + DeviceSpec correctDeviceSpec = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec correctDeviceSpec = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); - - // Incorrect device spec, it will never be executed - DeviceSpec incorrectDeviceSpec = DeviceSpec.newBuilder() - .job("UNKNOWN") - .replica(1) - .task(1000) - .deviceType(DeviceType.TPU) - .build(); - - Ops tf = Ops.create(g); - - Constant aOps = tf.constant(-1); - Constant bOps = tf.constant(10); - - Output absOps = tf - .withName("absWithDevice") - .withDevice(incorrectDeviceSpec) - .math - .abs(aOps) - .asOutput(); - - Output mulOps = tf - .withName("mulWithDevice") - .withDevice(correctDeviceSpec) - .math - .mul(absOps, bOps) - .asOutput(); - - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { - fail(); - } catch (TFInvalidArgumentException e) { - // ok + // Incorrect device spec, it will never be executed + DeviceSpec incorrectDeviceSpec = DeviceSpec.newBuilder() + .job("UNKNOWN") + .replica(1) + .task(1000) + .deviceType(DeviceType.TPU) + .build(); + + Ops tf = Ops.create(g); + + Constant aOps = tf.constant(-1); + Constant bOps = tf.constant(10); + + Output absOps = tf + .withName("absWithDevice") + .withDevice(incorrectDeviceSpec) + .math + .abs(aOps) + .asOutput(); + + Output mulOps = tf + .withName("mulWithDevice") + .withDevice(correctDeviceSpec) + .math + .mul(absOps, bOps) + .asOutput(); + + try { + List t = session.runner().fetch(mulOps).run(scope); + fail(); + } catch (TFInvalidArgumentException e) { + // ok + } } } } @Test public void withDeviceSpecInScope() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); + + try (Graph g = new Graph(); Session session = new Session(g, config)) { + DeviceSpec deviceSpec = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec deviceSpec = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); - - Ops tf = Ops.create(g).withSubScope("testScope").withDevice(deviceSpec); + Ops tf = Ops.create(g).withSubScope("testScope").withDevice(deviceSpec); - Constant aOps = tf.constant(-1); + Constant aOps = tf.constant(-1); - Output absOps = tf - .withName("absWithDevice") - .math - .abs(aOps) - .asOutput(); + Output absOps = tf + .withName("absWithDevice") + .math + .abs(aOps) + .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + List t = session.runner().fetch(absOps).run(scope); + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java index b39ecec9881..0b474ec795a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -23,7 +23,9 @@ import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TInt32; -/** Unit tests for {@link EagerOperationBuilder} class. */ +/** + * Unit tests for {@link EagerOperationBuilder} class. + */ public class EagerOperationBuilderTest { @Test @@ -59,7 +61,7 @@ public void addInputs() { Operation asrt = opBuilder(session, "Assert", "assert") .addInput(tf.constant(true).asOutput()) - .addInputList(new Output[] {tf.constant(-1).asOutput()}) + .addInputList(new Output[]{tf.constant(-1).asOutput()}) .build(); opBuilder(session, "Const", "var").addControlInput(asrt); } @@ -79,62 +81,63 @@ public void setDevice() { @Test public void setAttrs() { - // The effect of setting an attribute may not easily be visible from the other parts of this - // package's API. Thus, for now, the test simply executes the various setAttr variants to see - // that there are no exceptions. - // - // This is a bit of an awkward test since it has to find operations with attributes of specific - // types that aren't inferred from the input arguments. - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); - // dtype, tensor attributes. - try (TInt32 t = TInt32.scalarOf(1)) { + try (TensorScope scope = new TensorScope()) { + // The effect of setting an attribute may not easily be visible from the other parts of this + // package's API. Thus, for now, the test simply executes the various setAttr variants to see + // that there are no exceptions. + // + // This is a bit of an awkward test since it has to find operations with attributes of specific + // types that aren't inferred from the input arguments. + try (EagerSession session = EagerSession.create()) { + Ops tf = Ops.create(session); + // dtype, tensor attributes. + TInt32 t = TInt32.scalarOf(scope, 1); opBuilder(session, "Const", "DataTypeAndTensor") .setAttr("dtype", t.dataType()) .setAttr("value", t) .build(); + // type, int (TF "int" attributes are 64-bit signed, so a Java long). + opBuilder(session, "RandomUniform", "DataTypeAndInt") + .addInput(tf.array(1).asOutput()) + .setAttr("seed", 10) + .setAttr("dtype", DataType.DT_FLOAT) + .build(); + // list(int), string + opBuilder(session, "MaxPool", "IntListAndString") + .addInput(tf.constant(new float[2][2][2][2]).asOutput()) + .setAttr("ksize", new long[]{1, 1, 1, 1}) + .setAttr("strides", new long[]{1, 1, 1, 1}) + .setAttr("padding", "SAME") + .build(); + // list(float), device + opBuilder(session, "FractionalMaxPool", "FloatList") + .addInput(tf.constant(new float[2][2][2][2]).asOutput()) + .setAttr("pooling_ratio", new float[]{1.0f, 1.44f, 1.73f, 1.0f}) + .build(); + // shape + opBuilder(session, "EnsureShape", "ShapeAttr") + .addInput(tf.constant(new int[2][2]).asOutput()) + .setAttr("shape", Shape.of(2, 2)) + .build(); + // list(shape) + opBuilder(session, "FIFOQueue", "queue") + .setAttr("component_types", new DataType[]{DataType.DT_INT32, DataType.DT_INT32}) + .setAttr("shapes", new Shape[]{Shape.of(2, 2), Shape.of(2, 2, 2)}) + .build(); + // bool + opBuilder(session, "All", "Bool") + .addInput(tf.constant(new boolean[]{true, true, false}).asOutput()) + .addInput(tf.constant(0).asOutput()) + .setAttr("keep_dims", false) + .build(); + // float + opBuilder(session, "ApproximateEqual", "Float") + .addInput(tf.constant(10.00001f).asOutput()) + .addInput(tf.constant(10.00000f).asOutput()) + .setAttr("tolerance", 0.1f) + .build(); + // Missing tests: list(string), list(byte), list(bool), list(type) } - // type, int (TF "int" attributes are 64-bit signed, so a Java long). - opBuilder(session, "RandomUniform", "DataTypeAndInt") - .addInput(tf.array(1).asOutput()) - .setAttr("seed", 10) - .setAttr("dtype", DataType.DT_FLOAT) - .build(); - // list(int), string - opBuilder(session, "MaxPool", "IntListAndString") - .addInput(tf.constant(new float[2][2][2][2]).asOutput()) - .setAttr("ksize", new long[] {1, 1, 1, 1}) - .setAttr("strides", new long[] {1, 1, 1, 1}) - .setAttr("padding", "SAME") - .build(); - // list(float), device - opBuilder(session, "FractionalMaxPool", "FloatList") - .addInput(tf.constant(new float[2][2][2][2]).asOutput()) - .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) - .build(); - // shape - opBuilder(session, "EnsureShape", "ShapeAttr") - .addInput(tf.constant(new int[2][2]).asOutput()) - .setAttr("shape", Shape.of(2, 2)) - .build(); - // list(shape) - opBuilder(session, "FIFOQueue", "queue") - .setAttr("component_types", new DataType[] {DataType.DT_INT32, DataType.DT_INT32}) - .setAttr("shapes", new Shape[] {Shape.of(2, 2), Shape.of(2, 2, 2)}) - .build(); - // bool - opBuilder(session, "All", "Bool") - .addInput(tf.constant(new boolean[] {true, true, false}).asOutput()) - .addInput(tf.constant(0).asOutput()) - .setAttr("keep_dims", false) - .build(); - // float - opBuilder(session, "ApproximateEqual", "Float") - .addInput(tf.constant(10.00001f).asOutput()) - .addInput(tf.constant(10.00000f).asOutput()) - .setAttr("tolerance", 0.1f) - .build(); - // Missing tests: list(string), list(byte), list(bool), list(type) } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java index 38714b86599..8ebb4789e8c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java @@ -50,7 +50,8 @@ public void failToCreateIfSessionIsClosed() { @Test public void outputDataTypeAndShape() { try (EagerSession session = EagerSession.create(); - TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { + TensorScope scope = new TensorScope()) { + TInt32 t = TInt32.tensorOf(scope, Shape.of(2, 3)); EagerOperation op = opBuilder(session, "Const", "OutputAttrs") .setAttr("dtype", t.dataType()) @@ -64,14 +65,15 @@ public void outputDataTypeAndShape() { @Test public void outputTensor() { - try (EagerSession session = EagerSession.create()) { + try (EagerSession session = EagerSession.create(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(session); EagerOperation add = opBuilder(session, "Add", "CompareResult") .addInput(tf.constant(2).asOutput()) .addInput(tf.constant(4).asOutput()) .build(); - assertEquals(6, ((TInt32)add.tensor(0)).getInt()); + assertEquals(6, ((TInt32) add.tensor(scope, 0)).getInt()); // Validate that we retrieve the right shape and datatype from the tensor // that has been resolved @@ -154,7 +156,8 @@ public void opNotAccessibleIfSessionIsClosed() { @Test public void outputIndexOutOfBounds() { - try (EagerSession session = EagerSession.create()) { + try (EagerSession session = EagerSession.create(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(session); EagerOperation add = opBuilder(session, "Add", "OutOfRange") @@ -180,7 +183,7 @@ public void outputIndexOutOfBounds() { // expected } try { - add.tensor(1); + add.tensor(scope, 1); fail(); } catch (IndexOutOfBoundsException e) { // expected diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index 33ae979ccbd..d19781d8fb1 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -27,13 +27,16 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.GraphOperationBuilder}. */ +/** + * Unit tests for {@link org.tensorflow.GraphOperationBuilder}. + */ public class GraphOperationBuilderTest { @Test public void failOnUseAfterBuild() { try (Graph g = new Graph(); - TInt32 t = TInt32.scalarOf(1)) { + TensorScope scope = new TensorScope()) { + TInt32 t = TInt32.scalarOf(scope, 1); OperationBuilder b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); b.build(); @@ -49,7 +52,8 @@ public void failOnUseAfterBuild() { public void failOnUseAfterGraphClose() { OperationBuilder b = null; try (Graph g = new Graph(); - TInt32 t = TInt32.scalarOf(1)) { + TensorScope scope = new TensorScope()) { + TInt32 t = TInt32.scalarOf(scope, 1); b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); } try { @@ -68,17 +72,17 @@ public void setAttr() { // // This is a bit of an awkward test since it has to find operations with attributes of specific // types that aren't inferred from the input arguments. - try (Graph g = new Graph()) { + try (Graph g = new Graph(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); // dtype, tensor attributes. - try (TInt32 t = TInt32.scalarOf(1)) { - g.opBuilder("Const", "DataTypeAndTensor") - .setAttr("dtype", t.dataType()) - .setAttr("value", t) - .build() - .output(0); - assertTrue(hasNode(g, "DataTypeAndTensor")); - } + TInt32 t = TInt32.scalarOf(scope, 1); + g.opBuilder("Const", "DataTypeAndTensor") + .setAttr("dtype", t.dataType()) + .setAttr("value", t) + .build() + .output(0); + assertTrue(hasNode(g, "DataTypeAndTensor")); // string, bool attributes. g.opBuilder("Abort", "StringAndBool") .setAttr("error_msg", "SomeErrorMessage") @@ -95,15 +99,15 @@ public void setAttr() { // list(int) g.opBuilder("MaxPool", "IntList") .addInput(tf.constant(new float[2][2][2][2]).asOutput()) - .setAttr("ksize", new long[] {1, 1, 1, 1}) - .setAttr("strides", new long[] {1, 1, 1, 1}) + .setAttr("ksize", new long[]{1, 1, 1, 1}) + .setAttr("strides", new long[]{1, 1, 1, 1}) .setAttr("padding", "SAME") .build(); assertTrue(hasNode(g, "IntList")); // list(float) g.opBuilder("FractionalMaxPool", "FloatList") .addInput(tf.constant(new float[2][2][2][2]).asOutput()) - .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) + .setAttr("pooling_ratio", new float[]{1.0f, 1.44f, 1.73f, 1.0f}) .build(); assertTrue(hasNode(g, "FloatList")); // Missing tests: float, list(dtype), list(tensor), list(string), list(bool) @@ -138,10 +142,10 @@ public void setAttrShape() { @Test public void setAttrShapeList() { // Those shapes match tensors ones, so no exception is thrown - testSetAttrShapeList(new Shape[] {Shape.of(2, 2), Shape.of(2, 2, 2)}); + testSetAttrShapeList(new Shape[]{Shape.of(2, 2), Shape.of(2, 2, 2)}); try { // Those shapes do not match tensors ones, exception is thrown - testSetAttrShapeList(new Shape[] {Shape.of(2, 2), Shape.of(2, 2, 2, 2)}); + testSetAttrShapeList(new Shape[]{Shape.of(2, 2), Shape.of(2, 2, 2, 2)}); fail("Shapes are incompatible and an exception was expected"); } catch (TFInvalidArgumentException e) { // expected @@ -152,23 +156,24 @@ public void setAttrShapeList() { public void addControlInput() { try (Graph g = new Graph(); Session s = new Session(g); - TBool yes = TBool.scalarOf(true); - TBool no = TBool.scalarOf(false)) { + TensorScope scope = new TensorScope()) { + TBool yes = TBool.scalarOf(scope, true); + TBool no = TBool.scalarOf(scope, false); Ops tf = Ops.create(g); Output placeholder = tf.placeholder(TBool.class).asOutput(); GraphOperation check = g.opBuilder("Assert", "assert") .addInput(placeholder) - .addInputList(new Output[] {placeholder}) + .addInputList(new Output[]{placeholder}) .build(); Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build(); // No problems when the Assert check succeeds - s.runner().feed(placeholder, yes).addTarget(noop).run(); + s.runner().feed(placeholder, yes).addTarget(noop).run(scope); // Exception thrown by the execution of the Assert node try { - s.runner().feed(placeholder, no).addTarget(noop).run(); + s.runner().feed(placeholder, no).addTarget(noop).run(scope); fail("Did not run control operation."); } catch (TFInvalidArgumentException e) { // expected @@ -178,26 +183,27 @@ public void addControlInput() { private static void testSetAttrShapeList(Shape[] shapes) { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); - int[][] matrix = new int[][] {{0, 0}, {0, 0}}; + int[][] matrix = new int[][]{{0, 0}, {0, 0}}; Output queue = g.opBuilder("FIFOQueue", "queue") - .setAttr("component_types", new DataType[] {DataType.DT_INT32, DataType.DT_INT32}) + .setAttr("component_types", new DataType[]{DataType.DT_INT32, DataType.DT_INT32}) .setAttr("shapes", shapes) .build() .output(0); assertTrue(hasNode(g, "queue")); Output c1 = tf.constant(matrix).asOutput(); - Output c2 = tf.constant(new int[][][] {matrix, matrix}).asOutput(); + Output c2 = tf.constant(new int[][][]{matrix, matrix}).asOutput(); Operation enqueue = g.opBuilder("QueueEnqueue", "enqueue") .addInput(queue) - .addInputList(new Output[] {c1, c2}) + .addInputList(new Output[]{c1, c2}) .build(); assertTrue(hasNode(g, "enqueue")); - s.runner().addTarget(enqueue).run(); + s.runner().addTarget(enqueue).run(scope); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java index b164c129745..7bc0e0953fd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java @@ -29,7 +29,9 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.GraphOperation}. */ +/** + * Unit tests for {@link org.tensorflow.GraphOperation}. + */ public class GraphOperationTest { @Test @@ -53,8 +55,8 @@ public void operationEquality() { GraphOperation op1; try (Graph g = new Graph()) { Ops tf = Ops.create(g); - op1 = (GraphOperation)tf.withName("op1").constant(1).op(); - GraphOperation op2 = (GraphOperation)tf.withName("op2").constant(2).op(); + op1 = (GraphOperation) tf.withName("op1").constant(1).op(); + GraphOperation op2 = (GraphOperation) tf.withName("op2").constant(2).op(); GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle()); GraphOperation op4 = g.operation("op1"); assertEquals(op1, op1); @@ -78,8 +80,8 @@ public void operationEquality() { public void operationCollection() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); - GraphOperation op1 = (GraphOperation)tf.withName("op1").constant(1).op(); - GraphOperation op2 = (GraphOperation)tf.withName("op2").constant(2).op(); + GraphOperation op1 = (GraphOperation) tf.withName("op1").constant(1).op(); + GraphOperation op2 = (GraphOperation) tf.withName("op2").constant(2).op(); GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle()); GraphOperation op4 = g.operation("op1"); Set ops = new HashSet<>(); @@ -179,11 +181,12 @@ public void outputList() { @Test public void outputTensorNotSupported() { - try (Graph g = new Graph()) { + try (Graph g = new Graph(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Operation split = tf.split(tf.constant(0), tf.array(0, 1, 2), 3L).op(); try { - split.output(0).asTensor(); + split.output(0).asTensor(scope); fail(); } catch (IllegalStateException e) { } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index d8ffc1a475b..838c6ea1279 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; @@ -32,7 +33,9 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.Graph}. */ +/** + * Unit tests for {@link org.tensorflow.Graph}. + */ public class GraphTest { @Test @@ -138,7 +141,8 @@ public void failOnUseAfterClose() { @Test public void addGradientsToGraph() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x1 = tf.placeholder(TFloat32.class).output(); @@ -146,7 +150,7 @@ public void addGradientsToGraph() { Output y0 = tf.math.square(x1).y(); Output y1 = tf.math.square(y0).y(); Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum(); - + Output[] grads0 = g.addGradients(y1, toArray(x1)); assertNotNull(grads0); assertEquals(1, grads0.length); @@ -157,30 +161,29 @@ public void addGradientsToGraph() { assertEquals(2, grads1.length); assertEquals(DataType.DT_FLOAT, grads1[0].dataType()); assertEquals(DataType.DT_FLOAT, grads1[1].dataType()); - - try (TFloat32 c1 = TFloat32.scalarOf(3.0f); - TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = new AutoCloseableList<>( - s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run())) { - - assertEquals(3, outputs.size()); - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - assertEquals(6.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); - assertEquals(1.0f, ((TFloat32)outputs.get(2)).getFloat(), 0.0f); - } + + TFloat32 c1 = TFloat32.scalarOf(scope, 3.0f); + TFloat32 c2 = TFloat32.scalarOf(scope, 2.0f); + List outputs = s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run(scope); + + assertEquals(3, outputs.size()); + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); + assertEquals(1.0f, ((TFloat32) outputs.get(2)).getFloat(), 0.0f); } } @Test public void addGradientSumsToGraph() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); @@ -192,27 +195,27 @@ public void addGradientSumsToGraph() { assertEquals(1, grad.length); assertEquals(DataType.DT_FLOAT, grad[0].dataType()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - TFloat32 output = (TFloat32)s.runner() - .feed(x, c) - .fetch(grad[0]) - .run() - .get(0)) { - assertEquals(114.0f, output.getFloat(), 0.0f); - } + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + TFloat32 output = (TFloat32) s.runner() + .feed(x, c) + .fetch(grad[0]) + .run(scope) + .get(0); + assertEquals(114.0f, output.getFloat(), 0.0f); } } @Test public void addGradientsWithInitialValuesToGraph() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); - + Output[] grad0 = g.addGradients(y1, toArray(y0)); assertNotNull(grad0); assertEquals(1, grad0.length); @@ -223,14 +226,13 @@ public void addGradientsWithInitialValuesToGraph() { assertEquals(1, grad1.length); assertEquals(DataType.DT_FLOAT, grad1[0].dataType()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - TFloat32 output = (TFloat32)s.runner() - .feed(x, c) - .fetch(grad1[0]) - .run() - .get(0)) { - assertEquals(108.0f, output.getFloat(), 0.0f); - } + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + TFloat32 output = (TFloat32) s.runner() + .feed(x, c) + .fetch(grad1[0]) + .run(scope) + .get(0); + assertEquals(108.0f, output.getFloat(), 0.0f); } } @@ -265,7 +267,8 @@ public void validateGradientsNames() { @Test public void buildWhileLoopSingleInput() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output input = tf.placeholder(TInt32.class).output(); @@ -275,29 +278,29 @@ public void buildWhileLoopSingleInput() { toArray(input), (condGraph, condInputs, condOutputs) -> { Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); + condOutputs[0] = tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); }, (bodyGraph, bodyInputs, bodyOutputs) -> { Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); }, "test_loop"); - try (TInt32 c = TInt32.scalarOf(2); - TInt32 output = (TInt32)s.runner() - .feed(input, c) - .fetch(loopOutputs[0]) - .run() - .get(0)) { - assertEquals(16, output.getInt()); // ((2^2)^2) - } + TInt32 c = TInt32.scalarOf(scope, 2); + TInt32 output = (TInt32) s.runner() + .feed(input, c) + .fetch(loopOutputs[0]) + .run(scope) + .get(0); + assertEquals(16, output.getInt()); // ((2^2)^2) } } @Test public void buildWhileLoopMultipleInputs() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output input1 = tf.placeholder(TInt32.class).output(); @@ -309,29 +312,26 @@ public void buildWhileLoopMultipleInputs() { inputs, (condGraph, condInputs, condOutputs) -> { Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); + condOutputs[0] = tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); }, (bodyGraph, bodyInputs, bodyOutputs) -> { Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); - bodyOutputs[1] = tfb.math.square((Output)bodyInputs[1]).y(); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); + bodyOutputs[1] = tfb.math.square((Output) bodyInputs[1]).y(); }, "test_loop"); - try (TInt32 c1 = TInt32.scalarOf(2); - TInt32 c2 = TInt32.scalarOf(5); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() - .feed(input1, c1) - .feed(input2, c2) - .fetch(loopOutputs[0]) - .fetch(loopOutputs[1]) - .run())) { - assertEquals(2, outputs.size()); - assertEquals(16, ((TInt32)outputs.get(0)).getInt()); // ((2^2)^2) - assertEquals(625, ((TInt32)outputs.get(1)).getInt()); // ((5^2)^2) - } + TInt32 c1 = TInt32.scalarOf(scope, 2); + TInt32 c2 = TInt32.scalarOf(scope, 5); + List outputs = s.runner() + .feed(input1, c1) + .feed(input2, c2) + .fetch(loopOutputs[0]) + .fetch(loopOutputs[1]) + .run(scope); + assertEquals(2, outputs.size()); + assertEquals(16, ((TInt32) outputs.get(0)).getInt()); // ((2^2)^2) + assertEquals(625, ((TInt32) outputs.get(1)).getInt()); // ((5^2)^2) } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java index 0d2d8af8b1c..7c365517aa3 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java @@ -30,61 +30,66 @@ public class RawTensorTest { @Test public void rawToTypedTensor() { - RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), -1); - TFloat32 floatTensor = (TFloat32)rawTensor.asTypedTensor(); - assertSame(floatTensor.asRawTensor(), rawTensor); - try { - TInt32 intTensor = (TInt32)rawTensor.asTypedTensor(); - fail(); - } catch (ClassCastException e) { - // ok + try (TensorScope scope = new TensorScope()) { + RawTensor rawTensor = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), -1); + TFloat32 floatTensor = (TFloat32) rawTensor.asTypedTensor(); + assertSame(floatTensor.asRawTensor(), rawTensor); + try { + TInt32 intTensor = (TInt32) rawTensor.asTypedTensor(); + fail(); + } catch (ClassCastException e) { + // ok + } } } @Test public void allocateTensorWithSize() { - try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), 16)) { + try (TensorScope scope = new TensorScope()) { + RawTensor rawTensor = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), 16); assertEquals(16, rawTensor.numBytes()); - } - try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), 100)) { + rawTensor = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), 100); assertEquals(100, rawTensor.numBytes()); - } - try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), 10)) { - fail(); - } catch (IllegalArgumentException e) { - // ok - } - try (RawTensor rawTensor = RawTensor.allocate(TString.class, Shape.of(2, 2), 100)) { + try (RawTensor rawTensor2 = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), 10)) { + fail(); + } catch (IllegalArgumentException e) { + // ok + } + rawTensor = RawTensor.allocate(scope, TString.class, Shape.of(2, 2), 100); assertEquals(100, rawTensor.numBytes()); } } @Test public void allocateTensorWithoutSize() { - try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), -1)) { - assertEquals(16, rawTensor.numBytes()); - // ok - } - try (RawTensor rawTensor = RawTensor.allocate(TString.class, Shape.of(2, 2), -1)) { - fail(); - } catch (IllegalArgumentException e) { - // ok + try (TensorScope scope = new TensorScope()) { + try (RawTensor rawTensor = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), -1)) { + assertEquals(16, rawTensor.numBytes()); + // ok + } + try (RawTensor rawTensor = RawTensor.allocate(scope, TString.class, Shape.of(2, 2), -1)) { + fail(); + } catch (IllegalArgumentException e) { + // ok + } } } @Test public void failToAllocateTensorFromUnknownShape() { - try { - RawTensor.allocate(TFloat32.class, Shape.of(3, -1, 3), -1); - fail(); - } catch (IllegalArgumentException e) { - // ok - } - try { - RawTensor.allocate(TString.class, Shape.unknown(), 100); - fail(); - } catch (IllegalArgumentException e) { - // ok + try (TensorScope scope = new TensorScope()) { + try { + RawTensor.allocate(scope, TFloat32.class, Shape.of(3, -1, 3), -1); + fail(); + } catch (IllegalArgumentException e) { + // ok + } + try { + RawTensor.allocate(scope, TString.class, Shape.unknown(), 100); + fail(); + } catch (IllegalArgumentException e) { + // ok + } } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index cd8ac7e2ae4..726294547b5 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -27,8 +27,8 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Collections; -import java.util.Map; import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.ndarray.FloatNdArray; @@ -46,7 +46,9 @@ import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.types.TFloat32; -/** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ +/** + * Unit tests for {@link org.tensorflow.SavedModelBundle}. + */ public class SavedModelBundleTest { private static final float EPSILON = 1e-7f; @@ -56,7 +58,8 @@ public class SavedModelBundleTest { static { try { SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString(); - SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()).toString(); + SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()) + .toString(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -102,15 +105,15 @@ public void exportFunctionWithVariables() throws IOException { float reducedSum; FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); Shape xyShape = Shape.of(2, 3L); - try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { + try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape)); + TensorScope scope = new TensorScope()) { // Init variable state by running the Init operation directly f.session().run(Init.DEFAULT_NAME); // Call the graph and remember the result of computation for later - try (TFloat32 xTensor = TFloat32.tensorOf(xValue); - TFloat32 zTensor = (TFloat32)f.call(xTensor)) { - reducedSum = zTensor.getFloat(); - } + TFloat32 xTensor = TFloat32.tensorOf(scope, xValue); + TFloat32 zTensor = (TFloat32) f.call(scope, xTensor); + reducedSum = zTensor.getFloat(); // Save/export the model (which is a single function in this case) f.save(testFolder.toString()); } @@ -121,7 +124,8 @@ public void exportFunctionWithVariables() throws IOException { // Reload the model just saved and validate its data try (SavedModelBundle savedModel = - SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) { + SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG); + TensorScope scope = new TensorScope()) { assertNotNull(savedModel.metaGraphDef()); assertNotNull(savedModel.metaGraphDef().getSaverDef()); assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); @@ -153,17 +157,14 @@ public void exportFunctionWithVariables() throws IOException { assertNotNull(outputInfo); assertEquals(0, outputInfo.getTensorShape().getDimCount()); - try (TFloat32 xTensor = TFloat32.tensorOf(xValue)) { - // Call the saved model function and make sure it returns the same result as before - try (TFloat32 zTensor = (TFloat32)function.call(xTensor)) { - assertEquals(reducedSum, zTensor.getFloat(), EPSILON); - } - // Now call the same function directly from the model - try (TFloat32 zTensor = - (TFloat32)savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { - assertEquals(reducedSum, zTensor.getFloat(), EPSILON); - } - } + TFloat32 xTensor = TFloat32.tensorOf(scope, xValue); + // Call the saved model function and make sure it returns the same result as before + TFloat32 zTensor = (TFloat32) function.call(scope, xTensor); + assertEquals(reducedSum, zTensor.getFloat(), EPSILON); + // Now call the same function directly from the model + TFloat32 zTensor2 = + (TFloat32) savedModel.call(scope, Collections.singletonMap("input", xTensor)).get("reducedSum"); + assertEquals(reducedSum, zTensor2.getFloat(), EPSILON); } } @@ -177,32 +178,32 @@ public void exportMultipleFunctions() throws IOException { Signature f2Signature = buildIdentityGraph(tf, "identity"); try (Session s = new Session(g); ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); - ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s); + TensorScope scope = new TensorScope()) { f1.session().run(Init.DEFAULT_NAME); - try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - TFloat32 t = (TFloat32)f1.call(x)) { - reducedSum = t.getFloat(); - } + TFloat32 x = TFloat32.tensorOf(scope, StdArrays.ndCopyOf(new float[]{2, 2})); + TFloat32 t = (TFloat32) f1.call(scope, x); + reducedSum = t.getFloat(); SavedModelBundle.exporter(testFolder.toString()) .withFunction(f1) .withFunction(f2) .export(); } } - try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { + try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString()); + TensorScope scope = new TensorScope()) { assertEquals(2, model.signatures().size()); ConcreteFunction f1 = model.function(Signature.DEFAULT_KEY); assertNotNull(f1); - try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - TFloat32 t = (TFloat32)f1.call(x)) { + try (TFloat32 x = TFloat32.tensorOf(scope, StdArrays.ndCopyOf(new float[]{2, 2})); + TFloat32 t = (TFloat32) f1.call(scope, x)) { assertEquals(reducedSum, t.getFloat(), EPSILON); } ConcreteFunction f2 = model.function("identity"); assertNotNull(f2); - try (TFloat32 x = TFloat32.scalarOf(10.0f); - TFloat32 t = (TFloat32)f2.call(x)) { - assertEquals(10.0f, t.getFloat(), 0.0f); - } + TFloat32 x = TFloat32.scalarOf(scope, 10.0f); + TFloat32 t = (TFloat32) f2.call(scope, x); + assertEquals(10.0f, t.getFloat(), 0.0f); try { model.function("NoSuchFunction"); fail(); @@ -284,23 +285,22 @@ public void cannotExportOrImportInvalidTags() { @Test public void pythonTfFunction() { // ConcreteFunctions on models saved using python - try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PY_PATH, "serve")) { + try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PY_PATH, "serve"); + TensorScope scope = new TensorScope()) { /* * Test model was created in python * Signature name used for saving 'add', argument names 'a' and 'b' */ ConcreteFunction add = bundle.function("add"); Map args = new HashMap(); - try (TFloat32 a = TFloat32.scalarOf(10.0f); - TFloat32 b = TFloat32.scalarOf(15.5f)) { - args.put("a", a); - args.put("b", b); - Map result = add.call(args); - assertEquals(result.size(), 1); - try (TFloat32 c = (TFloat32)result.values().iterator().next()) { - assertEquals(25.5f, c.getFloat()); - } - } + TFloat32 a = TFloat32.scalarOf(scope, 10.0f); + TFloat32 b = TFloat32.scalarOf(scope, 15.5f); + args.put("a", a); + args.put("b", b); + Map result = add.call(scope, args); + assertEquals(result.size(), 1); + TFloat32 c = (TFloat32) result.values().iterator().next(); + assertEquals(25.5f, c.getFloat()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index d7ea381d315..ca922cec6b8 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -26,9 +26,13 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Comparator; - +import java.util.List; import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Split; @@ -38,120 +42,114 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.GraphDef; import org.tensorflow.proto.framework.RunOptions; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.StdArrays; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.Session}. */ +/** + * Unit tests for {@link org.tensorflow.Session}. + */ public class SessionTest { @Test public void runUsingOperationNames() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { - assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); - } + transpose_A_times_X(tf, new int[][]{{2}, {3}}); + TInt32 x = TInt32.tensorOf(scope, StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + List outputs = s.runner().feed("X", x).fetch("Y").run(scope); + assertEquals(1, outputs.size()); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } } @Test public void runUsingOperationHandles() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); + transpose_A_times_X(tf, new int[][]{{2}, {3}}); Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { - assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); - } + TInt32 x = TInt32.tensorOf(scope, StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + List outputs = s.runner().feed(feed, x).fetch(fetch).run(scope); + assertEquals(1, outputs.size()); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } } @Test public void runUsingColonSeparatedNames() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Split split = tf.split(tf.constant(0), tf.array(1, 2, 3, 4), 2L); tf.math.add(split.output().get(0), split.output().get(1)); // Fetch using colon separated names. - try (TInt32 fetched = (TInt32)s.runner().fetch("Split:1").run().get(0)) { - assertEquals(3, fetched.getInt(0)); - assertEquals(4, fetched.getInt(1)); - } + TInt32 fetched = (TInt32) s.runner().fetch("Split:1").run(scope).get(0); + assertEquals(3, fetched.getInt(0)); + assertEquals(4, fetched.getInt(1)); // Feed using colon separated names. - try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1); - TInt32 fetched = (TInt32) s.runner() - .feed("Split:0", fed) - .feed("Split:1", fed) - .fetch("Add") - .run() - .get(0)) { - assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched); - } + TInt32 fed = TInt32.vectorOf(scope, 4, 3, 2, 1); + TInt32 fetched2 = (TInt32) s.runner() + .feed("Split:0", fed) + .feed("Split:1", fed) + .fetch("Add") + .run(scope) + .get(0); + assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched2); } } @Test public void runWithMetadata() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { - Session.Run result = s.runner() - .feed("X", x) - .fetch("Y") - .setOptions(fullTraceRunOptions()) - .runAndFetchMetadata(); - // Sanity check on outputs. - AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); - assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); - // Sanity check on metadata - assertNotNull(result.metadata); - assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); - outputs.close(); - } + transpose_A_times_X(tf, new int[][]{{2}, {3}}); + TInt32 x = TInt32.tensorOf(scope, StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + Session.Run result = s.runner() + .feed("X", x) + .fetch("Y") + .setOptions(fullTraceRunOptions()) + .runAndFetchMetadata(scope); + // Sanity check on outputs. + assertEquals(1, result.outputs.size()); + assertEquals(31, ((TInt32) result.outputs.get(0)).getInt(0, 0)); + // Sanity check on metadata + assertNotNull(result.metadata); + assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); } } @Test public void runMultipleOutputs() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); tf.withName("c1").constant(2718); tf.withName("c2").constant(31415); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); + List outputs = s.runner().fetch("c2").fetch("c1").run(scope); assertEquals(2, outputs.size()); - assertEquals(31415, ((TInt32)outputs.get(0)).getInt()); - assertEquals(2718, ((TInt32)outputs.get(1)).getInt()); - outputs.close(); + assertEquals(31415, ((TInt32) outputs.get(0)).getInt()); + assertEquals(2718, ((TInt32) outputs.get(1)).getInt()); } } @Test public void failOnUseAfterClose() { - try (Graph g = new Graph()) { + try (Graph g = new Graph(); + TensorScope scope = new TensorScope()) { Session s = new Session(g); s.close(); try { - s.runner().run(); + s.runner().run(scope); fail("methods on a session should fail after close() is called"); } catch (IllegalStateException e) { // expected exception @@ -162,7 +160,8 @@ public void failOnUseAfterClose() { @Test public void createWithConfigProto() { try (Graph g = new Graph(); - Session s = new Session(g, singleThreadConfigProto())) {} + Session s = new Session(g, singleThreadConfigProto())) { + } } @Test @@ -175,12 +174,12 @@ public void runInit() { Variable var2 = tf.variable(tf.constant(20)); Add add = tf.math.add(var1, var2); - try (Session s = new Session(g)) { + try (Session s = new Session(g); + TensorScope scope = new TensorScope()) { s.run(tf.init()); - try (TInt32 t = (TInt32) s.runner().fetch(add).run().get(0)) { - assertEquals(30, t.getInt()); - } + TInt32 t = (TInt32) s.runner().fetch(add).run(scope).get(0); + assertEquals(30, t.getInt()); } } } @@ -196,12 +195,12 @@ public void runInitByName() { Add add = tf.math.add(var1, var2); tf.withName("init_test").init(); - try (Session s = new Session(g)) { + try (Session s = new Session(g); + TensorScope scope = new TensorScope()) { s.run("init_test"); - try (TInt32 t = (TInt32) s.runner().fetch(add).run().get(0)) { - assertEquals(30, t.getInt()); - } + TInt32 t = (TInt32) s.runner().fetch(add).run(scope).get(0); + assertEquals(30, t.getInt()); try { s.run("wrong_name"); fail(); @@ -213,9 +212,10 @@ public void runInitByName() { } @Test - public void saveAndRestore() throws IOException { + public void saveAndRestore() throws IOException { Path testFolder = Files.createTempDirectory("tf-session-save-restore-test"); - try (Graph g = new Graph()) { + try (Graph g = new Graph(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Variable x = tf.withName("x").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); Variable y = tf.withName("y").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); @@ -230,11 +230,10 @@ public void saveAndRestore() throws IOException { restoredGraph.importGraphDef(graphDef); try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); - try (AutoCloseableList oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); - AutoCloseableList newList = new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())){ - assertEquals(oldList.get(0),newList.get(0)); - assertEquals(oldList.get(1),newList.get(1)); - } + List oldList = s.runner().fetch("x").fetch("y").run(scope); + List newList = restoredSession.runner().fetch("x").fetch("y").run(scope); + assertEquals(oldList.get(0),newList.get(0)); + assertEquals(oldList.get(1),newList.get(1)); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java index a9c6d3774fa..8c285392d7b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -19,9 +19,6 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.ndarray.Shape; import org.tensorflow.types.TFloat32; @@ -31,8 +28,8 @@ */ public class TensorScopeTest { - private static TFloat32 makeTensor(long size) { - return TFloat32.tensorOf(Shape.of(size), x -> { + private static TFloat32 makeTensor(TensorScope scope, long size) { + return TFloat32.tensorOf(scope, Shape.of(size), x -> { for (long i = 0; i < size; i++) { x.setFloat(0, i); } @@ -43,8 +40,8 @@ private static TFloat32 makeTensor(long size) { public void testBasicScope() { TensorScope scope = new TensorScope(); - TFloat32 tensor = makeTensor(10); - TFloat32 detachTensor = makeTensor(10); + TFloat32 tensor = makeTensor(scope, 10); + TFloat32 detachTensor = makeTensor(scope, 10); detachTensor.detach(); assertTrue(tensor.isAttached()); @@ -61,34 +58,10 @@ public void testBasicScope() { detachTensor.close(); } - @Test - public void testNestedScope() { - TensorScope outerScope = new TensorScope(); - TensorScope scope = new TensorScope(); - - TFloat32 tensor = makeTensor(10); - TFloat32 detachTensor = makeTensor(10); - detachTensor.detach(); - - assertTrue(tensor.isAttached()); - assertFalse(tensor.isClosed()); - - assertFalse(detachTensor.isAttached()); - assertFalse(detachTensor.isClosed()); - - outerScope.close(); - - assertTrue(tensor.isClosed()); - assertTrue(scope.isClosed()); - assertTrue(outerScope.isClosed()); - assertFalse(detachTensor.isClosed()); - detachTensor.close(); - } - @Test public void testAttach() { TensorScope firstScope = new TensorScope(); - TFloat32 tensor = makeTensor(10); + TFloat32 tensor = makeTensor(firstScope, 10); TensorScope secondScope = new TensorScope().withTensors(tensor); assertTrue(tensor.isAttached()); @@ -100,127 +73,4 @@ public void testAttach() { firstScope.close(); } - @Test - public void testReleaseToParentScope() { - TensorScope outerScope = new TensorScope(); - TensorScope scope = new TensorScope(); - - TFloat32 tensor = makeTensor(10); - - assertTrue(tensor.isAttached()); - assertFalse(tensor.isClosed()); - - scope.releaseAllToParent(); - - assertTrue(scope.isClosed()); - assertTrue(tensor.isAttached()); - assertFalse(tensor.isClosed()); - - outerScope.close(); - - assertTrue(tensor.isClosed()); - assertTrue(outerScope.isClosed()); - } - - @Test - public void testAttachToParentScope() { - TensorScope outerScope = new TensorScope(); - TensorScope scope = new TensorScope(); - - TFloat32 tensor = makeTensor(10); - - assertTrue(tensor.isAttached()); - assertFalse(tensor.isClosed()); - - scope.release(tensor); - - scope.close(); - - assertTrue(scope.isClosed()); - assertTrue(tensor.isAttached()); - assertFalse(tensor.isClosed()); - - outerScope.close(); - - assertTrue(tensor.isClosed()); - assertTrue(outerScope.isClosed()); - } - - @Test - public void testWithCleanup() { - final Tensor[] tensor = new Tensor[1]; - TensorScope.withCleanup(() -> { - tensor[0] = makeTensor(2); - }); - assertTrue(tensor[0].isClosed()); - } - - @Test - public void testGetWithCleanup() { - Tensor tensor = TensorScope.getWithCleanup(() -> makeTensor(2)); - assertTrue(tensor.isClosed()); - } - - @Test - public void testProduceTensorWithCleanup() { - final Tensor[] closedTensor = new Tensor[1]; - Tensor openTensor = TensorScope.produceTensorWithCleanup(() -> { - closedTensor[0] = makeTensor(2); - return makeTensor(3); - }); - - assertTrue(closedTensor[0].isClosed()); - assertFalse(openTensor.isClosed()); - openTensor.close(); - } - - private static class TestTensorContainer implements TensorContainer { - - private final List tensors; - - TestTensorContainer(List tensors) { - this.tensors = tensors; - } - - @SafeVarargs - TestTensorContainer(T... tensors) { - this(Arrays.asList(tensors)); - } - - @Override - public Iterable tensors() { - return tensors; - } - - public List getTensors() { - return tensors; - } - } - - @Test - public void testProduceTensorContainerWithCleanup() { - final TestTensorContainer[] closedTensor = new TestTensorContainer[1]; - TestTensorContainer openTensor = TensorScope.produceTensorContainerWithCleanup(() -> { - closedTensor[0] = new TestTensorContainer<>(makeTensor(2)); - return new TestTensorContainer<>(makeTensor(3)); - }); - - assertTrue(closedTensor[0].getTensors().get(0).isClosed()); - assertFalse(openTensor.getTensors().get(0).isClosed()); - openTensor.getTensors().get(0).close(); - } - - @Test - public void testProduceTensorsWithCleanup(){ - final List[] closedTensor = new List[1]; - List openTensor = TensorScope.produceTensorsWithCleanup(() -> { - closedTensor[0] = Collections.singletonList(makeTensor(2)); - return Collections.singletonList(makeTensor(2)); - }); - - assertTrue(closedTensor[0].get(0).isClosed()); - assertFalse(openTensor.get(0).isClosed()); - openTensor.get(0).close(); - } - } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java index 6b9cb202b97..c6f26afaeda 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java @@ -40,7 +40,6 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.DataBuffers; -import org.tensorflow.op.Ops; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; @@ -50,69 +49,73 @@ import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; -/** Unit tests for {@link org.tensorflow.Tensor}. */ +/** + * Unit tests for {@link org.tensorflow.Tensor}. + */ public class TensorTest { + private static final double EPSILON = 1e-7; private static final float EPSILON_F = 1e-7f; @Test public void createWithRawData() { - double[] doubles = {1d, 2d, 3d, 4d}; - Shape doubles_shape = Shape.of(4); - boolean[] bools = {true, false, true, false}; - byte[] bools_ = {1, 0, 1, 0}; - Shape bools_shape = Shape.of(4); - String strings = "test"; - Shape strings_shape = Shape.scalar(); - byte[] strings_; // raw TF_STRING - try (TString t = TString.tensorOf(NdArrays.scalarOfObject(strings))) { - strings_ = new byte[(int)t.numBytes()]; + try (TensorScope scope = new TensorScope()) { + double[] doubles = {1d, 2d, 3d, 4d}; + Shape doubles_shape = Shape.of(4); + boolean[] bools = {true, false, true, false}; + byte[] bools_ = {1, 0, 1, 0}; + Shape bools_shape = Shape.of(4); + String strings = "test"; + Shape strings_shape = Shape.scalar(); + byte[] strings_; // raw TF_STRING + TString t = TString.tensorOf(scope, NdArrays.scalarOfObject(strings)); + strings_ = new byte[(int) t.numBytes()]; t.asRawTensor().data().read(strings_); - } - // validate creating a tensor using a raw data byte buffers - { - try (TBool t = Tensor.of(TBool.class, bools_shape, DataBuffers.of(bools_))) { + // validate creating a tensor using a raw data byte buffers + { + TBool t1 = Tensor.of(scope, TBool.class, bools_shape, DataBuffers.of(bools_)); boolean[] actual = new boolean[bools_.length]; - t.read(DataBuffers.of(actual)); + t1.read(DataBuffers.of(actual)); assertArrayEquals(bools, actual); - } - // note: the buffer is expected to contain raw TF_STRING (as per C API) - try (TString t = Tensor.of(TString.class, strings_shape, DataBuffers.of(strings_))) { - assertEquals(strings, t.getObject()); + // note: the buffer is expected to contain raw TF_STRING (as per C API) + TString t2 = Tensor.of(scope, TString.class, strings_shape, DataBuffers.of(strings_)); + assertEquals(strings, t2.getObject()); } - } - // validate creating a tensor using a direct byte buffer (in host order) - { - DoubleBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder()) - .asDoubleBuffer().put(doubles); - try (TFloat64 t = TFloat64.tensorOf(doubles_shape, d -> d.write(DataBuffers.of(buf)))) { + // validate creating a tensor using a direct byte buffer (in host order) + { + DoubleBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder()) + .asDoubleBuffer().put(doubles); + TFloat64 t1 = TFloat64.tensorOf(scope, doubles_shape, d -> d.write(DataBuffers.of(buf))); double[] actual = new double[doubles.length]; - t.read(DataBuffers.of(actual)); + t1.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); } - } - // validate shape checking - try (TBool t = Tensor.of(TBool.class, Shape.of(bools_.length * 2), DataBuffers.of(bools_))) { - fail("should have failed on incompatible buffer"); - } catch (IllegalArgumentException e) { - // expected + // validate shape checking + try { + TBool t1 = Tensor.of(scope, TBool.class, Shape.of(bools_.length * 2), DataBuffers.of(bools_)); + fail("should have failed on incompatible buffer"); + } catch (IllegalArgumentException e) { + // expected + } } + } @Test public void createFromBufferWithNativeByteOrder() { - double[] doubles = {1d, 2d, 3d, 4d}; - DoubleBuffer buf = - ByteBuffer.allocate(8 * doubles.length) - .order(ByteOrder.nativeOrder()) - .asDoubleBuffer() - .put(doubles); - flipBuffer(buf); - try (TFloat64 t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { + try (TensorScope scope = new TensorScope()) { + double[] doubles = {1d, 2d, 3d, 4d}; + DoubleBuffer buf = + ByteBuffer.allocate(8 * doubles.length) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer() + .put(doubles); + flipBuffer(buf); + TFloat64 t = TFloat64.tensorOf(scope, Shape.of(4), DataBuffers.of(buf)); double[] actual = new double[doubles.length]; t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); @@ -121,17 +124,18 @@ public void createFromBufferWithNativeByteOrder() { @Test public void createFromBufferWithNonNativeByteOrder() { - double[] doubles = {1d, 2d, 3d, 4d}; - DoubleBuffer buf = - ByteBuffer.allocate(8 * doubles.length) - .order( - ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN - ? ByteOrder.BIG_ENDIAN - : ByteOrder.LITTLE_ENDIAN) - .asDoubleBuffer() - .put(doubles); - flipBuffer(buf); - try (TFloat64 t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { + try (TensorScope scope = new TensorScope()) { + double[] doubles = {1d, 2d, 3d, 4d}; + DoubleBuffer buf = + ByteBuffer.allocate(8 * doubles.length) + .order( + ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN + ? ByteOrder.BIG_ENDIAN + : ByteOrder.LITTLE_ENDIAN) + .asDoubleBuffer() + .put(doubles); + flipBuffer(buf); + TFloat64 t = TFloat64.tensorOf(scope, Shape.of(4), DataBuffers.of(buf)); double[] actual = new double[doubles.length]; t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); @@ -140,75 +144,78 @@ public void createFromBufferWithNonNativeByteOrder() { @Test public void createWithTypedBuffer() { - IntBuffer ints = IntBuffer.wrap(new int[]{1, 2, 3, 4}); - FloatBuffer floats = FloatBuffer.wrap(new float[]{1f, 2f, 3f, 4f}); - DoubleBuffer doubles = DoubleBuffer.wrap(new double[]{1d, 2d, 3d, 4d}); - LongBuffer longs = LongBuffer.wrap(new long[]{1L, 2L, 3L, 4L}); - - // validate creating a tensor using a typed buffer - { - Shape shape = Shape.of(4); - try (TFloat64 t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { - DoubleBuffer actual = DoubleBuffer.allocate(doubles.capacity()); - t.read(DataBuffers.of(actual)); - assertEquals(doubles, actual); - } - try (TFloat32 t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { - FloatBuffer actual = FloatBuffer.allocate(floats.capacity()); - t.read(DataBuffers.of(actual)); - assertEquals(floats, actual); - } - try (TInt32 t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { - IntBuffer actual = IntBuffer.allocate(ints.capacity()); - t.read(DataBuffers.of(actual)); - assertEquals(ints, actual); - } - try (TInt64 t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { - LongBuffer actual = LongBuffer.allocate(longs.capacity()); - t.read(DataBuffers.of(actual)); - assertEquals(longs, actual); - } - } + try (TensorScope scope = new TensorScope()) { + IntBuffer ints = IntBuffer.wrap(new int[]{1, 2, 3, 4}); + FloatBuffer floats = FloatBuffer.wrap(new float[]{1f, 2f, 3f, 4f}); + DoubleBuffer doubles = DoubleBuffer.wrap(new double[]{1d, 2d, 3d, 4d}); + LongBuffer longs = LongBuffer.wrap(new long[]{1L, 2L, 3L, 4L}); + + // validate creating a tensor using a typed buffer + { + Shape shape = Shape.of(4); + TFloat64 tDouble = TFloat64.tensorOf(scope, shape, DataBuffers.of(doubles)); + DoubleBuffer doubleActual = DoubleBuffer.allocate(doubles.capacity()); + tDouble.read(DataBuffers.of(doubleActual)); + assertEquals(doubles, doubleActual); + + TFloat32 tFloat = TFloat32.tensorOf(scope, shape, DataBuffers.of(floats)); + FloatBuffer floatActual = FloatBuffer.allocate(floats.capacity()); + tFloat.read(DataBuffers.of(floatActual)); + assertEquals(floats, floatActual); + + TInt32 tInt = TInt32.tensorOf(scope, shape, DataBuffers.of(ints)); + IntBuffer intActual = IntBuffer.allocate(ints.capacity()); + tInt.read(DataBuffers.of(intActual)); + assertEquals(ints, intActual); + + TInt64 tLong = TInt64.tensorOf(scope, shape, DataBuffers.of(longs)); + LongBuffer longActual = LongBuffer.allocate(longs.capacity()); + tLong.read(DataBuffers.of(longActual)); + assertEquals(longs, longActual); - // validate shape-checking - { - Shape shape = Shape.of(5); - try (TFloat64 t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { - fail("should have failed on incompatible buffer"); - } catch (BufferUnderflowException e) { - // expected - } - try (TFloat32 t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { - fail("should have failed on incompatible buffer"); - } catch (BufferUnderflowException e) { - // expected - } - try (TInt32 t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { - fail("should have failed on incompatible buffer"); - } catch (BufferUnderflowException e) { - // expected } - try (TInt64 t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { - fail("should have failed on incompatible buffer"); - } catch (BufferUnderflowException e) { - // expected + + // validate shape-checking + { + Shape shape = Shape.of(5); + try (TFloat64 t = TFloat64.tensorOf(scope, shape, DataBuffers.of(doubles))) { + fail("should have failed on incompatible buffer"); + } catch (BufferUnderflowException e) { + // expected + } + try (TFloat32 t = TFloat32.tensorOf(scope, shape, DataBuffers.of(floats))) { + fail("should have failed on incompatible buffer"); + } catch (BufferUnderflowException e) { + // expected + } + try (TInt32 t = TInt32.tensorOf(scope, shape, DataBuffers.of(ints))) { + fail("should have failed on incompatible buffer"); + } catch (BufferUnderflowException e) { + // expected + } + try (TInt64 t = TInt64.tensorOf(scope, shape, DataBuffers.of(longs))) { + fail("should have failed on incompatible buffer"); + } catch (BufferUnderflowException e) { + // expected + } } } } @Test public void readFromRawData() { - int[] ints = {1, 2, 3}; - float[] floats = {1f, 2f, 3f}; - double[] doubles = {1d, 2d, 3d}; - long[] longs = {1L, 2L, 3L}; - boolean[] bools = {true, false, true}; - - try (TInt32 tints = TInt32.vectorOf(ints); - TFloat32 tfloats = TFloat32.vectorOf(floats); - TFloat64 tdoubles = TFloat64.vectorOf(doubles); - TInt64 tlongs = TInt64.vectorOf(longs); - TBool tbools = TBool.vectorOf(bools)) { + try (TensorScope scope = new TensorScope()) { + int[] ints = {1, 2, 3}; + float[] floats = {1f, 2f, 3f}; + double[] doubles = {1d, 2d, 3d}; + long[] longs = {1L, 2L, 3L}; + boolean[] bools = {true, false, true}; + + TInt32 tints = TInt32.vectorOf(scope, ints); + TFloat32 tfloats = TFloat32.vectorOf(scope, floats); + TFloat64 tdoubles = TFloat64.vectorOf(scope, doubles); + TInt64 tlongs = TInt64.vectorOf(scope, longs); + TBool tbools = TBool.vectorOf(scope, bools); // validate that any datatype is readable with ByteBuffer (content, position) { @@ -243,7 +250,7 @@ public void readFromRawData() { // validate the use of direct buffers { ByteBuffer bbuf = - ByteBuffer.allocateDirect((int)tdoubles.numBytes()).order(ByteOrder.nativeOrder()); + ByteBuffer.allocateDirect((int) tdoubles.numBytes()).order(ByteOrder.nativeOrder()); tdoubles.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tdoubles.numBytes()); assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON); } @@ -251,7 +258,7 @@ public void readFromRawData() { // validate byte order conversion { DoubleBuffer foreignBuf = - ByteBuffer.allocate((int)tdoubles.numBytes()) + ByteBuffer.allocate((int) tdoubles.numBytes()) .order( ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? ByteOrder.BIG_ENDIAN @@ -267,142 +274,136 @@ public void readFromRawData() { @Test public void scalars() { - try (TFloat32 t = TFloat32.scalarOf(2.718f)) { - assertEquals(TFloat32.class, t.type()); - assertEquals(DataType.DT_FLOAT, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals(2.718f, t.getFloat(), EPSILON_F); - } - - try (TFloat64 t = TFloat64.scalarOf(3.1415)) { - assertEquals(TFloat64.class, t.type()); - assertEquals(DataType.DT_DOUBLE, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals(3.1415, t.getDouble(), EPSILON); - } - - try (TInt32 t = TInt32.scalarOf(-33)) { - assertEquals(TInt32.class, t.type()); - assertEquals(DataType.DT_INT32, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals(-33, t.getInt()); - } - - try (TInt64 t = TInt64.scalarOf(8589934592L)) { - assertEquals(TInt64.class, t.type()); - assertEquals(DataType.DT_INT64, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals(8589934592L, t.getLong()); - } - - try (TBool t = TBool.scalarOf(true)) { - assertEquals(TBool.class, t.type()); - assertEquals(DataType.DT_BOOL, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertTrue(t.getBoolean()); - } - - try (TString t = TString.scalarOf("sombrero")) { - assertEquals(TString.class, t.type()); - assertEquals(DataType.DT_STRING, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals("sombrero", t.getObject()); - } - - final byte[] bytes = {1, 2, 3, 4}; - try (TString t = TString.tensorOfBytes(NdArrays.scalarOfObject(bytes))) { - assertEquals(TString.class, t.type()); - assertEquals(DataType.DT_STRING, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertArrayEquals(bytes, t.asBytes().getObject()); + try (TensorScope scope = new TensorScope()) { + TFloat32 tFloat = TFloat32.scalarOf(scope, 2.718f); + assertEquals(TFloat32.class, tFloat.type()); + assertEquals(DataType.DT_FLOAT, tFloat.dataType()); + assertEquals(0, tFloat.shape().numDimensions()); + assertEquals(2.718f, tFloat.getFloat(), EPSILON_F); + + TFloat64 tDouble = TFloat64.scalarOf(scope, 3.1415); + assertEquals(TFloat64.class, tDouble.type()); + assertEquals(DataType.DT_DOUBLE, tDouble.dataType()); + assertEquals(0, tDouble.shape().numDimensions()); + assertEquals(3.1415, tDouble.getDouble(), EPSILON); + + TInt32 tInt = TInt32.scalarOf(scope, -33); + assertEquals(TInt32.class, tInt.type()); + assertEquals(DataType.DT_INT32, tInt.dataType()); + assertEquals(0, tInt.shape().numDimensions()); + assertEquals(-33, tInt.getInt()); + + TInt64 tLong = TInt64.scalarOf(scope, 8589934592L); + assertEquals(TInt64.class, tLong.type()); + assertEquals(DataType.DT_INT64, tLong.dataType()); + assertEquals(0, tLong.shape().numDimensions()); + assertEquals(8589934592L, tLong.getLong()); + + TBool tBool = TBool.scalarOf(scope, true); + assertEquals(TBool.class, tBool.type()); + assertEquals(DataType.DT_BOOL, tBool.dataType()); + assertEquals(0, tBool.shape().numDimensions()); + assertTrue(tBool.getBoolean()); + + TString tString = TString.scalarOf(scope, "sombrero"); + assertEquals(TString.class, tString.type()); + assertEquals(DataType.DT_STRING, tString.dataType()); + assertEquals(0, tString.shape().numDimensions()); + assertEquals("sombrero", tString.getObject()); + + final byte[] bytes = {1, 2, 3, 4}; + TString tByteString = TString.tensorOfBytes(scope, NdArrays.scalarOfObject(bytes)); + assertEquals(TString.class, tByteString.type()); + assertEquals(DataType.DT_STRING, tByteString.dataType()); + assertEquals(0, tByteString.shape().numDimensions()); + assertArrayEquals(bytes, tByteString.asBytes().getObject()); } } @Test public void nDimensional() { - DoubleNdArray vector = StdArrays.ndCopyOf(new double[]{1.414, 2.718, 3.1415}); - try (TFloat64 t = TFloat64.tensorOf(vector)) { - assertEquals(TFloat64.class, t.type()); - assertEquals(DataType.DT_DOUBLE, t.dataType()); - assertEquals(1, t.shape().numDimensions()); - assertEquals(3, t.shape().size(0)); - assertEquals(vector, t); - } - - IntNdArray matrix = StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {4, 5, 6}}); - try (TInt32 t = TInt32.tensorOf(matrix)) { - assertEquals(TInt32.class, t.type()); - assertEquals(DataType.DT_INT32, t.dataType()); - assertEquals(2, t.shape().numDimensions()); - assertEquals(2, t.shape().size(0)); - assertEquals(3, t.shape().size(1)); - assertEquals(matrix, t); - } - - LongNdArray threeD = StdArrays.ndCopyOf(new long[][][]{ - {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, - }); - try (TInt64 t = TInt64.tensorOf(threeD)) { - assertEquals(TInt64.class, t.type()); - assertEquals(DataType.DT_INT64, t.dataType()); - assertEquals(3, t.shape().numDimensions()); - assertEquals(2, t.shape().size(0)); - assertEquals(5, t.shape().size(1)); - assertEquals(1, t.shape().size(2)); - assertEquals(threeD, t); - } - - BooleanNdArray fourD = StdArrays.ndCopyOf(new boolean[][][][]{ - {{{false, false, false, true}, {false, false, true, false}}}, - {{{false, false, true, true}, {false, true, false, false}}}, - {{{false, true, false, true}, {false, true, true, false}}}, - }); - try (TBool t = TBool.tensorOf(fourD)) { - assertEquals(TBool.class, t.type()); - assertEquals(DataType.DT_BOOL, t.dataType()); - assertEquals(4, t.shape().numDimensions()); - assertEquals(3, t.shape().size(0)); - assertEquals(1, t.shape().size(1)); - assertEquals(2, t.shape().size(2)); - assertEquals(4, t.shape().size(3)); - assertEquals(fourD, t); + try (TensorScope scope = new TensorScope()) { + DoubleNdArray vector = StdArrays.ndCopyOf(new double[]{1.414, 2.718, 3.1415}); + TFloat64 tDouble = TFloat64.tensorOf(scope, vector); + assertEquals(TFloat64.class, tDouble.type()); + assertEquals(DataType.DT_DOUBLE, tDouble.dataType()); + assertEquals(1, tDouble.shape().numDimensions()); + assertEquals(3, tDouble.shape().size(0)); + assertEquals(vector, tDouble); + + IntNdArray matrix = StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {4, 5, 6}}); + TInt32 tInt = TInt32.tensorOf(scope, matrix); + assertEquals(TInt32.class, tInt.type()); + assertEquals(DataType.DT_INT32, tInt.dataType()); + assertEquals(2, tInt.shape().numDimensions()); + assertEquals(2, tInt.shape().size(0)); + assertEquals(3, tInt.shape().size(1)); + assertEquals(matrix, tInt); + + LongNdArray threeD = StdArrays.ndCopyOf(new long[][][]{ + {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, + }); + TInt64 tLong = TInt64.tensorOf(scope, threeD); + assertEquals(TInt64.class, tLong.type()); + assertEquals(DataType.DT_INT64, tLong.dataType()); + assertEquals(3, tLong.shape().numDimensions()); + assertEquals(2, tLong.shape().size(0)); + assertEquals(5, tLong.shape().size(1)); + assertEquals(1, tLong.shape().size(2)); + assertEquals(threeD, tLong); + + BooleanNdArray fourD = StdArrays.ndCopyOf(new boolean[][][][]{ + {{{false, false, false, true}, {false, false, true, false}}}, + {{{false, false, true, true}, {false, true, false, false}}}, + {{{false, true, false, true}, {false, true, true, false}}}, + }); + TBool tBool = TBool.tensorOf(scope, fourD); + assertEquals(TBool.class, tBool.type()); + assertEquals(DataType.DT_BOOL, tBool.dataType()); + assertEquals(4, tBool.shape().numDimensions()); + assertEquals(3, tBool.shape().size(0)); + assertEquals(1, tBool.shape().size(1)); + assertEquals(2, tBool.shape().size(2)); + assertEquals(4, tBool.shape().size(3)); + assertEquals(fourD, tBool); } } @Test public void testNDimensionalStringTensor() { - NdArray matrix = NdArrays.ofObjects(String.class, Shape.of(4, 3)); - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 3; ++j) { - matrix.setObject(String.format("(%d, %d) = %d", i, j, i << j), i, j); + try (TensorScope scope = new TensorScope()) { + NdArray matrix = NdArrays.ofObjects(String.class, Shape.of(4, 3)); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 3; ++j) { + matrix.setObject(String.format("(%d, %d) = %d", i, j, i << j), i, j); + } } - } - try (TString t = TString.tensorOf(matrix)) { - assertEquals(TString.class, t.type()); - assertEquals(DataType.DT_STRING, t.dataType()); - assertEquals(2, t.shape().numDimensions()); - assertEquals(4, t.shape().size(0)); - assertEquals(3, t.shape().size(1)); - assertEquals(matrix, t); - } - - NdArray byteMatrix = NdArrays.ofObjects(byte[].class, matrix.shape()); - matrix.scalars().forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i)); - try (TString t = TString.tensorOfBytes(byteMatrix)) { - assertEquals(TString.class, t.type()); - assertEquals(DataType.DT_STRING, t.dataType()); - assertEquals(2, t.shape().numDimensions()); - assertEquals(4, t.shape().size(0)); - assertEquals(3, t.shape().size(1)); - assertEquals(byteMatrix, t.asBytes()); - assertEquals(matrix, t); + TString tString = TString.tensorOf(scope, matrix); + assertEquals(TString.class, tString.type()); + assertEquals(DataType.DT_STRING, tString.dataType()); + assertEquals(2, tString.shape().numDimensions()); + assertEquals(4, tString.shape().size(0)); + assertEquals(3, tString.shape().size(1)); + assertEquals(matrix, tString); + + NdArray byteMatrix = NdArrays.ofObjects(byte[].class, matrix.shape()); + matrix.scalars().forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i)); + TString tByteString = TString.tensorOfBytes(scope, byteMatrix); + assertEquals(TString.class, tByteString.type()); + assertEquals(DataType.DT_STRING, tByteString.dataType()); + assertEquals(2, tByteString.shape().numDimensions()); + assertEquals(4, tByteString.shape().size(0)); + assertEquals(3, tByteString.shape().size(1)); + assertEquals(byteMatrix, tByteString.asBytes()); + assertEquals(matrix, tByteString); } } @Test public void testUint8TensorFromArray() { - byte[] vector = new byte[] {1, 2, 3, 4}; - try (TUint8 t = TUint8.vectorOf(vector)) { + try (TensorScope scope = new TensorScope()) { + byte[] vector = new byte[]{1, 2, 3, 4}; + TUint8 t = TUint8.vectorOf(scope, vector); assertEquals(TUint8.class, t.type()); assertEquals(DataType.DT_UINT8, t.dataType()); assertEquals(1, t.shape().numDimensions()); @@ -416,8 +417,9 @@ public void testUint8TensorFromArray() { @Test public void testCreateFromArrayOfBoxed() { - Integer[] vector = new Integer[] {1, 2, 3, 4}; - try (TInt32 t = TInt32.tensorOf(Shape.of(4), d -> d.write(DataBuffers.ofObjects(vector)))) { + try (TensorScope scope = new TensorScope()) { + Integer[] vector = new Integer[]{1, 2, 3, 4}; + TInt32 t = TInt32.tensorOf(scope, Shape.of(4), d -> d.write(DataBuffers.ofObjects(vector))); assertEquals(TInt32.class, t.type()); assertEquals(DataType.DT_INT32, t.dataType()); assertEquals(1, t.shape().numDimensions()); @@ -431,96 +433,109 @@ public void testCreateFromArrayOfBoxed() { @Test public void failCreateOnMismatchedDimensions() { - int[][][] invalid = new int[3][1][]; - for (int x = 0; x < invalid.length; ++x) { - for (int y = 0; y < invalid[x].length; ++y) { - invalid[x][y] = new int[x + y + 1]; + try (TensorScope scope = new TensorScope()) { + int[][][] invalid = new int[3][1][]; + for (int x = 0; x < invalid.length; ++x) { + for (int y = 0; y < invalid[x].length; ++y) { + invalid[x][y] = new int[x + y + 1]; + } + } + try (TInt32 t = TInt32.tensorOf(scope, StdArrays.ndCopyOf(invalid))) { + fail("Tensor.create() should fail because of differing sizes in the 3rd dimension"); + } catch (IllegalArgumentException e) { + // The expected exception. } - } - try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(invalid))) { - fail("Tensor.create() should fail because of differing sizes in the 3rd dimension"); - } catch (IllegalArgumentException e) { - // The expected exception. } } @Test public void tensorWithZeroDimension() { - // Note: Historically, TF Java failed on purpose when trying to allocate a tensor with a shape - // that has one or more dimensions set to 0 elements. But Python API allows it, so we should do - // the same. - try (TInt32 t = TInt32.tensorOf(Shape.of(3, 0, 1))) { - assertEquals(0, t.numBytes()); - assertEquals(0, t.shape().size()); - } - try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[3][0][1]))) { + try (TensorScope scope = new TensorScope()) { + // Note: Historically, TF Java failed on purpose when trying to allocate a tensor with a shape + // that has one or more dimensions set to 0 elements. But Python API allows it, so we should do + // the same. + TInt32 t = TInt32.tensorOf(scope, Shape.of(3, 0, 1)); assertEquals(0, t.numBytes()); assertEquals(0, t.shape().size()); + + TInt32 t2 = TInt32.tensorOf(scope, StdArrays.ndCopyOf(new int[3][0][1])); + assertEquals(0, t2.numBytes()); + assertEquals(0, t2.shape().size()); } } @Test public void allocateTensorWithSize() { - try (TInt32 t = Tensor.of(TInt32.class, Shape.of(2, 2, 2), 8 * 4)) { + try (TensorScope scope = new TensorScope()) { // ok - } - try (TInt32 t = Tensor.of(TInt32.class, Shape.of(2, 2, 2), 9 * 4)) { + Tensor.of(scope, TInt32.class, Shape.of(2, 2, 2), 8 * 4); + // ok (size requested is larger that minimum space required) - } - try { - Tensor.of(TInt32.class, Shape.of(2, 2, 2), 8 * 4 - 1); - fail(); - } catch (IllegalArgumentException e) { - // as expected + Tensor.of(scope, TInt32.class, Shape.of(2, 2, 2), 9 * 4); + + try { + Tensor.of(scope, TInt32.class, Shape.of(2, 2, 2), 8 * 4 - 1); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } } } @Test public void useAfterClose() { - int n = 4; - TInt32 t = TInt32.scalarOf(n); - t.close(); - try { - t.numBytes(); - } catch (IllegalStateException e) { - // The expected exception. + try (TensorScope scope = new TensorScope()) { + int n = 4; + TInt32 t = TInt32.scalarOf(scope, n); + t.close(); + try { + t.numBytes(); + } catch (IllegalStateException e) { + // The expected exception. + } } } @Test public void fromHandle() { - // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been - // created independently of the Java code. In practice, two Tensor instances MUST NOT have the - // same native handle. - // - // An exception is made for this test, where the pitfalls of this is avoided by not calling - // close() on both Tensors. - final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}}); - try (TFloat32 src = TFloat32.tensorOf(matrix)) { - TFloat32 cpy = (TFloat32)RawTensor.fromHandle(src.asRawTensor().nativeHandle()).asTypedTensor(); + try (TensorScope scope = new TensorScope()) { + // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been + // created independently of the Java code. In practice, two Tensor instances MUST NOT have the + // same native handle. + // + // An exception is made for this test, where the pitfalls of this is avoided by not calling + // close() on both Tensors. + final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}}); + TFloat32 src = TFloat32.tensorOf(scope, matrix); + TFloat32 cpy = (TFloat32) RawTensor.fromHandle(scope, src.asRawTensor().nativeHandle()).asTypedTensor(); assertEquals(src.type(), cpy.type()); assertEquals(src.dataType(), cpy.dataType()); assertEquals(src.shape().numDimensions(), cpy.shape().numDimensions()); assertEquals(src.shape(), cpy.shape()); assertEquals(matrix, cpy); + + // don't want to call close + TensorScope.detach(cpy); } } @Test public void gracefullyFailCreationFromNullArrayForStringTensor() { - // Motivated by: https://github.com/tensorflow/tensorflow/issues/17130 - byte[][] array = new byte[1][]; - try { - TUint8.tensorOf(StdArrays.ndCopyOf(array)); - } catch (IllegalStateException e) { - // expected. - } - byte[][][] array2 = new byte[2][2][2]; - array2[1] = null; - try { - TUint8.tensorOf(StdArrays.ndCopyOf(array)); - } catch (IllegalStateException e) { - // expected. + try (TensorScope scope = new TensorScope()) { + // Motivated by: https://github.com/tensorflow/tensorflow/issues/17130 + byte[][] array = new byte[1][]; + try { + TUint8.tensorOf(scope, StdArrays.ndCopyOf(array)); + } catch (IllegalStateException e) { + // expected. + } + byte[][][] array2 = new byte[2][2][2]; + array2[1] = null; + try { + TUint8.tensorOf(scope, StdArrays.ndCopyOf(array)); + } catch (IllegalStateException e) { + // expected. + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/benchmark/TensorBenchmark.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/benchmark/TensorBenchmark.java index 053738d2dd6..a3b23fa4abf 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/benchmark/TensorBenchmark.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/benchmark/TensorBenchmark.java @@ -12,10 +12,11 @@ import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; import org.openjdk.jmh.runner.RunnerException; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.ndarray.buffer.IntDataBuffer; -import org.tensorflow.ndarray.StdArrays; import org.tensorflow.types.TInt32; @Fork(value = 1, jvmArgs = {"-Xms4G", "-Xmx4G"}) @@ -32,108 +33,114 @@ public static void main(String[] args) throws IOException, RunnerException { @Benchmark @Measurement(batchSize = 1000) public void initTensorByStdArrays() { - int[][][][] data = new int[][][][] { - { - { - {0, 0, 0}, {0, 0, 1}, {0, 0, 2} - }, - { - {0, 1, 0}, {0, 1, 1}, {0, 1, 2} - }, - { - {0, 2, 0}, {0, 2, 1}, {0, 2, 2} - } - }, { - { - {1, 0, 0}, {1, 0, 1}, {1, 0, 2} - }, - { - {1, 1, 0}, {1, 1, 1}, {1, 1, 2} - }, - { - {1, 2, 0}, {1, 2, 1}, {1, 2, 2} - } - }, { - { - {2, 0, 0}, {2, 0, 1}, {2, 0, 2} - }, - { - {2, 1, 0}, {2, 1, 1}, {2, 1, 2} - }, - { - {2, 2, 0}, {2, 2, 1}, {2, 2, 2} - } - } - }; - TInt32.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d)); + try (TensorScope scope = new TensorScope()) { + int[][][][] data = new int[][][][]{ + { + { + {0, 0, 0}, {0, 0, 1}, {0, 0, 2} + }, + { + {0, 1, 0}, {0, 1, 1}, {0, 1, 2} + }, + { + {0, 2, 0}, {0, 2, 1}, {0, 2, 2} + } + }, { + { + {1, 0, 0}, {1, 0, 1}, {1, 0, 2} + }, + { + {1, 1, 0}, {1, 1, 1}, {1, 1, 2} + }, + { + {1, 2, 0}, {1, 2, 1}, {1, 2, 2} + } + }, { + { + {2, 0, 0}, {2, 0, 1}, {2, 0, 2} + }, + { + {2, 1, 0}, {2, 1, 1}, {2, 1, 2} + }, + { + {2, 2, 0}, {2, 2, 1}, {2, 2, 2} + } + } + }; + TInt32.tensorOf(scope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d)); + } } @Benchmark @Measurement(batchSize = 1000) public void initTensorByVectors() { - TInt32.tensorOf(Shape.of(3, 3, 3, 3), d -> d - .set(vectorOf(0, 0, 0), 0, 0, 0) - .set(vectorOf(0, 0, 1), 0, 0, 1) - .set(vectorOf(0, 0, 2), 0, 0, 2) - .set(vectorOf(0, 1, 0), 0, 1, 0) - .set(vectorOf(0, 1, 1), 0, 1, 1) - .set(vectorOf(0, 1, 2), 0, 1, 2) - .set(vectorOf(0, 2, 0), 0, 2, 0) - .set(vectorOf(0, 2, 1), 0, 2, 1) - .set(vectorOf(0, 2, 2), 0, 2, 2) - .set(vectorOf(1, 0, 0), 1, 0, 0) - .set(vectorOf(1, 0, 1), 1, 0, 1) - .set(vectorOf(1, 0, 2), 1, 0, 2) - .set(vectorOf(1, 1, 0), 1, 1, 0) - .set(vectorOf(1, 1, 1), 1, 1, 1) - .set(vectorOf(1, 1, 2), 1, 1, 2) - .set(vectorOf(1, 2, 0), 1, 2, 0) - .set(vectorOf(1, 2, 1), 1, 2, 1) - .set(vectorOf(1, 2, 2), 1, 2, 2) - .set(vectorOf(2, 0, 0), 2, 0, 0) - .set(vectorOf(2, 0, 1), 2, 0, 1) - .set(vectorOf(2, 0, 2), 2, 0, 2) - .set(vectorOf(2, 1, 0), 2, 1, 0) - .set(vectorOf(2, 1, 1), 2, 1, 1) - .set(vectorOf(2, 1, 2), 2, 1, 2) - .set(vectorOf(2, 2, 0), 2, 2, 0) - .set(vectorOf(2, 2, 1), 2, 2, 1) - .set(vectorOf(2, 2, 2), 2, 2, 2) - ); + try (TensorScope scope = new TensorScope()) { + TInt32.tensorOf(scope, Shape.of(3, 3, 3, 3), d -> d + .set(vectorOf(0, 0, 0), 0, 0, 0) + .set(vectorOf(0, 0, 1), 0, 0, 1) + .set(vectorOf(0, 0, 2), 0, 0, 2) + .set(vectorOf(0, 1, 0), 0, 1, 0) + .set(vectorOf(0, 1, 1), 0, 1, 1) + .set(vectorOf(0, 1, 2), 0, 1, 2) + .set(vectorOf(0, 2, 0), 0, 2, 0) + .set(vectorOf(0, 2, 1), 0, 2, 1) + .set(vectorOf(0, 2, 2), 0, 2, 2) + .set(vectorOf(1, 0, 0), 1, 0, 0) + .set(vectorOf(1, 0, 1), 1, 0, 1) + .set(vectorOf(1, 0, 2), 1, 0, 2) + .set(vectorOf(1, 1, 0), 1, 1, 0) + .set(vectorOf(1, 1, 1), 1, 1, 1) + .set(vectorOf(1, 1, 2), 1, 1, 2) + .set(vectorOf(1, 2, 0), 1, 2, 0) + .set(vectorOf(1, 2, 1), 1, 2, 1) + .set(vectorOf(1, 2, 2), 1, 2, 2) + .set(vectorOf(2, 0, 0), 2, 0, 0) + .set(vectorOf(2, 0, 1), 2, 0, 1) + .set(vectorOf(2, 0, 2), 2, 0, 2) + .set(vectorOf(2, 1, 0), 2, 1, 0) + .set(vectorOf(2, 1, 1), 2, 1, 1) + .set(vectorOf(2, 1, 2), 2, 1, 2) + .set(vectorOf(2, 2, 0), 2, 2, 0) + .set(vectorOf(2, 2, 1), 2, 2, 1) + .set(vectorOf(2, 2, 2), 2, 2, 2) + ); + } } @Benchmark @Measurement(batchSize = 1000) public void initTensorByFlatArray() { - IntDataBuffer data = DataBuffers.of( - 0, 0, 0, - 0, 0, 1, - 0, 0, 2, - 0, 1, 0, - 0, 1, 1, - 0, 1, 2, - 0, 2, 0, - 0, 2, 1, - 0, 2, 2, - 1, 0, 0, - 1, 0, 1, - 1, 0, 2, - 1, 1, 0, - 1, 1, 1, - 1, 1, 2, - 1, 2, 0, - 1, 2, 1, - 1, 2, 2, - 2, 0, 0, - 2, 0, 1, - 2, 0, 2, - 2, 1, 0, - 2, 1, 1, - 2, 1, 2, - 2, 2, 0, - 2, 2, 1, - 2, 2, 2 - ); - TInt32.tensorOf(Shape.of(3, 3, 3, 3), data); + try (TensorScope scope = new TensorScope()) { + IntDataBuffer data = DataBuffers.of( + 0, 0, 0, + 0, 0, 1, + 0, 0, 2, + 0, 1, 0, + 0, 1, 1, + 0, 1, 2, + 0, 2, 0, + 0, 2, 1, + 0, 2, 2, + 1, 0, 0, + 1, 0, 1, + 1, 0, 2, + 1, 1, 0, + 1, 1, 1, + 1, 1, 2, + 1, 2, 0, + 1, 2, 1, + 1, 2, 2, + 2, 0, 0, + 2, 0, 1, + 2, 0, 2, + 2, 1, 0, + 2, 1, 1, + 2, 1, 2, + 2, 2, 0, + 2, 2, 1, + 2, 2, 2 + ); + TInt32.tensorOf(scope, Shape.of(3, 3, 3, 3), data); + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java index 62881dcee8c..874680a4015 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java @@ -23,10 +23,13 @@ import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TType; -/** Unit tests for {@link org.tensorflow.op.Scope}. */ +/** + * Unit tests for {@link org.tensorflow.op.Scope}. + */ public class ScopeTest { @Test @@ -85,9 +88,9 @@ public void validateNames() { Scope root = new Scope(g); final String[] invalid_names = { - "_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.] - null, "", "a$", // Invalid characters - "a/b", // slashes not allowed + "_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.] + null, "", "a$", // Invalid characters + "a/b", // slashes not allowed }; for (String name : invalid_names) { @@ -144,10 +147,11 @@ public void hierarchy() { @Test public void composite() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Scope s = new Scope(g); Output data = - Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output(); + Const.create(s.withName("data"), new int[]{600, 470, 170, 430, 300}).output(); // Create a composite op with a customized name Variance var1 = Variance.create(s.withName("example"), data); @@ -168,23 +172,28 @@ public void composite() { // assertNotNull(g.operation("variance/zero")); // Verify correct results as well. - TInt32 result = (TInt32)sess.runner().fetch(var1.output()).run().get(0); + TInt32 result = (TInt32) sess.runner().fetch(var1.output()).run(scope).get(0); assertEquals(21704, result.getInt()); - result = (TInt32)sess.runner().fetch(var2.output()).run().get(0); + result = (TInt32) sess.runner().fetch(var2.output()).run(scope).get(0); assertEquals(21704, result.getInt()); } } // "handwritten" sample operator classes private static final class Const { + private final Output output; static Const create(Scope s, int v) { - return create(s, TInt32.scalarOf(v)); + try (TensorScope scope = new TensorScope()) { + return create(s, TInt32.scalarOf(scope, v)); + } } static Const create(Scope s, int[] v) { - return create(s, TInt32.vectorOf(v)); + try (TensorScope scope = new TensorScope()) { + return create(s, TInt32.vectorOf(scope, v)); + } } static Const create(Scope s, T value) { @@ -207,6 +216,7 @@ Output output() { } private static final class Mean { + private final Output output; static Mean create(Scope s, Output input, Output reductionIndices) { @@ -229,6 +239,7 @@ Output output() { } private static final class SquaredDifference { + private final Output output; static SquaredDifference create(Scope s, Output x, Output y) { @@ -251,17 +262,20 @@ Output output() { } private static final class Variance { + private final Output output; static Variance create(Scope base, Output x) { - Scope s = base.withSubScope("variance"); - Output zero = Const.create(base, TInt32.scalarOf(0)).output(); - Output sqdiff = - SquaredDifference.create( - s.withName("squared_deviation"), x, Mean.create(s, x, zero).output()) - .output(); - - return new Variance<>(Mean.create(s.withName("variance"), sqdiff, zero).output()); + try (TensorScope scope = new TensorScope()) { + Scope s = base.withSubScope("variance"); + Output zero = Const.create(base, TInt32.scalarOf(scope, 0)).output(); + Output sqdiff = + SquaredDifference.create( + s.withName("squared_deviation"), x, Mean.create(s, x, zero).output()) + .output(); + + return new Variance<>(Mean.create(s.withName("variance"), sqdiff, zero).output()); + } } Variance(Output o) { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 6df73261867..18e6e900ac1 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -18,13 +18,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.io.IOException; +import java.util.List; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.EagerSession; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.IntNdArray; @@ -38,8 +39,12 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; -import org.tensorflow.op.Ops; -import org.tensorflow.op.Scope; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.IntNdArray; +import org.tensorflow.ndarray.LongNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TFloat16; import org.tensorflow.types.TFloat32; @@ -61,15 +66,14 @@ public void createInts() { IntNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @@ -80,15 +84,14 @@ public void createFloats() { FloatNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @@ -99,15 +102,14 @@ public void createDoubles() { DoubleNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @@ -118,15 +120,14 @@ public void createLongs() { LongNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @@ -137,36 +138,36 @@ public void createStrings() throws IOException { NdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @Test public void createFromTensorsInEagerMode() throws IOException { try (EagerSession s = EagerSession.create(); - TInt32 t = TInt32.vectorOf(1, 2, 3, 4)) { + TensorScope tensorScope = new TensorScope(); + TInt32 t = TInt32.vectorOf(tensorScope, 1, 2, 3, 4)) { Ops tf = Ops.create(s); Constant c1 = tf.constant(t); - assertEquals(c1.asTensor(), t); + assertEquals(c1.asTensor(tensorScope), t); // A different endpoint for capturing a tensor as a constant, which supports all data types Constant c2 = tf.constantOf(t); - assertEquals(c2.asTensor(), t); - assertEquals(c1.asTensor(), c2.asTensor()); + assertEquals(c2.asTensor(tensorScope), t); + assertEquals(c1.asTensor(tensorScope), c2.asTensor(tensorScope)); // Permute data in the tensor to make sure that constant copies are independent t.setInt(10); assertEquals(NdArrays.vectorOf(10, 2, 3, 4), t); - assertEquals(NdArrays.vectorOf(1, 2, 3, 4), c1.asTensor()); + assertEquals(NdArrays.vectorOf(1, 2, 3, 4), c1.asTensor(tensorScope)); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java index b1ebd469eb3..d686dcae81b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java @@ -22,6 +22,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -32,28 +33,28 @@ public final class GeneratedOperationsTest { @Test public void tensorInputTensorOutput() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops ops = Ops.create(g); Operand x = ops.math.add(ops.constant(1), ops.constant(2)); - try (TInt32 result = (TInt32)sess.runner().fetch(x).run().get(0)) { - assertEquals(3, result.getInt()); - } + TInt32 result = (TInt32) sess.runner().fetch(x).run(scope).get(0); + assertEquals(3, result.getInt()); } } @Test public void testListInputTensorOutput() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops ops = Ops.create(g); ArrayList> inputs = new ArrayList<>(); inputs.add(ops.constant(1)); inputs.add(ops.constant(2)); inputs.add(ops.constant(3)); Operand x = ops.math.addN(inputs); - try (TInt32 result = (TInt32)sess.runner().fetch(x).run().get(0)) { - assertEquals(6, result.getInt()); - } + TInt32 result = (TInt32) sess.runner().fetch(x).run(scope).get(0); + assertEquals(6, result.getInt()); } } @@ -61,13 +62,14 @@ public void testListInputTensorOutput() { * Test for Ops.withControlDependencies. * *

Creates an add node with a control dependency to an assign node. In other words, the assign - * node is a control input to the add node. When the add node is run, the assign node is expected - * to have run beforehand due to the control dependency. + * node is a control input to the add node. When the add node is run, the assign node is expected to have run + * beforehand due to the control dependency. */ @Test public void testControlDependencies() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops ops = Ops.create(g); Operand variable = ops.variable(Shape.scalar(), TInt32.class); Operand initVariable = ops.assign(variable, ops.constant(0)); @@ -75,10 +77,9 @@ public void testControlDependencies() { controls.add(ops.assign(variable, ops.constant(3))); Operand x = ops.withControlDependencies(controls).math.add(variable, ops.constant(0)); - sess.runner().addTarget(initVariable).run(); - try (TInt32 result = (TInt32)sess.runner().fetch(x).run().get(0)) { - assertEquals(3, result.getInt()); - } + sess.runner().addTarget(initVariable).run(scope); + TInt32 result = (TInt32) sess.runner().fetch(x).run(scope).get(0); + assertEquals(3, result.getInt()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 80150b64bb6..3b6f0ada75d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -20,12 +20,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; @@ -34,7 +35,8 @@ public class GradientsTest { @Test public void createGradients() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); @@ -47,21 +49,19 @@ public void createGradients() { assertNotNull(grads.dy()); assertEquals(2, grads.dy().size()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + List outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run(scope); - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); - } + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + assertEquals(18.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); } } @Test public void createGradientsWithSum() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); @@ -74,19 +74,18 @@ public void createGradientsWithSum() { assertNotNull(grads.dy()); assertEquals(1, grads.dy().size()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + List outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run(scope); - assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - } + assertEquals(114.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); } } @Test public void createGradientsWithInitialValues() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); @@ -100,13 +99,10 @@ public void createGradientsWithInitialValues() { assertNotNull(grads1.dy()); assertEquals(1, grads1.dy().size()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + List outputs = sess.runner().feed(x, c).fetch(grads1.dy(0)).run(scope); - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - } + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java index 39c04c942af..46604a53ddb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java @@ -22,6 +22,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.op.Scope; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -29,153 +30,165 @@ public class ShapesTest { - /** Test of flatten method, of class Shapes. */ + /** + * Test of flatten method, of class Shapes. + */ @Test public void testFlatten_Operand() { - try (Graph g = new Graph(); + try (TensorScope tensorScope = new TensorScope(); + Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Shape expResult = Shape.create(scope, operand, TInt64.class); Operand reshaped = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Operand actual = Shapes.flatten(scope, reshaped); Shape tfshape = Shape.create(scope, actual, TInt64.class); AtomicInteger index = new AtomicInteger(); - try (TInt64 result1 = (TInt64)session.runner().fetch(tfshape.asOutput()).run().get(0); - TInt64 result2 = (TInt64)session.runner().fetch(expResult.asOutput()).run().get(0)) { - result1 - .scalars() - .forEach( - s -> assertEquals(result2.getLong(index.getAndIncrement()), s.getLong())); - } + TInt64 result1 = (TInt64) session.runner().fetch(tfshape.asOutput()).run(tensorScope).get(0); + TInt64 result2 = (TInt64) session.runner().fetch(expResult.asOutput()).run(tensorScope).get(0); + result1 + .scalars() + .forEach( + s -> assertEquals(result2.getLong(index.getAndIncrement()), s.getLong())); } } - /** Test of flatten method, of class Shapes. */ + /** + * Test of flatten method, of class Shapes. + */ @Test public void testFlatten_Shape() { - try (EagerSession session = EagerSession.create()) { + try (TensorScope tensorScope = new TensorScope(); + EagerSession session = EagerSession.create()) { Scope scope = new Scope(session); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Shape expShape = Shape.create(scope, operand, TInt64.class); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand flattened = Shapes.flatten(scope, tfshape, TInt64.class); AtomicInteger index = new AtomicInteger(); flattened - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> assertEquals( - expShape.asTensor().getLong(index.getAndIncrement()), s.getLong())); + expShape.asTensor(tensorScope).getLong(index.getAndIncrement()), s.getLong())); } } - /** Test of size method, of class Shapes. */ + /** + * Test of size method, of class Shapes. + */ @Test public void testSize_Shape() { - try (Graph g = new Graph(); + try (TensorScope tensorScope = new TensorScope(); + Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand size = Shapes.size(scope, tfshape, TInt64.class); AtomicInteger index = new AtomicInteger(); - try (TInt64 result1 = (TInt64)session.runner().fetch(size.asOutput()).run().get(0)) { - result1.scalars().forEach(s -> assertEquals(8, s.getLong())); - } + TInt64 result1 = (TInt64) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result1.scalars().forEach(s -> assertEquals(8, s.getLong())); } } - /** Test of size method, of class Shapes. */ + /** + * Test of size method, of class Shapes. + */ @Test public void testSize_Shape_Operand() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 0)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(4, s.getInt())); - } + TInt32 result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(4, s.getInt())); size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 1)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(2, s.getInt())); - } + result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(2, s.getInt())); size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 2)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(1, s.getInt())); - } + result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(1, s.getInt())); } } - /** Test of size method, of class Shapes. */ + /** + * Test of size method, of class Shapes. + */ @Test public void testSize_Operand_Operand() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Operand size = Shapes.size(scope, actual, Constant.scalarOf(scope, 0)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(4, s.getInt())); - } + TInt32 result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(4, s.getInt())); size = Shapes.size(scope, actual, Constant.scalarOf(scope, 1)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(2, s.getInt())); - } + result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(2, s.getInt())); size = Shapes.size(scope, actual, Constant.scalarOf(scope, 2)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(1, s.getInt())); - } + result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(1, s.getInt())); } } - /** Test of numDimensions method, of class Shapes. */ + /** + * Test of numDimensions method, of class Shapes. + */ @Test public void testNumDimensions() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand nDims = Shapes.numDimensions(scope, tfshape); - try (TInt32 result = (TInt32)session.runner().fetch(nDims.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(3, s.getInt())); - } + TInt32 result = (TInt32) session.runner().fetch(nDims.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(3, s.getInt())); } } - /** Test of reduceDims method, of class Shapes. */ + /** + * Test of reduceDims method, of class Shapes. + */ @Test public void testReduceDims_Operand_Operand() { - try (EagerSession session = EagerSession.create()) { + try (EagerSession session = EagerSession.create(); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(session); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {2, 2, 2})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{2, 2, 2})); Shape tfshape = Shape.create(scope, actual); Operand reduced = Shapes.reduceDims(scope, actual, Constant.scalarOf(scope, 0)); @@ -183,7 +196,7 @@ public void testReduceDims_Operand_Operand() { AtomicInteger index = new AtomicInteger(); int[] expected = {8}; reducedShape - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> { @@ -193,14 +206,17 @@ public void testReduceDims_Operand_Operand() { } } - /** Test of reduceDims method, of class Shapes. */ + /** + * Test of reduceDims method, of class Shapes. + */ @Test public void testReduceDims_Shape_Operand() { - try (EagerSession session = EagerSession.create()) { + try (EagerSession session = EagerSession.create(); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(session); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {2, 2, 2})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{2, 2, 2})); Shape tfshape = Shape.create(scope, actual); Operand reduced = Shapes.reduceDims(scope, actual, Constant.scalarOf(scope, 0)); @@ -208,7 +224,7 @@ public void testReduceDims_Shape_Operand() { AtomicInteger index = new AtomicInteger(); int[] expected1 = {8}; reducedShape - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> { @@ -221,7 +237,7 @@ public void testReduceDims_Shape_Operand() { index.set(0); int[] expected2 = {2, 4}; reducedShape - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> { @@ -234,7 +250,7 @@ public void testReduceDims_Shape_Operand() { index.set(0); int[] expected3 = {2, 2, 2}; reducedShape - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> { @@ -244,28 +260,30 @@ public void testReduceDims_Shape_Operand() { } } - /** Test of squeeze method, of class Shapes. */ + /** + * Test of squeeze method, of class Shapes. + */ @Test public void testSqueeze() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand squeezed = Shapes.squeeze(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2}; - try (TInt32 result = (TInt32)session.runner().fetch(squeezed.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(squeezed.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -273,24 +291,24 @@ public void testSqueeze() { @Test public void testHead() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand head = Shapes.head(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {4}; - try (TInt32 result = (TInt32)session.runner().fetch(head.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(head.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -298,24 +316,24 @@ public void testHead() { @Test public void testTake() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand take = Shapes.take(scope, tfshape, Constant.scalarOf(scope, 2)); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 1}; - try (TInt32 result = (TInt32)session.runner().fetch(take.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(take.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -323,24 +341,24 @@ public void testTake() { @Test public void testTail() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand tail = Shapes.tail(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {1}; - try (TInt32 result = (TInt32)session.runner().fetch(tail.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(tail.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -348,24 +366,24 @@ public void testTail() { @Test public void testTakeLast() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand takeLast = Shapes.takeLast(scope, tfshape, Constant.scalarOf(scope, 3)); AtomicInteger index = new AtomicInteger(); int[] expected = {1, 2, 1}; - try (TInt32 result = (TInt32)session.runner().fetch(takeLast.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(takeLast.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -373,23 +391,23 @@ public void testTakeLast() { @Test public void testPrependInt() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2})); Shape tfshape = Shape.create(scope, actual); Operand prepend = Shapes.prepend(scope, tfshape, 3); AtomicInteger index = new AtomicInteger(); int[] expected = {3, 4, 2}; - try (TInt32 result = (TInt32)session.runner().fetch(prepend.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(prepend.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -397,23 +415,23 @@ public void testPrependInt() { @Test public void testPrependLong() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2})); Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand prepend = Shapes.prepend(scope, tfshape, 1L); AtomicInteger index = new AtomicInteger(); long[] expected = {1, 4, 2}; - try (TInt64 result = (TInt64)session.runner().fetch(prepend.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getLong()); - }); - } + TInt64 result = (TInt64) session.runner().fetch(prepend.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getLong()); + }); assertEquals(expected.length, index.get()); } } @@ -421,28 +439,28 @@ public void testPrependLong() { @Test public void testPrependShapeTInt32() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand1 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual1 = - Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[]{4, 2})); + Operand operand2 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = - Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); + Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[]{2, 4})); Shape tfshape1 = Shape.create(scope, actual1); Shape tfshape2 = Shape.create(scope, actual2); Operand prepend = Shapes.prepend(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); int[] expected = {2, 4, 4, 2}; - try (TInt32 result = (TInt32)session.runner().fetch(prepend.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(prepend.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -450,28 +468,28 @@ public void testPrependShapeTInt32() { @Test public void testPrependShapeTInt64() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand1 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual1 = - Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[]{4, 2})); + Operand operand2 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = - Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); + Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[]{2, 4})); Shape tfshape1 = Shape.create(scope, actual1, TInt64.class); Shape tfshape2 = Shape.create(scope, actual2, TInt64.class); Operand prepend = Shapes.prepend(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); long[] expected = {2, 4, 4, 2}; - try (TInt64 result = (TInt64)session.runner().fetch(prepend.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getLong()); - }); - } + TInt64 result = (TInt64) session.runner().fetch(prepend.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getLong()); + }); assertEquals(expected.length, index.get()); } } @@ -479,23 +497,23 @@ public void testPrependShapeTInt64() { @Test public void testAppendLong() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2})); Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand append = Shapes.append(scope, tfshape, 2L); AtomicInteger index = new AtomicInteger(); long[] expected = {4L, 2L, 2L}; - try (TInt64 result = (TInt64)session.runner().fetch(append.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getLong()); - }); - } + TInt64 result = (TInt64) session.runner().fetch(append.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getLong()); + }); assertEquals(expected.length, index.get()); } } @@ -503,23 +521,23 @@ public void testAppendLong() { @Test public void testAppendInt() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2})); Shape tfshape = Shape.create(scope, actual); Operand append = Shapes.append(scope, tfshape, 2); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2, 2}; - try (TInt32 result = (TInt32)session.runner().fetch(append.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(append.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -527,28 +545,28 @@ public void testAppendInt() { @Test public void testAppendShapeTInt32() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand1 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual1 = - Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[]{4, 2})); + Operand operand2 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = - Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); + Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[]{2, 4})); Shape tfshape1 = Shape.create(scope, actual1); Shape tfshape2 = Shape.create(scope, actual2); Operand append = Shapes.append(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2, 2, 4}; - try (TInt32 result = (TInt32)session.runner().fetch(append.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(append.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -556,28 +574,28 @@ public void testAppendShapeTInt32() { @Test public void testAppendShapeTInt64() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand1 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual1 = - Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[]{4, 2})); + Operand operand2 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = - Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); + Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[]{2, 4})); Shape tfshape1 = Shape.create(scope, actual1, TInt64.class); Shape tfshape2 = Shape.create(scope, actual2, TInt64.class); Operand append = Shapes.append(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); long[] expected = {4, 2, 2, 4}; - try (TInt64 result = (TInt64)session.runner().fetch(append.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getLong()); - }); - } + TInt64 result = (TInt64) session.runner().fetch(append.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getLong()); + }); assertEquals(expected.length, index.get()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index 4121baf3af1..4e379efecbc 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.op.Scope; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; @@ -37,102 +38,104 @@ public class ZerosTest { @Test public void createIntZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt32.class); - try (TInt32 result = (TInt32)sess.runner().fetch(op).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0, s.getInt())); - } + TInt32 result = (TInt32) sess.runner().fetch(op).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0, s.getInt())); } } @Test public void createFloatZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class); - try (TFloat32 result = (TFloat32)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0.0f, s.getFloat(), 0)); - } + TFloat32 result = (TFloat32) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0.0f, s.getFloat(), 0)); + } } @Test public void createDoubleZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat64.class); - try (TFloat64 result = (TFloat64)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0.0f, s.getDouble(), 0)); - } + TFloat64 result = (TFloat64) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0.0f, s.getDouble(), 0)); } } @Test public void createLongZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt64.class); - try (TInt64 result = (TInt64)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0L, s.getLong())); - } + TInt64 result = (TInt64) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0L, s.getLong())); } } @Test public void createBooleanZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TBool.class); - try (TBool result = (TBool)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertFalse(s.getBoolean())); - } - } + TBool result = (TBool) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertFalse(s.getBoolean())); + } } @Test public void createUint8Zeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TUint8.class); - try (TUint8 result = (TUint8)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0, s.getByte())); - } + TUint8 result = (TUint8) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0, s.getByte())); } } @Test public void createStringZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TString.class); - try (TString result = (TString)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertTrue(s.getObject().isEmpty())); - } + TString result = (TString) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertTrue(s.getObject().isEmpty())); } } @Test public void operationsComposingZerosAreCorrectlyNamed() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); - List results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); + List results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(tensorScope); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java index faddc7c5826..03254e65196 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.EagerSession; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.index.Indices; @@ -34,78 +35,82 @@ abstract class NumericTypesTestBase { @Test public void initializeTensorsWithZeros() { - // Allocate a tensor of 32-bits integer of the shape (2, 3, 2) - T tensor = allocateTensor(Shape.of(2, 3, 2)); + try (TensorScope scope = new TensorScope()) { + // Allocate a tensor of 32-bits integer of the shape (2, 3, 2) + T tensor = allocateTensor(scope, Shape.of(2, 3, 2)); - assertEquals(3, tensor.rank()); - assertEquals(12, tensor.size()); - NdArray data = (NdArray)tensor; + assertEquals(3, tensor.rank()); + assertEquals(12, tensor.size()); + NdArray data = (NdArray) tensor; - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); + try (EagerSession session = EagerSession.create()) { + Ops tf = Ops.create(session); - // Initialize tensor memory with zeros and take a snapshot - data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(0))); - Constant x = tf.constantOf(tensor); + // Initialize tensor memory with zeros and take a snapshot + data.scalars().forEach(scalar -> ((NdArray) scalar).setObject(valueOf(0))); + Constant x = tf.constantOf(tensor); - // Initialize the same tensor memory with ones and take a snapshot - data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(1))); - Constant y = tf.constantOf(tensor); + // Initialize the same tensor memory with ones and take a snapshot + data.scalars().forEach(scalar -> ((NdArray) scalar).setObject(valueOf(1))); + Constant y = tf.constantOf(tensor); - // Subtract y from x and validate the result - Sub sub = tf.math.sub(x, y); - ((NdArray)sub.asTensor()).scalars().forEach(scalar -> - assertEquals(valueOf(-1), scalar.getObject()) - ); + // Subtract y from x and validate the result + Sub sub = tf.math.sub(x, y); + ((NdArray) sub.asTensor(scope)).scalars().forEach(scalar -> + assertEquals(valueOf(-1), scalar.getObject()) + ); + } } } @Test public void setAndCompute() { - NdArray heapData = allocateNdArray(Shape.of(4)) - .setObject(valueOf(0), 0) - .setObject(valueOf(1), 1) - .setObject(valueOf(2), 2) - .setObject(valueOf(3), 3); - - // Creates a 2x2 matrix - try (T tensor = allocateTensor(Shape.of(2, 2))) { - NdArray data = (NdArray)tensor; - - // Copy first 2 values of the vector to the first row of the matrix - data.set(heapData.slice(Indices.range(0, 2)), 0); - - // Copy values at an odd position in the vector as the second row of the matrix - data.set(heapData.slice(Indices.odd()), 1); - - assertEquals(valueOf(0), data.getObject(0, 0)); - assertEquals(valueOf(1), data.getObject(0, 1)); - assertEquals(valueOf(1), data.getObject(1, 0)); - assertEquals(valueOf(3), data.getObject(1, 1)); - - // Read rows of the tensor in reverse order - NdArray flippedData = data.slice(Indices.flip(), Indices.flip()); - - assertEquals(valueOf(3), flippedData.getObject(0, 0)); - assertEquals(valueOf(1), flippedData.getObject(0, 1)); - assertEquals(valueOf(1), flippedData.getObject(1, 0)); - assertEquals(valueOf(0), flippedData.getObject(1, 1)); - - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); - - Add add = tf.math.add(tf.constantOf(tensor), tf.constantOf(tensor)); - NdArray result = (NdArray)add.asTensor(); - - assertEquals(valueOf(0), result.getObject(0, 0)); - assertEquals(valueOf(2), result.getObject(0, 1)); - assertEquals(valueOf(2), result.getObject(1, 0)); - assertEquals(valueOf(6), result.getObject(1, 1)); + try (TensorScope scope = new TensorScope()) { + NdArray heapData = allocateNdArray(Shape.of(4)) + .setObject(valueOf(0), 0) + .setObject(valueOf(1), 1) + .setObject(valueOf(2), 2) + .setObject(valueOf(3), 3); + + // Creates a 2x2 matrix + try (T tensor = allocateTensor(scope, Shape.of(2, 2))) { + NdArray data = (NdArray) tensor; + + // Copy first 2 values of the vector to the first row of the matrix + data.set(heapData.slice(Indices.range(0, 2)), 0); + + // Copy values at an odd position in the vector as the second row of the matrix + data.set(heapData.slice(Indices.odd()), 1); + + assertEquals(valueOf(0), data.getObject(0, 0)); + assertEquals(valueOf(1), data.getObject(0, 1)); + assertEquals(valueOf(1), data.getObject(1, 0)); + assertEquals(valueOf(3), data.getObject(1, 1)); + + // Read rows of the tensor in reverse order + NdArray flippedData = data.slice(Indices.flip(), Indices.flip()); + + assertEquals(valueOf(3), flippedData.getObject(0, 0)); + assertEquals(valueOf(1), flippedData.getObject(0, 1)); + assertEquals(valueOf(1), flippedData.getObject(1, 0)); + assertEquals(valueOf(0), flippedData.getObject(1, 1)); + + try (EagerSession session = EagerSession.create()) { + Ops tf = Ops.create(session); + + Add add = tf.math.add(tf.constantOf(tensor), tf.constantOf(tensor)); + NdArray result = (NdArray) add.asTensor(scope); + + assertEquals(valueOf(0), result.getObject(0, 0)); + assertEquals(valueOf(2), result.getObject(0, 1)); + assertEquals(valueOf(2), result.getObject(1, 0)); + assertEquals(valueOf(6), result.getObject(1, 1)); + } } } } - abstract T allocateTensor(Shape shape); + abstract T allocateTensor(TensorScope scope, Shape shape); abstract NdArray allocateNdArray(Shape shape); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java index 17a6e0dd2b5..9d2890ebdce 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TBfloat16Test extends NumericTypesTestBase { @Override - TBfloat16 allocateTensor(Shape shape) { - return TBfloat16.tensorOf(shape); + TBfloat16 allocateTensor(TensorScope scope, Shape shape) { + return TBfloat16.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java index c1ae8ad3b6d..4339be8a2c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TFloat16Test extends NumericTypesTestBase { @Override - TFloat16 allocateTensor(Shape shape) { - return TFloat16.tensorOf(shape); + TFloat16 allocateTensor(TensorScope scope, Shape shape) { + return TFloat16.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java index 8df96f2871a..45bf8278487 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TFloat32Test extends NumericTypesTestBase { @Override - TFloat32 allocateTensor(Shape shape) { - return TFloat32.tensorOf(shape); + TFloat32 allocateTensor(TensorScope scope, Shape shape) { + return TFloat32.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java index 47b4b6d936a..fd12b6f2666 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TFloat64Test extends NumericTypesTestBase { @Override - TFloat64 allocateTensor(Shape shape) { - return TFloat64.tensorOf(shape); + TFloat64 allocateTensor(TensorScope scope, Shape shape) { + return TFloat64.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java index a2ab28b6219..364ac21e8bf 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TInt32Test extends NumericTypesTestBase { @Override - TInt32 allocateTensor(Shape shape) { - return TInt32.tensorOf(shape); + TInt32 allocateTensor(TensorScope scope, Shape shape) { + return TInt32.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java index a88f3fb4d6d..2e04cb21688 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TInt64Test extends NumericTypesTestBase { @Override - TInt64 allocateTensor(Shape shape) { - return TInt64.tensorOf(shape); + TInt64 allocateTensor(TensorScope scope, Shape shape) { + return TInt64.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java index 015f93b70e7..6b4893baf2a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java @@ -23,6 +23,7 @@ import java.nio.charset.StandardCharsets; import org.junit.jupiter.api.Test; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -31,77 +32,94 @@ public class TStringTest { @Test public void createScalar() { - TString tensor = TString.scalarOf("Pretty vacant"); - assertNotNull(tensor); - assertEquals(Shape.scalar(), tensor.shape()); - assertEquals("Pretty vacant", tensor.getObject()); + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.scalarOf(scope, "Pretty vacant"); + assertNotNull(tensor); + assertEquals(Shape.scalar(), tensor.shape()); + assertEquals("Pretty vacant", tensor.getObject()); + } } - @Test - public void createrScalarLongerThan127() { - TString tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); - assertNotNull(tensor); - assertEquals(Shape.scalar(), tensor.shape()); - assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", tensor.getObject()); + @Test + public void createrScalarLongerThan127() { + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.scalarOf(scope, + "Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); + assertNotNull(tensor); + assertEquals(Shape.scalar(), tensor.shape()); + assertEquals( + "Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", + tensor.getObject()); } + } - @Test + @Test public void createVector() { - TString tensor = TString.vectorOf("Pretty", "vacant"); - assertNotNull(tensor); - assertEquals(Shape.of(2), tensor.shape()); - assertEquals("Pretty", tensor.getObject(0)); - assertEquals("vacant", tensor.getObject(1)); + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.vectorOf(scope, "Pretty", "vacant"); + assertNotNull(tensor); + assertEquals(Shape.of(2), tensor.shape()); + assertEquals("Pretty", tensor.getObject(0)); + assertEquals("vacant", tensor.getObject(1)); + } } @Test public void createCopy() { - NdArray strings = NdArrays.ofObjects(String.class, Shape.of(2, 2)) - .setObject("Pretty", 0, 0) - .setObject("vacant", 0, 1) - .setObject("New", 1, 0) - .setObject("York", 1, 1); + try (TensorScope scope = new TensorScope()) { + NdArray strings = NdArrays.ofObjects(String.class, Shape.of(2, 2)) + .setObject("Pretty", 0, 0) + .setObject("vacant", 0, 1) + .setObject("New", 1, 0) + .setObject("York", 1, 1); - TString tensor = TString.tensorOf(strings); - assertNotNull(tensor); - strings.scalars().forEachIndexed((idx, s) -> - assertEquals(s.getObject(), tensor.getObject(idx)) - ); + TString tensor = TString.tensorOf(scope, strings); + assertNotNull(tensor); + strings.scalars().forEachIndexed((idx, s) -> + assertEquals(s.getObject(), tensor.getObject(idx)) + ); + } } @Test public void defaultCharsetIsUtf8() { - TString tensor = TString.tensorOf(NdArrays.scalarOfObject(BABY_CHICK)); - byte[] bytes = tensor.asBytes().getObject(); - assertArrayEquals(new byte[] { (byte)0xF0, (byte)0x9F, (byte)0x90, (byte)0xA5 }, bytes); - assertEquals(BABY_CHICK, tensor.getObject()); + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.tensorOf(scope, NdArrays.scalarOfObject(BABY_CHICK)); + byte[] bytes = tensor.asBytes().getObject(); + assertArrayEquals(new byte[]{(byte) 0xF0, (byte) 0x9F, (byte) 0x90, (byte) 0xA5}, bytes); + assertEquals(BABY_CHICK, tensor.getObject()); + } } @Test public void usingDifferentCharset() { - TString tensor = TString.tensorOf(StandardCharsets.UTF_16LE, NdArrays.scalarOfObject(BABY_CHICK)); - byte[] bytes = tensor.asBytes().getObject(); - assertArrayEquals(new byte[] { (byte)0x3D, (byte)0xD8, (byte)0x25, (byte)0xDC }, bytes); - assertEquals(BABY_CHICK, tensor.using(StandardCharsets.UTF_16LE).getObject()); + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.tensorOf(scope, StandardCharsets.UTF_16LE, NdArrays.scalarOfObject(BABY_CHICK)); + byte[] bytes = tensor.asBytes().getObject(); + assertArrayEquals(new byte[]{(byte) 0x3D, (byte) 0xD8, (byte) 0x25, (byte) 0xDC}, bytes); + assertEquals(BABY_CHICK, tensor.using(StandardCharsets.UTF_16LE).getObject()); + } } @Test public void initializingTensorWithRawBytes() { - String[] strings = new String[] { "TensorFlow", "For", "Java", "Rocks", "!" }; - NdArray bytes = NdArrays.ofObjects(byte[].class, Shape.of(strings.length)); - for (int i = 0; i < strings.length; ++i) { - bytes.setObject(strings[i].getBytes(), i); - } - TString tensor = TString.tensorOfBytes(bytes); - assertNotNull(tensor); - assertEquals(bytes.shape(), tensor.shape()); + try (TensorScope scope = new TensorScope()) { + String[] strings = new String[]{"TensorFlow", "For", "Java", "Rocks", "!"}; + NdArray bytes = NdArrays.ofObjects(byte[].class, Shape.of(strings.length)); + for (int i = 0; i < strings.length; ++i) { + bytes.setObject(strings[i].getBytes(), i); + } + TString tensor = TString.tensorOfBytes(scope, bytes); + assertNotNull(tensor); + assertEquals(bytes.shape(), tensor.shape()); - NdArray tensorBytes = tensor.asBytes(); - for (int i = 0; i < strings.length; ++i) { - assertArrayEquals(bytes.getObject(i), tensorBytes.getObject(i)); + NdArray tensorBytes = tensor.asBytes(); + for (int i = 0; i < strings.length; ++i) { + assertArrayEquals(bytes.getObject(i), tensorBytes.getObject(i)); + } } } - private static final String BABY_CHICK = "\uD83D\uDC25"; + private static final String BABY_CHICK = "\uD83D\uDC25"; } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java index ce7397d5878..b7fc379b6fc 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TUint8Test extends NumericTypesTestBase { @Override - TUint8 allocateTensor(Shape shape) { - return TUint8.tensorOf(shape); + TUint8 allocateTensor(TensorScope scope, Shape shape) { + return TUint8.tensorOf(scope, shape); } @Override From dacd0c365a23f268e52d4075b3639ca80ed61bf4 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 28 Jan 2021 20:55:32 -0800 Subject: [PATCH 27/35] Add no-output run method Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Session.java | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index 7cf3f39e144..ecaa7ec57ab 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -305,10 +305,9 @@ public Runner setOptions(RunOptions options) { } /** - * Execute the graph fragments necessary to compute all requested fetches. + * Execute the graph fragments necessary to compute all requested fetches and complete all targets. * - *

WARNING: The caller assumes ownership of all returned {@link Tensor Tensors}, i.e., - * the caller must call {@link Tensor#close} on all elements of the returned list to free up resources. + *

The returned tensors will be part of the passed scope. * *

TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in SessionTest.java), and @@ -324,6 +323,17 @@ public List run(TensorScope scope) { return runHelper(scope, false).outputs; } + /** + * Execute the graph fragments necessary to compute all requested fetches and complete all targets. + * + * @see #run(TensorScope) + */ + public void runWithoutOutputs() { + try (TensorScope scope = new TensorScope()) { + run(scope); + } + } + /** * Execute graph fragments to compute requested fetches and return metadata about the run. * From 9ceeefd6a723c69a47314a7370790b74e3ef2c56 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 28 Jan 2021 21:04:01 -0800 Subject: [PATCH 28/35] Doc updates Signed-off-by: Ryan Nett --- .../java/org/tensorflow/ConcreteFunction.java | 8 ++---- .../src/main/java/org/tensorflow/Tensor.java | 7 ++--- .../java/org/tensorflow/TensorContainer.java | 5 ++-- .../main/java/org/tensorflow/TensorScope.java | 28 +++++++++++-------- 4 files changed, 23 insertions(+), 25 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index ea529d4c374..d3ba68d62a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -156,8 +156,6 @@ public Signature signature() { /** * Invokes a function. * - *

Caller is responsible for closing all Tensors. - * * @param scope the {@link TensorScope} to create the outputs in * @param arguments list of tensors to pass in input to the function, mapped by their signature name * @return output tensors resulting from the execution of the function, mapped by their signature name @@ -202,8 +200,6 @@ public Map call(TensorScope scope, Map arguments /** * Invokes a function with a single input and output. * - *

Caller is responsible for closing all Tensors. - * * @param scope the {@link TensorScope} to create the output in * @param tensor input tensor * @return output tensor @@ -244,8 +240,8 @@ public void save(String exportDir) throws IOException { * Returns the session used to execute the graph when calling this function * *

In general, a user does not need to handle directly the session of a function and rely - * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to the session might be - * necessary, as it allows more running options. + * on {@link #call(TensorScope, Map)} to execute the graph instead. But in some cases, direct access to the session + * might be necessary, as it allows more running options. * * @return the function session */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index 4c649079d2d..d5c93bf3c42 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -39,7 +39,7 @@ * doSomethingWith(t); * } * }

- *

This can be done automatically using {@link TensorScope}. + *

This can (and probably should) be done using {@link TensorScope}. *

Instances of a Tensor are not thread-safe. */ public interface Tensor extends Shaped, AutoCloseable { @@ -207,9 +207,8 @@ static T of(TensorScope scope, Class type, Shape shape, Byt /** * Release resources associated with the Tensor. - * - *

WARNING:This must be invoked for all tensors that were not been produced by an eager - * operation or memory will be leaked. May be done automatically via {@link TensorScope}. + *

All tensors should be closed using this method or {@link TensorScope}. + * Memory will not leak if they aren't, but relying on the garbage collector for cleanup is not efficient. * *

The Tensor object is no longer usable after {@code close} returns. */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java index ef1a368cd84..6052daab086 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java @@ -38,9 +38,8 @@ default void detach() { /** * Release resources associated with these tensors. - * - *

WARNING:This must be invoked for all tensors that were not been produced by an eager - * operation or memory will be leaked. May be done automatically via {@link TensorScope}. + *

All tensors should be closed using this method or {@link TensorScope}. + * Memory will not leak if they aren't, but relying on the garbage collector for cleanup is not efficient. * *

The Tensor objects are no longer usable after {@code close} returns. * diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 294e025e96a..99197c948e4 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -16,21 +16,22 @@ */ package org.tensorflow; +import java.util.Collections; import java.util.HashSet; import java.util.Set; +import java.util.WeakHashMap; /** - * A scope that can be used to manage tensor resources. Any tensors created between a scope's creation and calling - * {@code close()} that haven't been detached or attached to a different scope are guaranteed to be closed with the - * scope (even if they are created in a sub-scope). Tensors may be manually closed earlier without issue. + * A scope used to manage tensor resources. All tensor-creating methods take a scope as a parameter, and create their + * tensors in that scope. When a scope is closed, it closes all of it's attached tensors. Tensors may be manually + * closed earlier without issue, and being attached to a scope will not keep a tensor from being GC'd. + *

While tensors will be closed when GC'd, relying on the garbage collector for cleanup is not efficient. This + * class + * or manual management should be used. *

- * Tensors are automatically tracked on creation. A tensor can me manually added to a scope with {@link - * TensorScope#attach(Tensor)} or {@link Tensor#attachToCurrent()}. A tensor may only have one scope: if it currently - * has a scope when {@code attach} is called, it is removed from its original scope. - *

- * {@link Tensor#detach()} detaches the tensor from it's scope, requiring the user to close it manually or attach it to - * another scope. + * {@link TensorScope#detach(Tensor)} and {@link Tensor#detach()} detaches the tensor from it's scope, requiring the + * user to close it manually or attach it to another scope. *

* Like Tensors, TensorScope is not thread safe. */ @@ -38,7 +39,7 @@ public final class TensorScope implements AutoCloseable { /** - * Create a new tensor scope. If {@code autoAttach} is false, will not automatically manage tensors. + * Create a new tensor scope. * * @see TensorScope */ @@ -46,7 +47,10 @@ public TensorScope() { } /** - * Closes this scope and its tensors, and any inner scopes. + * Closes this scope and its tensors. + *

All tensors should be closed using this method or {@link Tensor#close()}. + * Memory will not leak if they aren't, but relying on the garbage collector for cleanup is + * not efficient. */ @Override public synchronized void close() { @@ -234,5 +238,5 @@ public synchronized boolean isClosed() { } private boolean closed = false; - private final Set tensors = new HashSet<>(); + private final Set tensors = Collections.newSetFromMap(new WeakHashMap<>()); } From 344fe2b07f74435b07e0dc0a1510f0b3e450274b Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 28 Jan 2021 21:30:36 -0800 Subject: [PATCH 29/35] Fix framework Signed-off-by: Ryan Nett --- .../framework/data/DatasetIterator.java | 14 +- .../framework/losses/impl/LossesHelper.java | 132 +- .../framework/utils/ShapeUtils.java | 30 +- .../framework/data/BatchDatasetTest.java | 34 +- .../framework/data/DatasetIteratorTest.java | 47 +- .../framework/data/MapDatasetTest.java | 47 +- .../framework/data/SkipDatasetTest.java | 22 +- .../framework/data/TakeDatasetTest.java | 21 +- .../framework/optimizers/AdamTest.java | 58 +- .../framework/optimizers/AdamaxTest.java | 65 +- .../framework/optimizers/NadamTest.java | 76 +- .../framework/utils/EagerTestSession.java | 1200 ++++++------- .../framework/utils/GraphTestSession.java | 1479 +++++++++-------- 13 files changed, 1691 insertions(+), 1534 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java index a3aa290a8c8..bc3c4eeaabf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java @@ -15,15 +15,15 @@ */ package org.tensorflow.framework.data; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.ndarray.Shape; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; import org.tensorflow.types.family.TType; /** @@ -272,7 +272,9 @@ public Iterator>> iterator() { @Override public boolean hasNext() { - return nextOptional.hasValue().asTensor().getBoolean(); + try (TensorScope scope = new TensorScope()) { + return nextOptional.hasValue().asTensor(scope).getBoolean(); + } } @Override diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index f6b0de71b0d..6a8521235fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -14,7 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.losses.impl; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.Arrays; +import java.util.Collections; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.framework.losses.Reduction; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -26,11 +31,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import java.util.Arrays; -import java.util.Collections; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * These are helper methods for Losses and Metrics and will be module private when Java modularity * is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics @@ -52,12 +52,11 @@ public class LossesHelper { * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction - * . + * . * @param the data type for the labels, predictions and result - * @return LossTuple of prediction, label,sampleWeight will - * be null. Each of them possibly has the last dimension squeezed, sampleWeight - * could be extended by one dimension. If sampleWeight is null, (prediction, - * label) is returned. + * @return LossTuple of prediction, label,sampleWeight will be null. Each of h + * of them possibly has the last dimension squeezed, sampleWeight could be extended by one dimension. If + * sampleWeight is null, (prediction, label) is returned. */ public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions) { @@ -141,8 +140,7 @@ public static LossTuple squeezeOrExpandDimensions( * Squeeze or expand the sampleWeight based on the rank difference * *

If the rank difference is +1, squeeze the last dimension of sampleWeight, If the rank - * difference is -1, expand the last dimension of sampleWeight. Otherwise, leave the shape of - * sampleWeight as is. + * difference is -1, expand the last dimension of sampleWeight. Otherwise, leave the shape of sampleWeight as is. * * @param tf the TensorFlow Ops * @param sampleWeight the sample weights @@ -180,7 +178,7 @@ private static Operand maybeExpandWeights( * * @param tf the TensorFlowOps * @param labels Label values, a Tensor whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. @@ -195,7 +193,7 @@ public static LossTuple removeSqueezableDimensions( * * @param tf the TensorFlowOps * @param labels Label values, a Operand whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). * @param the data type for the labels, predictions and result @@ -299,8 +297,8 @@ private static Operand reduceWeightedLoss( * @param losses Operand whose elements contain individual loss measurements. * @param numElements The number of measurable elements in losses. * @param the data type of the losses - * @return A scalar representing the mean of losses. If numElements is - * zero, then zero is returned. + * @return A scalar representing the mean of losses. If numElements is zero, then zero is + * returned. */ public static Operand safeMean( Ops tf, Operand losses, long numElements) { @@ -338,8 +336,7 @@ public static Operand allAxes(Ops tf, Operand op) * @param minValue the minimum value * @param maxValue the maximum value * @param the datatype for the values - * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph - * Session + * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session * @throws IllegalArgumentException if the TensorFlow Ops represents an Eager Session */ public static Operand rangeCheck( @@ -351,41 +348,44 @@ public static Operand rangeCheck( tf.reduceAll(tf.math.lessEqual(values, maxValue), allDims)); // Graph and Eager mode need to be handled differently, control dependencies are not allowed in // Eager mode - if (tf.scope().env().isGraph()) { - AssertThat assertThat = - tf.assertThat( - cond, - Arrays.asList( - tf.constant(prefix), - tf.constant(": values out of range, "), - tf.constant("minimum = "), - minValue, - tf.constant(", maximum = "), - maxValue)); - Ops ltf = - tf.withSubScope("rangeCheck") - .withControlDependencies(Collections.singletonList(assertThat)); - return ltf.identity(values); - } else if (!cond.asTensor().getBoolean()) - throw new IllegalArgumentException(String.format("%s : values out of range", prefix)); - else return values; + try (TensorScope scope = new TensorScope()) { + if (tf.scope().env().isGraph()) { + AssertThat assertThat = + tf.assertThat( + cond, + Arrays.asList( + tf.constant(prefix), + tf.constant(": values out of range, "), + tf.constant("minimum = "), + minValue, + tf.constant(", maximum = "), + maxValue)); + Ops ltf = + tf.withSubScope("rangeCheck") + .withControlDependencies(Collections.singletonList(assertThat)); + return ltf.identity(values); + } else if (!cond.asTensor(scope).getBoolean()) { + throw new IllegalArgumentException(String.format("%s : values out of range", prefix)); + } else { + return values; + } + } } /** - * Checks to see if all the values are in the allowed values set. Running the operand in Graph - * mode will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException}, if at least one - * value is not in the allowed values set. In Eager mode, this method will throw an {@link - * IllegalArgumentException} if at least one value is not in the allowed values set. + * Checks to see if all the values are in the allowed values set. Running the operand in Graph mode will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException}, if at least one value is not in the allowed values set. In + * Eager mode, this method will throw an {@link IllegalArgumentException} if at least one value is not in the allowed + * values set. * * @param tf The TensorFlow Ops * @param prefix A String prefix to include in the error message * @param values the values to check * @param allowedValues the allowed values * @param the data type for values and allowed values - * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph - * Session - * @throws IllegalArgumentException if the Session is in Eager mode and at least one value is not - * in the allowed values set + * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session + * @throws IllegalArgumentException if the Session is in Eager mode and at least one value is not in the allowed + * values set */ public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { @@ -396,26 +396,32 @@ public static Operand valueCheck( if (diffSize != Shape.UNKNOWN_SIZE) { if (diffSize != 0) { // at least 1 value in the diff did not match the allowed values. throw new IllegalArgumentException(String.format("%s : values not in value set,", prefix)); - } else return values; + } else { + return values; + } } else { // use dynamic shape - Operand cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0)); - // Graph and Eager mode need to be handled differently, control dependencies are not allowed - // in Eager mode - if (tf.scope().env().isGraph()) { - AssertThat assertThat = - tf.assertThat( - cond, - Arrays.asList( - tf.constant(prefix), - tf.constant(": values not in value set, values = "), - values)); - Ops ltf = - tf.withSubScope("valueCheck") - .withControlDependencies(Collections.singletonList(assertThat)); - return ltf.identity(values); - } else if (!cond.asTensor().getBoolean()) - throw new IllegalArgumentException(String.format("%s : values not in value set", prefix)); - else return values; + try (TensorScope scope = new TensorScope()) { + Operand cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0)); + // Graph and Eager mode need to be handled differently, control dependencies are not allowed + // in Eager mode + if (tf.scope().env().isGraph()) { + AssertThat assertThat = + tf.assertThat( + cond, + Arrays.asList( + tf.constant(prefix), + tf.constant(": values not in value set, values = "), + values)); + Ops ltf = + tf.withSubScope("valueCheck") + .withControlDependencies(Collections.singletonList(assertThat)); + return ltf.identity(values); + } else if (!cond.asTensor(scope).getBoolean()) { + throw new IllegalArgumentException(String.format("%s : values not in value set", prefix)); + } else { + return values; + } + } } } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index e730c79cfbf..260ab963e01 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -24,11 +24,9 @@ import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TIntegral; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** Various methods for processing with Shapes and Operands */ +/** + * Various methods for processing with Shapes and Operands + */ public class ShapeUtils { /** @@ -66,12 +64,14 @@ public static int[] getIntArray(Scope scope, Operand dims) { * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ public static long[] getLongArray(Scope scope, Operand dims) { - if (scope.env().isEager()) { - return getLongArray(dims.asTensor()); - } - try (Session session = new Session((Graph) scope.env()); - TIntegral tensor = (TIntegral) session.runner().fetch(dims).run().get(0)) { - return getLongArray(tensor); + try (TensorScope tensorScope = new TensorScope()) { + if (scope.env().isEager()) { + return getLongArray(dims.asTensor(tensorScope)); + } + try (Session session = new Session((Graph) scope.env())) { + TIntegral tensor = (TIntegral) session.runner().fetch(dims).run(tensorScope).get(0); + return getLongArray(tensor); + } } } @@ -112,12 +112,16 @@ public static Shape reduce(Shape shape, int axis) { axis = shape.numDimensions() + axis; } long[] array = shape.asArray(); - if (array == null) return Shape.unknown(); + if (array == null) { + return Shape.unknown(); + } long[] newArray = new long[axis]; System.arraycopy(array, 0, newArray, 0, axis - 1); long prod = array[axis - 1]; for (int i = axis; i < array.length; i++) { - if (array[i] != Shape.UNKNOWN_SIZE) prod *= array[i]; + if (array[i] != Shape.UNKNOWN_SIZE) { + prod *= array[i]; + } } newArray[axis - 1] = prod; return Shape.of(newArray); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java index 2d282e5dcf7..0545f25b407 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java @@ -15,18 +15,18 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.tensorflow.ndarray.index.Indices.range; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.tensorflow.ndarray.index.Indices.range; - public class BatchDatasetTest extends DatasetTestBase { @Test @@ -44,9 +44,9 @@ public void testEagerBatchDataset() { int count = 0; for (List> components : dataset) { - try (TInt32 batch1 = - (TInt32)components.get(0).asTensor(); - TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + try (TensorScope scope = new TensorScope()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); assertEquals(testMatrix1.slice(range(count, count + 2)), batch1); assertEquals(testMatrix2.slice(range(count, count + 2)), batch2); @@ -68,15 +68,16 @@ public void testDropLastBatch() { int count = 0; for (List> components : dataset) { + try (TensorScope scope = new TensorScope()) { - try (TInt32 batch1 = - (TInt32)components.get(0).asTensor(); - TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); assertEquals(testMatrix1.slice(range(count, count + 3)), batch1); assertEquals(testMatrix2.slice(range(count, count + 3)), batch2); count += 3; } + } } @@ -95,10 +96,9 @@ public void testKeepLastBatch() { boolean foundLastBatch = false; for (List> components : dataset) { - try (TInt32 batch1 = - (TInt32)components.get(0).asTensor(); - TInt32 batch2 = - (TInt32)components.get(1).asTensor();) { + try (TensorScope scope = new TensorScope()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); if (count == 0) { assertEquals(testMatrix1.slice(range(count, count + 3)), batch1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 882a64ba54d..70da093bb7c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -15,25 +15,26 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.types.family.TType; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; - -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.tensorflow.types.family.TType; public class DatasetIteratorTest extends DatasetTestBase { @Test public void testGraphIteration() { - try (Graph graph = new Graph()) { + try (Graph graph = new Graph(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(graph); List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); @@ -53,10 +54,10 @@ public void testGraphIteration() { int batches = 0; while (true) { try { - List outputs = session.runner().fetch(x).fetch(y).run(); + List outputs = session.runner().fetch(x).fetch(y).run(scope); - try (TInt32 xBatch = (TInt32)outputs.get(0); - TInt32 yBatch = (TInt32)outputs.get(1)) { + try (TInt32 xBatch = (TInt32) outputs.get(0); + TInt32 yBatch = (TInt32) outputs.get(1)) { assertEquals(testMatrix1.get(batches), xBatch); assertEquals(testMatrix2.get(batches), yBatch); batches++; @@ -71,22 +72,24 @@ public void testGraphIteration() { @Test public void testEagerIteration() { + try (TensorScope scope = new TensorScope()) { - Ops tf = Ops.create(); - - List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); + Ops tf = Ops.create(); - List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); + List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); - int count = 0; - for (List> outputs : dataset) { - try (TInt32 batch1 = (TInt32)outputs.get(0).asTensor(); - TInt32 batch2 = (TInt32)outputs.get(1).asTensor()) { - assertEquals(testMatrix1.get(count), batch1); - assertEquals(testMatrix2.get(count), batch2); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); - count++; + Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); + int count = 0; + for (List> outputs : dataset) { + try (TInt32 batch1 = (TInt32) outputs.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) outputs.get(1).asTensor(scope)) { + assertEquals(testMatrix1.get(count), batch1); + assertEquals(testMatrix2.get(count), batch2); + + count++; + } } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index 5f203427563..16f7edff143 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -15,24 +15,25 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.types.family.TType; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TFOutOfRangeException; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; - -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.tensorflow.types.family.TType; public class MapDatasetTest extends DatasetTestBase { + IntNdArray mapped1; IntNdArray mapped2; @@ -41,14 +42,14 @@ public void setUp() { super.setUp(); mapped1 = StdArrays.ndCopyOf( - new int[][] { - {2, 4, 6, 8, 10}, - {4, 8, 12, 16, 20}, - {6, 12, 18, 24, 30}, - {8, 16, 24, 32, 40} + new int[][]{ + {2, 4, 6, 8, 10}, + {4, 8, 12, 16, 20}, + {6, 12, 18, 24, 30}, + {8, 16, 24, 32, 40} }); - mapped2 = StdArrays.ndCopyOf(new int[][] {{2}, {0}, {2}, {2}}); + mapped2 = StdArrays.ndCopyOf(new int[][]{{2}, {0}, {2}, {2}}); } @Test @@ -77,17 +78,16 @@ public void testGraphIteration() { int batches = 0; while (true) { - try { - List outputs = session.runner().fetch(X).fetch(y).run(); + try (TensorScope scope = new TensorScope()) { + List outputs = session.runner().fetch(X).fetch(y).run(scope); - try (TInt32 XBatch = (TInt32)outputs.get(0); - TInt32 yBatch = (TInt32)outputs.get(1)) { + TInt32 XBatch = (TInt32) outputs.get(0); + TInt32 yBatch = (TInt32) outputs.get(1); - assertEquals(mapped1.get(batches), XBatch); - assertEquals(mapped2.get(batches), yBatch); + assertEquals(mapped1.get(batches), XBatch); + assertEquals(mapped2.get(batches), yBatch); - batches++; - } + batches++; } catch (TFOutOfRangeException e) { break; } @@ -113,8 +113,9 @@ public void testEagerIteration() { int count = 0; for (List> outputs : dataset) { - try (TInt32 XBatch = (TInt32)outputs.get(0).asTensor(); - TInt32 yBatch = (TInt32)outputs.get(1).asTensor()) { + try (TensorScope scope = new TensorScope()) { + TInt32 XBatch = (TInt32) outputs.get(0).asTensor(scope); + TInt32 yBatch = (TInt32) outputs.get(1).asTensor(scope); assertEquals(mapped1.get(count), XBatch); assertEquals(mapped2.get(count), yBatch); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java index d0cdb4527a5..70e297fa1e9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java @@ -15,32 +15,34 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - public class SkipDatasetTest extends DatasetTestBase { + @Test public void testEagerSkipDataset() { Ops tf = Ops.create(); Dataset dataset = Dataset.fromTensorSlices( - tf, - Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.class, TInt32.class)) + tf, + Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), + Arrays.asList(TInt32.class, TInt32.class)) .skip(2); int count = 2; for (List> components : dataset) { - try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); - TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + try (TensorScope scope = new TensorScope()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); assertEquals(testMatrix1.get(count), batch1); assertEquals(testMatrix2.get(count), batch2); count++; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java index 79a2e79c72e..19fbaf9da61 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java @@ -15,16 +15,16 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - public class TakeDatasetTest extends DatasetTestBase { @Test @@ -33,15 +33,16 @@ public void testEagerTakeDataset() { Dataset dataset = Dataset.fromTensorSlices( - tf, - Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.class, TInt32.class)) + tf, + Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), + Arrays.asList(TInt32.class, TInt32.class)) .take(4); int count = 0; for (List> components : dataset) { - try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); - TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + try (TensorScope scope = new TensorScope()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); assertEquals(testMatrix1.get(count), batch1); assertEquals(testMatrix2.get(count), batch2); count++; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index 49154882a0f..c3121325a48 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -14,8 +14,19 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; +import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -29,13 +40,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; -import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; - /** Test cases for Adam Optimizer */ public class AdamTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -138,24 +142,26 @@ public void testBasic() { final float[] powers = { (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) }; - - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("beta1_power") - .run() - .get(0)) { - result.scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); - } - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("beta2_power") - .run() - .get(0)) { - result.scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + try (TensorScope scope = new TensorScope()) { + + try (TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run(scope) + .get(0)) { + result.scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); + } + try (TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("beta2_power") + .run(scope) + .get(0)) { + result.scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + } } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index 60c17674dfe..920b3a38212 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -14,8 +14,22 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Adamax.BETA_ONE_DEFAULT; +import static org.tensorflow.framework.optimizers.Adamax.BETA_TWO_DEFAULT; +import static org.tensorflow.framework.optimizers.Adamax.FIRST_MOMENT; +import static org.tensorflow.framework.optimizers.Adamax.GradAndVar; +import static org.tensorflow.framework.optimizers.Adamax.SECOND_MOMENT; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -29,35 +43,39 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.Adamax.*; - -/** Test cases for Adamax Optimizer */ +/** + * Test cases for Adamax Optimizer + */ public class AdamaxTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private static final int VAR = 0; private static final int M = 1; private static final int V = 2; - public AdamaxTest() {} + public AdamaxTest() { + } @BeforeAll - public static void setUpClass() {} + public static void setUpClass() { + } @AfterAll - public static void tearDownClass() {} + public static void tearDownClass() { + } @BeforeEach - public void setUp() {} + public void setUp() { + } @AfterEach - public void tearDown() {} + public void tearDown() { + } - /** Test of getOptimizerName method, of class Adamax. */ + /** + * Test of getOptimizerName method, of class Adamax. + */ @Test public void testGetOptimizerName() { try (TestSession session = TestSession.createTestSession(tfMode)) { @@ -69,7 +87,9 @@ public void testGetOptimizerName() { } } - /** Test of applyDense method, of class Adamax. */ + /** + * Test of applyDense method, of class Adamax. + */ @Test public void testBasic() { @@ -148,13 +168,14 @@ public void testBasic() { // Test powers final float beta1Power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("beta1_power") - .run() - .get(0)) { + try (TensorScope scope = new TensorScope()) { + TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run(scope) + .get(0); result.scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index 849f2fbfec1..7d96bc37446 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -14,8 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -29,13 +38,11 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -/** Test cases for Nadam Optimizer */ +/** + * Test cases for Nadam Optimizer + */ public class NadamTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private static final int VAR = 0; @@ -44,21 +51,28 @@ public class NadamTest { float momentum = 1; - public NadamTest() {} + public NadamTest() { + } @BeforeAll - public static void setUpClass() {} + public static void setUpClass() { + } @AfterAll - public static void tearDownClass() {} + public static void tearDownClass() { + } @BeforeEach - public void setUp() {} + public void setUp() { + } @AfterEach - public void tearDown() {} + public void tearDown() { + } - /** Test of getOptimizerName method, of class Nadam. */ + /** + * Test of getOptimizerName method, of class Nadam. + */ @Test public void testGetOptimizerName() { try (TestSession session = TestSession.createTestSession(tfMode)) { @@ -70,7 +84,9 @@ public void testGetOptimizerName() { } } - /** Test of applyDense method, of class Nadam. */ + /** + * Test of applyDense method, of class Nadam. + */ @Test public void testBasic() { @@ -146,13 +162,15 @@ public void testBasic() { session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0)) { + try (TensorScope scope = new TensorScope()) { + + TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("momentum") + .run(scope) + .get(0); result.scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; @@ -165,13 +183,14 @@ public void testBasic() { Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); momentum = momentum * mut; - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0)) { + try (TensorScope scope = new TensorScope()) { + TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("momentum") + .run(scope) + .get(0); result.scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); @@ -198,6 +217,7 @@ public void testBasic() { session.evaluate(var1Np, var1); } } + } private FloatNdArray[] nadamUpdateNdArray( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java index 7884308c9fb..0db5784082b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java @@ -14,35 +14,53 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Predicate; +import org.tensorflow.EagerSession; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Predicate; - -import static org.junit.jupiter.api.Assertions.*; - -/** Eager Mode Test Session */ +/** + * Eager Mode Test Session + */ public class EagerTestSession extends TestSession { private final EagerSession session; private final Ops tf; - /** Create an Eager mode test session. */ + /** + * Create an Eager mode test session. + */ public EagerTestSession() { this.session = EagerSession.create(); this.tf = Ops.create(session).withName("test"); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public Ops getTF() { return tf; @@ -57,702 +75,746 @@ public EagerSession getSession() { return session; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void close() { session.close(); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public boolean isEager() { return true; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public Session getGraphSession() { return null; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public EagerSession getEagerSession() { return this.session; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(double expected, Operand input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } else if (inputType == TInt64.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } else if (inputType == TUint8.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + try (TensorScope scope = new TensorScope()) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); + } else if (inputType == TFloat64.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); + } else if (inputType == TInt32.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals((int) expected, f.getInt())); + } else if (inputType == TInt64.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals((long) expected, f.getLong())); + } else if (inputType == TUint8.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Number[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + Class inputType = input.type(); + if (inputType == TFloat32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> + assertEquals( + expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); + } else if (inputType == TFloat64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> + assertEquals( + expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); + } else if (inputType == TInt32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); + } else if (inputType == TInt64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); + } else if (inputType == TUint8.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%x). %d\n", index.getAndIncrement(), f.getByte())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%x). %d\n", index.getAndIncrement(), f.getByte())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()].byteValue(), f.getByte())); } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].byteValue(), f.getByte())); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(FloatNdArray expected, Output input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicLong index = new AtomicLong(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + try (TensorScope scope = new TensorScope()) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + Output o = (Output) input; + AtomicLong index = new AtomicLong(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); + } else if (inputType == TFloat64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - for (IntNdArray f : o.asTensor().scalars()) { - assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); - } - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> + assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); + } else if (inputType == TInt32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + index.set(0); + for (IntNdArray f : o.asTensor(scope).scalars()) { + assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); + } + } else if (inputType == TInt64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); + } else if (inputType == TUint8.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluateString(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %s\n", predicate.test(input.asTensor().getObject()), input.asTensor().getObject()); - } else { - input - .asTensor() - .scalars() - .forEachIndexed( - (idx, s) -> - System.out.printf( - "%d). %b <==> %s\n", - index.getAndIncrement(), predicate.test(s.getObject()), s.getObject())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(input.asTensor().getObject())); - } else { - input.asTensor().scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); - } - } - - /** {@inheritDoc} */ - @Override - public void evaluate(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - Output o = (Output) input; + try (TensorScope scope = new TensorScope()) { + AtomicInteger index = new AtomicInteger(); + boolean isScalar = input.shape().equals(Shape.scalar()); if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); + "0). %b <==> %s\n", predicate.test(input.asTensor(scope).getObject()), input.asTensor(scope).getObject()); } else { - o.asTensor() + input + .asTensor(scope) .scalars() .forEachIndexed( - (idx, f) -> + (idx, s) -> System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + "%d). %b <==> %s\n", + index.getAndIncrement(), predicate.test(s.getObject()), s.getObject())); } } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.asTensor().getFloat())); + assertTrue(predicate.test(input.asTensor(scope).getObject())); } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); + input.asTensor(scope).scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - if (debug) { + } + } + + /** + * {@inheritDoc} + */ + @Override + public void evaluate(Output input, Predicate predicate) { + try (TensorScope scope = new TensorScope()) { + AtomicInteger index = new AtomicInteger(); + Class inputType = input.type(); + boolean isScalar = input.shape().equals(Shape.scalar()); + if (inputType == TFloat32.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", predicate.test(o.asTensor(scope).getFloat()), o.asTensor(scope).getFloat()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getDouble()), o.asTensor().getDouble()); + assertTrue(predicate.test(o.asTensor(scope).getFloat())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getFloat()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getDouble())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getDouble()))); - } - } else if (inputType == TFloat16.class) { - Output o = (Output) input; - if (debug) { + } else if (inputType == TFloat64.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", predicate.test(o.asTensor(scope).getDouble()), o.asTensor(scope).getDouble()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); + assertTrue(predicate.test(o.asTensor(scope).getDouble())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getDouble()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getFloat())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); - } - } else if (inputType == TInt32.class) { - Output o = (Output) input; - if (debug) { + } else if (inputType == TFloat16.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", predicate.test(o.asTensor(scope).getFloat()), o.asTensor(scope).getFloat()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(o.asTensor().getInt()), o.asTensor().getInt()); + assertTrue(predicate.test(o.asTensor(scope).getFloat())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getFloat()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getInt())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getInt()))); - } - } else if (inputType == TInt64.class) { - Output o = (Output) input; - if (debug) { + } else if (inputType == TInt32.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", predicate.test(o.asTensor(scope).getInt()), o.asTensor(scope).getInt()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(o.asTensor().getLong()), o.asTensor().getLong()); + assertTrue(predicate.test(o.asTensor(scope).getInt())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getInt()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getLong())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getLong()))); - } - } else if (inputType == TUint8.class) { - Output o = (Output) input; - if (debug) { + } else if (inputType == TInt64.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", predicate.test(o.asTensor(scope).getLong()), o.asTensor(scope).getLong()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %x\n", predicate.test(o.asTensor().getByte()), o.asTensor().getByte()); + assertTrue(predicate.test(o.asTensor(scope).getLong())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %x\n", - index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getLong()))); + } + } else if (inputType == TUint8.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %x\n", predicate.test(o.asTensor(scope).getByte()), o.asTensor(scope).getByte()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %x\n", + index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + } + } + index.set(0); + if (isScalar) { + assertTrue(predicate.test(o.asTensor(scope).getByte())); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getByte()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getByte())); } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getByte()))); + fail("Unexpected Class: " + inputType); } - } else { - fail("Unexpected Class: " + inputType); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(String[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + if (debug) { + input + .asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + } + index.set(0); input - .asTensor() + .asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } - index.set(0); - input - .asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Boolean[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected size (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected size (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + if (debug) { + input + .asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + } + index.set(0); input - .asTensor() + .asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); } - index.set(0); - input - .asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Output expected, Output input) { - assert input.shape().equals(expected.shape()) - : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); - Class inputType = input.asOutput().type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + assert input.shape().equals(expected.shape()) + : String.format( + "expected shape (%s) != to input shape (%s)", + expected.shape().toString(), input.shape().toString()); + Class inputType = input.asOutput().type(); + boolean isScalar = input.shape().equals(Shape.scalar()); + if (inputType == TFloat32.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %f <==> %f\n", x.asTensor(scope).getFloat(), o.asTensor(scope).getFloat()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), x.asTensor(scope).getFloat(idx), f.getFloat())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.asTensor().getFloat(), o.asTensor().getFloat()); + assertEquals(x.asTensor(scope).getFloat(), o.asTensor(scope).getFloat(), epsilon); } else { - o.asTensor() + o.asTensor(scope) .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), x.asTensor().getFloat(idx), f.getFloat())); + (idx, f) -> assertEquals(x.asTensor(scope).getFloat(idx), f.getFloat(), epsilon)); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getFloat(), o.asTensor().getFloat(), epsilon); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(x.asTensor().getFloat(idx), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TFloat64.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %f <==> %f\n", x.asTensor(scope).getDouble(), o.asTensor(scope).getDouble()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), x.asTensor(scope).getDouble(idx), f.getDouble())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.asTensor().getDouble(), o.asTensor().getDouble()); + assertEquals(x.asTensor(scope).getDouble(), o.asTensor(scope).getDouble(), epsilon); } else { - o.asTensor() + o.asTensor(scope) .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), x.asTensor().getDouble(idx), f.getDouble())); + (idx, f) -> assertEquals(x.asTensor(scope).getDouble(idx), f.getDouble(), epsilon)); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getDouble(), o.asTensor().getDouble(), epsilon); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(x.asTensor().getDouble(idx), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt32.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %d <==> %d\n", x.asTensor(scope).getInt(), o.asTensor(scope).getInt()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), x.asTensor(scope).getInt(idx), f.getInt())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.asTensor().getInt(), o.asTensor().getInt()); + assertEquals(x.asTensor(scope).getInt(), o.asTensor(scope).getInt()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), x.asTensor().getInt(idx), f.getInt())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getInt(idx), f.getInt())); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getInt(), o.asTensor().getInt()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getInt(idx), f.getInt())); - } - } else if (inputType == TInt64.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt64.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %d <==> %d\n", x.asTensor(scope).getLong(), o.asTensor(scope).getLong()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), x.asTensor(scope).getLong(idx), f.getLong())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.asTensor().getLong(), o.asTensor().getLong()); + assertEquals(x.asTensor(scope).getLong(), o.asTensor(scope).getLong()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), x.asTensor().getLong(idx), f.getLong())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getLong(idx), f.getLong())); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getLong(), o.asTensor().getLong()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getLong(idx), f.getLong())); - } - } else if (inputType == TUint8.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TUint8.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %x <==> %x\n", x.asTensor(scope).getByte(), o.asTensor(scope).getByte()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %x <==> %x\n", + index.getAndIncrement(), x.asTensor(scope).getByte(idx), f.getByte())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %x <==> %x\n", x.asTensor().getByte(), o.asTensor().getByte()); + assertEquals(x.asTensor(scope).getByte(), o.asTensor(scope).getByte()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %x <==> %x\n", - index.getAndIncrement(), x.asTensor().getByte(idx), f.getByte())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getByte(idx), f.getByte())); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getByte(), o.asTensor().getByte()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getByte(idx), f.getByte())); - } - } else if (inputType == TString.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TString.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %s <==> %s\n", x.asTensor(scope).getObject(), o.asTensor(scope).getObject()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %s <==> %s\n", + index.getAndIncrement(), x.asTensor(scope).getObject(idx), f.getObject())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %s <==> %s\n", x.asTensor().getObject(), o.asTensor().getObject()); + assertEquals(x.asTensor(scope).getObject(), o.asTensor(scope).getObject()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %s <==> %s\n", - index.getAndIncrement(), x.asTensor().getObject(idx), f.getObject())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getObject(idx), f.getObject())); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getObject(), o.asTensor().getObject()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getObject(idx), f.getObject())); - } - } else if (inputType == TBool.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TBool.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %b <==> %b\n", x.asTensor(scope).getBoolean(), o.asTensor(scope).getBoolean()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %b\n", + index.getAndIncrement(), x.asTensor(scope).getBoolean(idx), f.getBoolean())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %b <==> %b\n", x.asTensor().getBoolean(), o.asTensor().getBoolean()); + assertEquals(x.asTensor(scope).getBoolean(), o.asTensor(scope).getBoolean()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %b\n", - index.getAndIncrement(), x.asTensor().getBoolean(idx), f.getBoolean())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getBoolean(idx), f.getBoolean())); } } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getBoolean(), o.asTensor().getBoolean()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getBoolean(idx), f.getBoolean())); - } } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void print(PrintWriter writer, Output input) { - Class inputType = input.asOutput().type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } else if (inputType == TString.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } else if (inputType == TBool.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); - } else { - writer.println("Unexpected Class: " + inputType); + try (TensorScope scope = new TensorScope()) { + Class inputType = input.asOutput().type(); + if (inputType == TFloat32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } else if (inputType == TFloat64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } else if (inputType == TInt32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } else if (inputType == TInt64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } else if (inputType == TUint8.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } else if (inputType == TString.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + } else if (inputType == TBool.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + } else { + writer.println("Unexpected Class: " + inputType); + } + writer.flush(); } - writer.flush(); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 43c0642939e..70e2b92ba67 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -14,22 +14,35 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Predicate; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Predicate; - -import static org.junit.jupiter.api.Assertions.*; - /** * Graph Mode Test Session */ @@ -110,7 +123,9 @@ public EagerSession getEagerSession() { */ @Override public void initialize() { - graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run()); + try (TensorScope scope = new TensorScope()) { + graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run(scope)); + } } /** @@ -126,84 +141,86 @@ public void run(Op op) { */ @Override public void evaluate(double expected, Operand input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TFloat64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + } + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + } + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + } + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TUint8.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + } + } + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -212,108 +229,108 @@ public void evaluate(double expected, Operand input) { */ @Override public void evaluate(Number[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - if (size != Shape.UNKNOWN_SIZE) { - assertEquals( + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + if (size != Shape.UNKNOWN_SIZE) {assertEquals( expected.length, - size, - () -> - String.format("expected length (%d) != to input length (%d)", expected.length, size)); + size,() -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); } - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + .forEach( + f -> + assertEquals( + expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TFloat64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + } + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + .forEach( + f -> + assertEquals( + expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + } + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + } + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TUint8.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + } + } + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -322,103 +339,105 @@ public void evaluate(Number[] expected, Output input) { */ @Override public void evaluate(FloatNdArray expected, Output input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicLong index = new AtomicLong(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + AtomicLong index = new AtomicLong(); + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + .forEach( + f -> + assertEquals( + expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TFloat64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + } + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + .forEach( + f -> + assertEquals( + expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + } + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + .forEach( + f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + } + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TUint8.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + } + } + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -427,30 +446,30 @@ public void evaluate(FloatNdArray expected, Output input) { */ @Override public void evaluate(String[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - if (size != Shape.UNKNOWN_SIZE) { - assertEquals( + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + if (size != Shape.UNKNOWN_SIZE) {assertEquals( expected.length, - size, - () -> - String.format("expected length (%d) != to input length (%d)", expected.length, size)); + size,() -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); } - AtomicInteger index = new AtomicInteger(); - if (debug) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TString result = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + } + } + index.set(0); try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); - } } /** @@ -458,27 +477,29 @@ public void evaluate(String[] expected, Output input) { */ @Override public void evaluate(Boolean[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TBool result = + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getObject())); + } + } + index.set(0); try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getObject())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } } - index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); - } } /** @@ -486,320 +507,322 @@ public void evaluate(Boolean[] expected, Output input) { */ @Override public void evaluate(Output expected, Output input) { - assert input.shape().equals(expected.shape()) - : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - if (!inputType.equals(expected.type())) { - throw new IllegalArgumentException( - String.format( - "Both data type must be equal, inout = %s, expected = %s", - inputType, expected.dataType())); - } - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - final Output finalExpected = (Output) expected; - if (debug) { + try (TensorScope scope = new TensorScope()) { + assert input.shape().equals(expected.shape()) + : String.format( + "expected shape (%s) != to input shape (%s)", + expected.shape().toString(), input.shape().toString()); + AtomicInteger index = new AtomicInteger(); + Class inputType = input.type(); + if (!inputType.equals(expected.type())) { + throw new IllegalArgumentException( + String.format( + "Both data type must be equal, inout = %s, expected = %s", + inputType, expected.dataType())); + } + boolean isScalar = input.shape().equals(Shape.scalar()); + if (inputType == TFloat32.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TFloat32 expectedResult = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getFloat(idx), + f.getFloat())); + } + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0); TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getFloat(idx), - f.getFloat())); + assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); + } else if (inputType == TFloat64.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TFloat64 expectedResult = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getDouble(idx), + f.getDouble())); + } + } } - } - } else if (inputType == TFloat64.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0); TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); + assertEquals(expectedResult.getDouble(), result.getDouble(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getDouble(idx), - f.getDouble())); + assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); } } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getDouble(), result.getDouble(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); + } else if (inputType == TFloat16.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TFloat16 result = + (TFloat16) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TFloat16 expectedResult = + (TFloat16) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getFloat(idx), + f.getFloat())); + } + } } - } - } else if (inputType == TFloat16.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); + (TFloat16) this.getGraphSession().runner().fetch(input).run(scope).get(0); TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat16) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getFloat(idx), - f.getFloat())); + assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } - } - index.set(0); - try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); + } else if (inputType == TInt32.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TInt32 expectedResult = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), finalExpected.asTensor(scope).getInt(idx), f.getInt())); + } + } } - } - } else if (inputType == TInt32.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0); TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); + assertEquals(expectedResult.getInt(), result.getInt(), epsilon); } else { result .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), finalExpected.asTensor().getInt(idx), f.getInt())); + (idx, f) -> assertEquals(expectedResult.getInt(idx), f.getInt(), epsilon)); } } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); - TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getInt(), result.getInt(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getInt(idx), f.getInt(), epsilon)); + } else if (inputType == TInt64.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TInt64 expectedResult = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getLong(idx), + f.getLong())); + } + } } - } - } else if (inputType == TInt64.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0); TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); + assertEquals(expectedResult.getLong(), result.getLong(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), - finalExpected.asTensor().getLong(idx), - f.getLong())); + assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); } } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); - TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getLong(), result.getLong(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); + } else if (inputType == TUint8.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TUint8 expectedResult = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getByte(idx), + f.getByte())); + } + } } - } - } else if (inputType == TUint8.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0); TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); + assertEquals(expectedResult.getByte(), result.getByte(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), - finalExpected.asTensor().getByte(idx), - f.getByte())); + assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); } } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); - TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getByte(), result.getByte(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); + } else if (inputType == TBool.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TBool result = + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TBool expectedResult = + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %b\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getBoolean(idx), + f.getBoolean())); + } + } } - } - } else if (inputType == TBool.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0); TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); + assertEquals(expectedResult.getBoolean(), result.getBoolean()); } else { result .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %b\n", - index.getAndIncrement(), - finalExpected.asTensor().getBoolean(idx), - f.getBoolean())); + (idx, f) -> assertEquals(expectedResult.getBoolean(idx), f.getBoolean())); } } - } - index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); - TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getBoolean(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getBoolean(idx), f.getBoolean())); + } else if (inputType == TString.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TString result = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TString expectedResult = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %s <==> %s\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getObject(idx), + f.getObject())); + } + } } - } - } else if (inputType == TString.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0); TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); + assertEquals(expectedResult.getObject(), result.getObject()); } else { result .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %s <==> %s\n", - index.getAndIncrement(), - finalExpected.asTensor().getObject(idx), - f.getObject())); + (idx, f) -> assertEquals(expectedResult.getObject(idx), f.getObject())); } } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); - TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getObject(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getObject(idx), f.getObject())); - } - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -808,37 +831,39 @@ public void evaluate(Output expected, Output input) { */ @Override public void evaluateString(Output input, Predicate predicate) { - boolean isScalar = input.shape().equals(Shape.scalar()); - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + boolean isScalar = input.shape().equals(Shape.scalar()); + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TString result = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %s\n", + predicate.test(result.getObject()), result.getObject()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %s\n", + index.getAndIncrement(), predicate.test(f.getObject()), f.getObject())); + } + } + } + index.set(0); try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %s\n", - predicate.test(result.getObject()), result.getObject()); + assertTrue(predicate.test(result.getObject())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %s\n", - index.getAndIncrement(), predicate.test(f.getObject()), f.getObject())); + .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } } } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getObject())); - } else { - result - .scalars() - .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); - } - } } /** @@ -846,160 +871,162 @@ public void evaluateString(Output input, Predicate predicate) { */ @Override public void evaluate(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - if (debug) { + try (TensorScope scope = new TensorScope()) { + AtomicInteger index = new AtomicInteger(); + Class inputType = input.type(); + boolean isScalar = input.shape().equals(Shape.scalar()); + if (inputType == TFloat32.class) { + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", + predicate.test(result.getFloat()), result.getFloat()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + } + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getFloat()), result.getFloat()); + assertTrue(predicate.test(result.getFloat())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getFloat()))); } } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getFloat())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getFloat()))); + } else if (inputType == TFloat64.class) { + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", + predicate.test(result.getDouble()), result.getDouble()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); + } + } } - } - } else if (inputType == TFloat64.class) { - if (debug) { + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getDouble()), result.getDouble()); + assertTrue(predicate.test(result.getDouble())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getDouble()))); } } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getDouble())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getDouble()))); + } else if (inputType == TInt32.class) { + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); + } + } } - } - } else if (inputType == TInt32.class) { - if (debug) { + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); + assertTrue(predicate.test(result.getInt())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); } } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getInt())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); + } else if (inputType == TInt64.class) { + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", + predicate.test(result.getLong()), result.getLong()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); + } + } } - } - } else if (inputType == TInt64.class) { - if (debug) { + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getLong()), result.getLong()); + assertTrue(predicate.test(result.getLong())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); } } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getLong())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); + } else if (inputType == TUint8.class) { + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", + predicate.test(result.getByte()), result.getByte()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + } + } } - } - } else if (inputType == TUint8.class) { - if (debug) { + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getByte()), result.getByte()); + assertTrue(predicate.test(result.getByte())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); } } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getByte())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); - } - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -1008,115 +1035,117 @@ public void evaluate(Output input, Predicate predic */ @Override public void print(PrintWriter writer, Output input) { - boolean isScalar = input.shape().size() == 1; + try (TensorScope scope = new TensorScope()) { + boolean isScalar = input.shape().size() == 1; - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf("%d). %f\n", index.getAndIncrement(), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + Class inputType = input.type(); + if (inputType == TFloat32.class) { + AtomicInteger index = new AtomicInteger(); + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf("%d). %f\n", index.getAndIncrement(), result.getFloat()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } } - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TFloat64.class) { + AtomicInteger index = new AtomicInteger(); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %f\n", index.getAndIncrement(), result.getDouble()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %f\n", index.getAndIncrement(), result.getDouble()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } } - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TInt32.class) { + AtomicInteger index = new AtomicInteger(); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(),result.getInt()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %d\n", index.getAndIncrement(), result.getInt()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } } - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TInt64.class) { + AtomicInteger index = new AtomicInteger(); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(), result.getLong()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %d\n", index.getAndIncrement(), result.getLong()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } } - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TUint8.class) { + AtomicInteger index = new AtomicInteger(); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %x\n", index.getAndIncrement(), result.getByte()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %x\n", index.getAndIncrement(), result.getByte()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } } - } - } else if (inputType == TBool.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TBool.class) { + AtomicInteger index = new AtomicInteger(); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %b\n", index.getAndIncrement(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + try (TBool result = + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %b\n", index.getAndIncrement(), result.getBoolean()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + } } - } - } else if (inputType == TString.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TString.class) { + AtomicInteger index = new AtomicInteger(); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %s\n", index.getAndIncrement(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + try (TString result = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %s\n", index.getAndIncrement(), result.getObject()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + } } + } else { + writer.println("Unexpected type class: " + inputType); } - } else { - writer.println("Unexpected type class: " + inputType); + writer.flush(); } - writer.flush(); } } From 9a846dc5a20138dc7d2582674faf78b06dc1a747 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 28 Jan 2021 21:47:41 -0800 Subject: [PATCH 30/35] Set TF_Tensor size properly, update docs Signed-off-by: Ryan Nett --- .../src/main/java/org/tensorflow/Tensor.java | 21 +++++++--- .../main/java/org/tensorflow/TensorScope.java | 22 ++++++++-- .../internal/c_api/AbstractTF_Tensor.java | 41 ++++++++++++++++--- 3 files changed, 70 insertions(+), 14 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index d5c93bf3c42..1f13d4e9f32 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -16,6 +16,7 @@ package org.tensorflow; import java.util.function.Consumer; +import org.bytedeco.javacpp.Pointer; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.Shaped; import org.tensorflow.ndarray.buffer.ByteDataBuffer; @@ -30,16 +31,23 @@ * allowing direct I/O operations from the JVM, while the latter is only a reference to a native tensor allowing basic * operations and flat data access.

* - *

WARNING: Resources consumed by the Tensor object must be explicitly freed by - * invoking the {@link #close()} method when the object is no longer needed. For example, using a try-with-resources - * block: + *

WARNING: Resources consumed by the Tensor object should be explicitly freed by + * invoking the {@link #close()} method on the tensor, or using {@link TensorScope}. For example, using a + * try-with-resources block: * *

{@code
- * try (Tensor t = Tensor.of(...)) {
+ * try (TensorScope scope = new TensorScope()) {
+ *   Tensor t = Tensor.of(scope, ...);
  *   doSomethingWith(t);
  * }
  * }
- *

This can (and probably should) be done using {@link TensorScope}. + * + * Dropped tensors will be closed when GC'd, but relying on the garbage collector for cleanup is inefficient. + * + * + *

JavaCPP properties are used to manage garbage collection, see {@link Pointer}. Specifically + * {@link Pointer#maxBytes} and {@link Pointer#maxPhysicalBytes}. + * *

Instances of a Tensor are not thread-safe. */ public interface Tensor extends Shaped, AutoCloseable { @@ -210,6 +218,9 @@ static T of(TensorScope scope, Class type, Shape shape, Byt *

All tensors should be closed using this method or {@link TensorScope}. * Memory will not leak if they aren't, but relying on the garbage collector for cleanup is not efficient. * + *

JavaCPP properties are used to manage garbage collection, see {@link Pointer}. Specifically + * {@link Pointer#maxBytes} and {@link Pointer#maxPhysicalBytes}. + * *

The Tensor object is no longer usable after {@code close} returns. */ @Override diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index 99197c948e4..be0eddcf382 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -20,18 +20,34 @@ import java.util.HashSet; import java.util.Set; import java.util.WeakHashMap; +import org.bytedeco.javacpp.Pointer; /** * A scope used to manage tensor resources. All tensor-creating methods take a scope as a parameter, and create their * tensors in that scope. When a scope is closed, it closes all of it's attached tensors. Tensors may be manually - * closed earlier without issue, and being attached to a scope will not keep a tensor from being GC'd. + * closed earlier without issue, and being attached to a scope will not keep a tensor from being GC'd. Using a {@code + * TensorScope} is recommended over manually closing every tensor. For example, using a try-with-resources block: + * + *

{@code
+ * try (TensorScope scope = new TensorScope()) {
+ *   Tensor t = Tensor.of(scope, ...);
+ *   doSomethingWith(t);
+ *   Tensor t2 = Tensor.of(scope, ...);
+ *   doSomething2(t);
+ * }
+ * }
+ * *

While tensors will be closed when GC'd, relying on the garbage collector for cleanup is not efficient. This - * class - * or manual management should be used. + * class or manual management should be used. + * + *

JavaCPP properties are used to manage garbage collection, see {@link Pointer}. Specifically + * {@link Pointer#maxBytes} and {@link Pointer#maxPhysicalBytes}. + * *

* {@link TensorScope#detach(Tensor)} and {@link Tensor#detach()} detaches the tensor from it's scope, requiring the * user to close it manually or attach it to another scope. + * *

* Like Tensors, TensorScope is not thread safe. */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java index fba056c6dcb..e1ceaeff5f5 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java @@ -20,29 +20,58 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_AllocateTensor; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteTensor; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewTensor; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorByteSize; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.annotation.Properties; @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_Tensor extends Pointer { + protected static class DeleteDeallocator extends TF_Tensor implements Pointer.Deallocator { - DeleteDeallocator(TF_Tensor s) { super(s); } - @Override public void deallocate() { if (!isNull()) TF_DeleteTensor(this); setNull(); } + + DeleteDeallocator(TF_Tensor s) { + super(s); + // ideally this would be TF_TensorElementCount and sizeof would be the datatype size, + // but datatype isn't stored anywhere and may be variably sized + s.capacity = TF_TensorByteSize(s); + } + + @Override + public void deallocate() { + if (!isNull()) { + TF_DeleteTensor(this); + } + setNull(); + } } - /** TensorFlow crashes if we don't pass it a deallocator, so... */ + /** + * TensorFlow crashes if we don't pass it a deallocator, so... + */ protected static Deallocator_Pointer_long_Pointer dummyDeallocator = new Deallocator_Pointer_long_Pointer() { - @Override public void call(Pointer data, long len, Pointer arg) { } + @Override + public void call(Pointer data, long len, Pointer arg) { + } }.retainReference(); - /** A reference to prevent deallocation. */ + /** + * A reference to prevent deallocation. + */ protected Pointer pointer; - public AbstractTF_Tensor(Pointer p) { super(p); } + public AbstractTF_Tensor(Pointer p) { + super(p); + } + + @Override + public int sizeof() { + return 1; + } /** * Calls TF_NewTensor(), and registers a deallocator. + * * @return TF_Tensor created. Do not call TF_DeleteTensor() on it. */ public static TF_Tensor newTensor(int dtype, long[] dims, Pointer data) { From a10043e64a1365a8b86c6763c2452100322e3dd3 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 28 Jan 2021 21:50:09 -0800 Subject: [PATCH 31/35] remove unneeded synchronizeds Signed-off-by: Ryan Nett --- .../main/java/org/tensorflow/TensorScope.java | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java index be0eddcf382..ecb8cadabae 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -65,11 +65,10 @@ public TensorScope() { /** * Closes this scope and its tensors. *

All tensors should be closed using this method or {@link Tensor#close()}. - * Memory will not leak if they aren't, but relying on the garbage collector for cleanup is - * not efficient. + * Memory will not leak if they aren't, but relying on the garbage collector for cleanup is not efficient. */ @Override - public synchronized void close() { + public void close() { if (closed) { return; } @@ -85,7 +84,7 @@ public synchronized void close() { * * @return All of this scope's now-detached tensors */ - public synchronized Set detachAll() { + public Set detachAll() { Set detachedTensors = new HashSet<>(this.tensors); detachedTensors.forEach(TensorScope::detach); closed = true; @@ -96,11 +95,9 @@ public synchronized Set detachAll() { public static T detach(T tensor) { // ensure that I'm not attaching or detaching at the same time in different threads RawTensor rt = tensor.asRawTensor(); - synchronized (rt) { - if (rt.tensorScope != null) { - rt.tensorScope.tensors.remove(rt); - rt.tensorScope = null; - } + if (rt.tensorScope != null) { + rt.tensorScope.tensors.remove(rt); + rt.tensorScope = null; } return tensor; } @@ -154,18 +151,15 @@ public static void detach(Iterable... tensors) { * * @return this */ - public synchronized T attach(T tensor) { + public T attach(T tensor) { if (this.closed) { throw new IllegalStateException("Scope has been closed, can not attach new tensor."); } RawTensor rt = tensor.asRawTensor(); - // ensure that I'm not attaching or detaching at the same time in different threads - synchronized (rt) { - detach(tensor); - rt.tensorScope = this; - tensors.add(rt); - } + detach(tensor); + rt.tensorScope = this; + tensors.add(rt); return tensor; } @@ -249,7 +243,7 @@ public final TensorScope withTensors(Iterable... tensors) { /** * Gets whether the scope is closed. */ - public synchronized boolean isClosed() { + public boolean isClosed() { return closed; } From eabf063ce890e4375b117fb80bf94e4310ddd5db Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Fri, 29 Jan 2021 16:42:12 -0800 Subject: [PATCH 32/35] Don't register memory for view tensors (i.e. from eager operations) Signed-off-by: Ryan Nett --- .../java/org/tensorflow/EagerOperation.java | 2 +- .../internal/c_api/AbstractTF_Tensor.java | 25 +++++++++++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 407efe9bf32..c5f261215b0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -164,7 +164,7 @@ private static Tensor resolveTensorHandle(TFE_TensorHandle handle, TensorScope t requireTensorHandle(handle); try (PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); - TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(); + TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(true); status.throwExceptionIfNotOK(); return RawTensor.fromHandle(tensorScope, tensor).asTypedTensor(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java index e1ceaeff5f5..940f717806a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java @@ -30,11 +30,20 @@ public abstract class AbstractTF_Tensor extends Pointer { protected static class DeleteDeallocator extends TF_Tensor implements Pointer.Deallocator { - DeleteDeallocator(TF_Tensor s) { + DeleteDeallocator(TF_Tensor s, boolean registerMemory) { super(s); - // ideally this would be TF_TensorElementCount and sizeof would be the datatype size, - // but datatype isn't stored anywhere and may be variably sized - s.capacity = TF_TensorByteSize(s); + // some tensors, like those from eager operations, are just views + if (registerMemory) { + // ideally this would be TF_TensorElementCount and sizeof would be the datatype size, + // but datatype isn't stored anywhere and may be variably sized + s.capacity = TF_TensorByteSize(s); + } else { + s.capacity = 0; + } + } + + DeleteDeallocator(TF_Tensor s) { + this(s, true); } @Override @@ -95,9 +104,11 @@ public static TF_Tensor allocateTensor(int dtype, long[] dims, long length) { return t; } - /** Registers a deallocator and returns this. */ - public TF_Tensor withDeallocator() { - return (TF_Tensor)this.deallocator(new DeleteDeallocator((TF_Tensor)this)); + /** + * Registers a deallocator and returns this. + */ + public TF_Tensor withDeallocator(boolean isView) { + return (TF_Tensor) this.deallocator(new DeleteDeallocator((TF_Tensor) this, isView)); } /** From cf8b24ba6e0a2ea648dab27568da1ee9e626bf36 Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Mon, 15 Feb 2021 18:05:52 -0800 Subject: [PATCH 33/35] Initial Rebase fixes Signed-off-by: Ryan Nett --- .../java/org/tensorflow/op/core/Constant.java | 64 ++++++++++--------- .../org/tensorflow/AutoCloseableList.java | 27 -------- .../java/org/tensorflow/WrongEnvTest.java | 5 +- .../org/tensorflow/op/core/ConstantTest.java | 11 ++-- .../org/tensorflow/op/core/IndexingTest.java | 9 ++- .../framework/utils/ShapeUtils.java | 4 ++ 6 files changed, 50 insertions(+), 70 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 21f5794186b..efe6a06e80b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -38,13 +38,13 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; +import org.tensorflow.op.Ops; import org.tensorflow.op.RawOp; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TBool; import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat16; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; @@ -1360,36 +1360,38 @@ public static Constant tensorOf(Scope scope, Shape shape) { @SuppressWarnings("unchecked") @Endpoint public static Constant tensorOf(Scope scope, Class type, Number number) { - if (type.equals(TBfloat16.class)) { - try (TBfloat16 tensor = TBfloat16.scalarOf(number.floatValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TFloat64.class)) { - try (TFloat64 tensor = TFloat64.scalarOf(number.doubleValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TFloat32.class)) { - try (TFloat32 tensor = TFloat32.scalarOf(number.floatValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TFloat16.class)) { - try (TFloat16 tensor = TFloat16.scalarOf(number.floatValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TInt64.class)) { - try (TInt64 tensor = TInt64.scalarOf(number.longValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TInt32.class)) { - try (TInt32 tensor = TInt32.scalarOf(number.intValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TUint8.class)) { - try (TUint8 tensor = TUint8.scalarOf(number.byteValue())) { - return (Constant) create(scope, tensor); + try (TensorScope tensorScope = new TensorScope()) { + if (type.equals(TBfloat16.class)) { + try (TBfloat16 tensor = TBfloat16.scalarOf(tensorScope, number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat64.class)) { + try (TFloat64 tensor = TFloat64.scalarOf(tensorScope, number.doubleValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat32.class)) { + try (TFloat32 tensor = TFloat32.scalarOf(tensorScope, number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat16.class)) { + try (TFloat16 tensor = TFloat16.scalarOf(tensorScope, number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TInt64.class)) { + try (TInt64 tensor = TInt64.scalarOf(tensorScope, number.longValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TInt32.class)) { + try (TInt32 tensor = TInt32.scalarOf(tensorScope, number.intValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TUint8.class)) { + try (TUint8 tensor = TUint8.scalarOf(tensorScope, number.byteValue())) { + return (Constant) create(scope, tensor); + } + } else { + throw new IllegalArgumentException("Tensor type " + type + " is an abstract or unknown numeric type."); } - } else { - throw new IllegalArgumentException("Tensor type " + type + " is an abstract or unknown numeric type."); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java deleted file mode 100644 index 330a40bae6b..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.tensorflow; - -import java.util.ArrayList; -import java.util.Collection; - -public final class AutoCloseableList extends ArrayList - implements AutoCloseable { - - public AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java index b2fbc1e794a..9de1cc5ba8c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java @@ -35,7 +35,8 @@ public class WrongEnvTest { @Test public void testTwoEagers() { try (EagerSession e1 = EagerSession.create(); - EagerSession e2 = EagerSession.create()) { + EagerSession e2 = EagerSession.create(); + TensorScope scope = new TensorScope()) { Ops tf1 = Ops.create(e1); Ops tf2 = Ops.create(e2); @@ -44,7 +45,7 @@ public void testTwoEagers() { Operand c = tf2.math.add(a, b); - try (TInt32 tensor = c.asTensor()) { + try (TInt32 tensor = c.asTensor(scope)) { assertEquals(11, tensor.getInt()); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 18e6e900ac1..8bf54b27ec0 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -39,12 +39,8 @@ import org.tensorflow.ndarray.buffer.FloatDataBuffer; import org.tensorflow.ndarray.buffer.IntDataBuffer; import org.tensorflow.ndarray.buffer.LongDataBuffer; -import org.tensorflow.ndarray.DoubleNdArray; -import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.IntNdArray; -import org.tensorflow.ndarray.LongNdArray; -import org.tensorflow.ndarray.NdArray; -import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TFloat16; import org.tensorflow.types.TFloat32; @@ -175,7 +171,8 @@ private static void testCreateFromNumber(Ops tf, Class type) Operand constant = tf.constant(type, 10); assertEquals(type, constant.type()); - try (TFloat64 t = tf.dtypes.cast(constant, TFloat64.class).asTensor()) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 t = tf.dtypes.cast(constant, TFloat64.class).asTensor(tensorScope)) { assertEquals(10.0, t.getDouble()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index 6e86573b7cf..a78422664e3 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -20,9 +20,10 @@ import org.junit.Test; import org.tensorflow.Graph; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.index.Indices; import org.tensorflow.ndarray.index.Index; +import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Scope; import org.tensorflow.types.TFloat32; @@ -62,9 +63,11 @@ public void testStridedSliceIndex(){ long[] shape = {10, 10, 10, 10, 10, 10, 10, 10}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class); StridedSlice output = StridedSliceHelper.stridedSlice(scope, op, slice); - try (TFloat32 result = (TFloat32) sess.runner().fetch(output.asOutput()).run().get(0)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 result = (TFloat32) sess.runner().fetch(output.asOutput()).run(tensorScope).get(0)) { // expected shape from Python tensorflow - assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.shape(), "Slice index didn't match expected (Python)"); + assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.shape(), + "Slice index didn't match expected (Python)"); } } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index 260ab963e01..4c915054077 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -14,9 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.utils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TInt32; From a50d06c678548f4950203297de84a7790e47942a Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 4 Mar 2021 14:25:31 -0800 Subject: [PATCH 34/35] More framework fixes Signed-off-by: Ryan Nett --- .../metrics/impl/AssertBroadcastableTest.java | 99 +++++++------- .../metrics/impl/BroadcastWeightsTest.java | 127 +++++++++--------- 2 files changed, 114 insertions(+), 112 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 63d666f8640..d83c50c725c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -14,9 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -26,37 +30,33 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertThrows; - public class AssertBroadcastableTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; int[][][] valueArrayI = - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new int[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; long[][][] valueArrayL = - new long[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new long[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; float[][][] valueArrayF = - new float[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new float[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; double[][][] valueArrayD = - new double[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new double[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; private void testValid( @@ -68,10 +68,11 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - List tensors = - testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); - try (Tensor weightsTensor = tensors.get(0); - Tensor valuesTensor = tensors.get(1)) { + try (TensorScope scope = new TensorScope()) { + List tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run(scope); + Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1); Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); testSession @@ -80,7 +81,7 @@ private void testValid( .feed(weightsPlaceholder, weightsTensor) .feed(valuesPlaceholder, valuesTensor) .addTarget(dynamicOp) - .run(); + .run(scope); } } @@ -103,7 +104,7 @@ public void test1x1x1() { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayD); - Operand weights = tf.constant(new double[][][] {{{5}}}); + Operand weights = tf.constant(new double[][][]{{{5}}}); testValid(testSession, tf, weights, values, TFloat64.class); } } @@ -114,7 +115,7 @@ public void test1x1xN() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayL); - Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); + Operand weights = tf.constant(new long[][][]{{{5, 7, 11, 3}}}); testValid(testSession, tf, weights, values, TInt64.class); } } @@ -125,7 +126,7 @@ public void test1xNx1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); + Operand weights = tf.constant(new int[][][]{{{5}, {11}}}); testValid(testSession, tf, weights, values, TInt32.class); } } @@ -137,7 +138,7 @@ public void test1xNxN() { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); + Operand weights = tf.constant(new int[][][]{{{5, 7, 11, 3}, {2, 13, 7, 5}}}); testValid(testSession, tf, weights, values, TInt32.class); } } @@ -149,7 +150,7 @@ public void testNx1x1() { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); + Operand weights = tf.constant(new int[][][]{{{5}}, {{7}}, {{11}}}); testValid(testSession, tf, weights, values, TInt32.class); } } @@ -162,7 +163,7 @@ public void testNx1xN() { Operand values = tf.constant(valueArrayI); Operand weights = - tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); + tf.constant(new int[][][]{{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); testValid(testSession, tf, weights, values, TInt32.class); } } @@ -176,10 +177,10 @@ public void testNxNxN() { Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( - new int[][][] { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} + new int[][][]{ + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} }); testValid(testSession, tf, weights, values, TInt32.class); } @@ -199,7 +200,7 @@ public void testInvalid1x1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5}}); + Operand weights = tf.constant(new int[][]{{5}}); testValid(testSession, tf, weights, values, TInt32.class); } }); @@ -213,7 +214,7 @@ public void testInvalidPrefixMatch() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); + Operand weights = tf.constant(new int[][]{{5, 7}, {11, 3}, {2, 12}}); testValid(testSession, tf, weights, values, TInt32.class); } }); @@ -227,7 +228,7 @@ public void testInvalidSuffixMatch() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); + Operand weights = tf.constant(new int[][]{{5, 7, 11, 3}, {2, 12, 7, 5}}); testValid(testSession, tf, weights, values, TInt32.class); } }); @@ -241,7 +242,7 @@ public void testInvalidOnesExtraDim() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][][] {{{{5}}}}); + Operand weights = tf.constant(new int[][][][]{{{{5}}}}); testValid(testSession, tf, weights, values, TInt32.class); } }); @@ -258,10 +259,10 @@ public void testInvalidPrefixMatchExtraDim() { Operand weights = tf.constant( - new int[][][][] { - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, - {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} + new int[][][][]{ + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, + {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} }); testValid(testSession, tf, weights, values, TInt32.class); } @@ -278,12 +279,12 @@ public void testInvalidSuffixMatchExtraDim() { Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( - new int[][][][] { - { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} - } + new int[][][][]{ + { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + } }); testValid(testSession, tf, weights, values, TInt32.class); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java index 3322a81fe5b..3fde6d1b204 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java @@ -14,9 +14,15 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; @@ -25,38 +31,33 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class BroadcastWeightsTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; int[][][] valueArrayI = - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new int[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; long[][][] valueArrayL = - new long[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new long[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; float[][][] valueArrayF = - new float[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new float[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; double[][][] valueArrayD = - new double[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new double[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; private void testValid( @@ -78,10 +79,10 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - List tensors = - testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); - try (Tensor weightsTensor = tensors.get(0); - Tensor valuesTensor = tensors.get(1)) { + try (TensorScope scope = new TensorScope()) { + List tensors = testSession.getGraphSession().runner().fetch(weights).fetch(values).run(scope); + Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1); Operand dynamicOp = MetricsHelper.broadcastWeights(tf, weightsPlaceholder, valuesPlaceholder); @@ -93,7 +94,7 @@ private void testValid( .feed(weightsPlaceholder, weightsTensor) .feed(valuesPlaceholder, valuesTensor) .fetch(dynamicOp) - .run(); + .run(scope); if (expected != null) { if (type.equals(TInt32.class)) { @@ -140,8 +141,8 @@ public void testValidScalar() { Operand values = tf.constant(valueArrayF); Operand weights = tf.constant(5f); Float[] expected = { - 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, - 5f + 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, + 5f }; testValid(testSession, tf, weights, values, expected, TFloat32.class); } @@ -153,10 +154,10 @@ public void test1x1x1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayD); - Operand weights = tf.constant(new double[][][] {{{5}}}); + Operand weights = tf.constant(new double[][][]{{{5}}}); Double[] expected = { - 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., - 5. + 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., + 5. }; testValid(testSession, tf, weights, values, expected, TFloat64.class); @@ -169,10 +170,10 @@ public void test1x1xN() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayL); - Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); + Operand weights = tf.constant(new long[][][]{{{5, 7, 11, 3}}}); Long[] expected = { - 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, - 11L, 3L, + 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, + 11L, 3L, }; testValid(testSession, tf, weights, values, expected, TInt64.class); } @@ -184,9 +185,9 @@ public void test1xNx1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); + Operand weights = tf.constant(new int[][][]{{{5}, {11}}}); Integer[] expected = { - 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11 + 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11 }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -198,9 +199,9 @@ public void test1xNxN() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); + Operand weights = tf.constant(new int[][][]{{{5, 7, 11, 3}, {2, 13, 7, 5}}}); Integer[] expected = { - 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, + 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -212,9 +213,9 @@ public void testNx1x1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); + Operand weights = tf.constant(new int[][][]{{{5}}, {{7}}, {{11}}}); Integer[] expected = { - 5, 5, 5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 11, 11, 11, 11, 11, 11, 11, 11 + 5, 5, 5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 11, 11, 11, 11, 11, 11, 11, 11 }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -227,9 +228,9 @@ public void testNx1xN() { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); Operand weights = - tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); + tf.constant(new int[][][]{{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); Integer[] expected = { - 5, 7, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3 + 5, 7, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3 }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -244,13 +245,13 @@ public void testNxNxN() { Operand weights = tf.constant( - new int[][][] { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} + new int[][][]{ + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} }); Integer[] expected = { - 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5 + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5 }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -270,7 +271,7 @@ public void testInvalid1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[] {5}); + Operand weights = tf.constant(new int[]{5}); testValid(testSession, tf, weights, values, null, TInt32.class); } @@ -286,7 +287,7 @@ public void testInvalid1x1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5}}); + Operand weights = tf.constant(new int[][]{{5}}); testValid(testSession, tf, weights, values, null, TInt32.class); } @@ -301,7 +302,7 @@ public void testInvalidPrefixMatch() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); + Operand weights = tf.constant(new int[][]{{5, 7}, {11, 3}, {2, 12}}); testValid(testSession, tf, weights, values, null, TInt32.class); } }); @@ -315,7 +316,7 @@ public void testInvalidSuffixMatch() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); + Operand weights = tf.constant(new int[][]{{5, 7, 11, 3}, {2, 12, 7, 5}}); testValid(testSession, tf, weights, values, null, TInt32.class); } }); @@ -329,7 +330,7 @@ public void testInvalidOnesExtraDim() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][][] {{{{5}}}}); + Operand weights = tf.constant(new int[][][][]{{{{5}}}}); testValid(testSession, tf, weights, values, null, TInt32.class); } }); @@ -346,10 +347,10 @@ public void testInvalidPrefixMatchExtraDim() { Operand weights = tf.constant( - new int[][][][] { - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, - {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} + new int[][][][]{ + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, + {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} }); testValid(testSession, tf, weights, values, null, TInt32.class); } @@ -366,12 +367,12 @@ public void testInvalidSuffixMatchExtraDim() { Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( - new int[][][][] { - { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} - } + new int[][][][]{ + { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + } }); testValid(testSession, tf, weights, values, null, TInt32.class); } From 46645eeeab53165d0a12814de06a1b29ffe3e79c Mon Sep 17 00:00:00 2001 From: Ryan Nett Date: Thu, 4 Mar 2021 14:41:10 -0800 Subject: [PATCH 35/35] Rebase fixes Signed-off-by: Ryan Nett --- .../annotations/org/tensorflow/op/Ops.java | 6 +- .../org/tensorflow/EagerOperationTest.java | 3 +- .../tensorflow/op/core/BooleanMaskTest.java | 41 +++---- .../op/core/BooleanMaskUpdateTest.java | 105 +++++++++--------- .../framework/constraints/MinMaxNormTest.java | 25 +++-- 5 files changed, 94 insertions(+), 86 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 47be3383364..b0dde6c1b6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -354,10 +354,10 @@ public final class Ops { public final SparseOps sparse; - public final TpuOps tpu; - public final BitwiseOps bitwise; + public final TpuOps tpu; + public final MathOps math; public final AudioOps audio; @@ -385,8 +385,8 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - tpu = new TpuOps(this); bitwise = new BitwiseOps(this); + tpu = new TpuOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java index 8ebb4789e8c..941b2073474 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java @@ -35,7 +35,8 @@ public class EagerOperationTest { public void failToCreateIfSessionIsClosed() { EagerSession session = EagerSession.create(); session.close(); - try (TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { + try (TensorScope scope = new TensorScope()) { + TInt32 t = TInt32.tensorOf(scope, Shape.of(2, 3)); EagerOperation op = opBuilder(session, "Const", "OutputAttrs") .setAttr("dtype", t.dataType()) diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java index a4d9293ccf8..22249d4d9ba 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java @@ -22,6 +22,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TBool; @@ -29,10 +30,12 @@ import org.tensorflow.types.TInt32; public class BooleanMaskTest { + @Test - public void testBooleanMask(){ + public void testBooleanMask() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Operand input = Constant.arrayOf(scope, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); @@ -43,25 +46,23 @@ public void testBooleanMask(){ Operand output1 = BooleanMask.create(scope, input, mask); Operand output2 = BooleanMask.create(scope, input2, mask, BooleanMask.axis(1)); - try (TFloat32 result = (TFloat32) sess.runner().fetch(output1).run().get(0)) { - // expected shape from Python tensorflow - assertEquals(Shape.of(5), result.shape()); - assertEquals(0, result.getFloat(0)); - assertEquals(1, result.getFloat(1)); - assertEquals(4, result.getFloat(2)); - assertEquals(5, result.getFloat(3)); - assertEquals(6, result.getFloat(4)); - } + TFloat32 result = (TFloat32) sess.runner().fetch(output1).run(tensorScope).get(0); + // expected shape from Python tensorflow + assertEquals(Shape.of(5), result.shape()); + assertEquals(0, result.getFloat(0)); + assertEquals(1, result.getFloat(1)); + assertEquals(4, result.getFloat(2)); + assertEquals(5, result.getFloat(3)); + assertEquals(6, result.getFloat(4)); - try (TFloat32 result = (TFloat32) sess.runner().fetch(output2).run().get(0)) { - // expected shape from Python tensorflow - assertEquals(Shape.of(5), result.shape()); - assertEquals(0, result.getFloat(0)); - assertEquals(1, result.getFloat(1)); - assertEquals(4, result.getFloat(2)); - assertEquals(5, result.getFloat(3)); - assertEquals(6, result.getFloat(4)); - } + TFloat32 result2 = (TFloat32) sess.runner().fetch(output2).run(tensorScope).get(0); + // expected shape from Python tensorflow + assertEquals(Shape.of(5), result2.shape()); + assertEquals(0, result2.getFloat(0)); + assertEquals(1, result2.getFloat(1)); + assertEquals(4, result2.getFloat(2)); + assertEquals(5, result2.getFloat(3)); + assertEquals(6, result2.getFloat(4)); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java index ab852bbffb2..ed3fd3614de 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -24,6 +24,7 @@ import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TBool; @@ -34,7 +35,8 @@ public class BooleanMaskUpdateTest { @Test public void testBooleanMaskUpdateSlice() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); @@ -47,31 +49,31 @@ public void testBooleanMaskUpdateSlice() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); - try (TInt32 result = (TInt32) results.get(0); - TInt32 bcastResult = (TInt32) results.get(1)) { + List results = sess.runner().fetch(output).fetch(bcastOutput).run(tensorScope); + TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1); - assertEquals(Shape.of(3, 3), result.shape()); + assertEquals(Shape.of(3, 3), result.shape()); - assertEquals(-1, result.getInt(0, 0)); - assertEquals(-1, result.getInt(0, 1)); - assertEquals(-1, result.getInt(0, 2)); - assertEquals(1, result.getInt(1, 0)); - assertEquals(1, result.getInt(1, 1)); - assertEquals(1, result.getInt(1, 2)); - assertEquals(2, result.getInt(2, 0)); - assertEquals(2, result.getInt(2, 1)); - assertEquals(2, result.getInt(2, 2)); + assertEquals(-1, result.getInt(0, 0)); + assertEquals(-1, result.getInt(0, 1)); + assertEquals(-1, result.getInt(0, 2)); + assertEquals(1, result.getInt(1, 0)); + assertEquals(1, result.getInt(1, 1)); + assertEquals(1, result.getInt(1, 2)); + assertEquals(2, result.getInt(2, 0)); + assertEquals(2, result.getInt(2, 1)); + assertEquals(2, result.getInt(2, 2)); - assertEquals(result, bcastResult); - } + assertEquals(result, bcastResult); } } @Test public void testBooleanMaskUpdateSliceWithBroadcast() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); @@ -84,31 +86,31 @@ public void testBooleanMaskUpdateSliceWithBroadcast() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); - try (TInt32 result = (TInt32) results.get(0); - TInt32 bcastResult = (TInt32) results.get(1)) { + List results = sess.runner().fetch(output).fetch(bcastOutput).run(tensorScope); + TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1); - assertEquals(Shape.of(3, 3), result.shape()); + assertEquals(Shape.of(3, 3), result.shape()); - assertEquals(-1, result.getInt(0, 0)); - assertEquals(-1, result.getInt(0, 1)); - assertEquals(-1, result.getInt(0, 2)); - assertEquals(1, result.getInt(1, 0)); - assertEquals(1, result.getInt(1, 1)); - assertEquals(1, result.getInt(1, 2)); - assertEquals(2, result.getInt(2, 0)); - assertEquals(2, result.getInt(2, 1)); - assertEquals(2, result.getInt(2, 2)); + assertEquals(-1, result.getInt(0, 0)); + assertEquals(-1, result.getInt(0, 1)); + assertEquals(-1, result.getInt(0, 2)); + assertEquals(1, result.getInt(1, 0)); + assertEquals(1, result.getInt(1, 1)); + assertEquals(1, result.getInt(1, 2)); + assertEquals(2, result.getInt(2, 0)); + assertEquals(2, result.getInt(2, 1)); + assertEquals(2, result.getInt(2, 2)); - assertEquals(result, bcastResult); - } + assertEquals(result, bcastResult); } } @Test public void testBooleanMaskUpdateAxis() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Operand input = Constant.tensorOf(scope, new int[][][]{{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}}); @@ -122,25 +124,24 @@ public void testBooleanMaskUpdateAxis() { Operand bcastOutput = BooleanMaskUpdate .create(scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); - try (TInt32 result = (TInt32) results.get(0); - TInt32 bcastResult = (TInt32) results.get(1)) { - - assertEquals(Shape.of(1, 1, 10), result.shape()); - - assertEquals(-1, result.getInt(0, 0, 0)); - assertEquals(-1, result.getInt(0, 0, 1)); - assertEquals(2, result.getInt(0, 0, 2)); - assertEquals(3, result.getInt(0, 0, 3)); - assertEquals(-1, result.getInt(0, 0, 4)); - assertEquals(-1, result.getInt(0, 0, 5)); - assertEquals(-1, result.getInt(0, 0, 6)); - assertEquals(7, result.getInt(0, 0, 7)); - assertEquals(8, result.getInt(0, 0, 8)); - assertEquals(9, result.getInt(0, 0, 9)); - - assertEquals(result, bcastResult); - } + List results = sess.runner().fetch(output).fetch(bcastOutput).run(tensorScope); + TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1); + + assertEquals(Shape.of(1, 1, 10), result.shape()); + + assertEquals(-1, result.getInt(0, 0, 0)); + assertEquals(-1, result.getInt(0, 0, 1)); + assertEquals(2, result.getInt(0, 0, 2)); + assertEquals(3, result.getInt(0, 0, 3)); + assertEquals(-1, result.getInt(0, 0, 4)); + assertEquals(-1, result.getInt(0, 0, 5)); + assertEquals(-1, result.getInt(0, 0, 6)); + assertEquals(7, result.getInt(0, 0, 7)); + assertEquals(8, result.getInt(0, 0, 8)); + assertEquals(9, result.getInt(0, 0, 9)); + + assertEquals(result, bcastResult); } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java index 8c2c3a54ff9..370512609c2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -1,7 +1,10 @@ package org.tensorflow.framework.constraints; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -10,9 +13,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; -import java.util.Random; -import java.util.concurrent.atomic.AtomicInteger; - class MinMaxNormTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -27,12 +27,15 @@ private float[] getSampleArray() { return result; } - /** Test of call method, of class MinMaxNorm. */ + /** + * Test of call method, of class MinMaxNorm. + */ @Test public void testCall() { float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { + for (TestSession.Mode tfMode : tfModes) { + try (TestSession session = TestSession.createTestSession(tfMode); + TensorScope scope = new TensorScope()) { Ops tf = session.getTF(); final float[] array = getSampleArray(); Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); @@ -41,15 +44,17 @@ public void testCall() { i.getAndIncrement()) { MinMaxNorm instance = new MinMaxNorm(tf, testValues[i.get()], testValues[i.get()] * 2); Operand result = instance.call(weights); - if (tfMode == TestSession.Mode.EAGER) - evaluate(session, result.asTensor(), testValues[i.get()]); - else + if (tfMode == TestSession.Mode.EAGER) { + evaluate(session, result.asTensor(scope), testValues[i.get()]); + } else { try (TFloat32 tensor = - (TFloat32) session.getGraphSession().runner().fetch(result).run().get(0)) { + (TFloat32) session.getGraphSession().runner().fetch(result).run(scope).get(0)) { evaluate(session, tensor, testValues[i.get()]); } + } } } + } } private void evaluate(TestSession session, TFloat32 tensor, float m) {