diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index ea3ef31313e..b0dde6c1b6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -354,10 +354,10 @@ public final class Ops { public final SparseOps sparse; - public final TpuOps tpu; - public final BitwiseOps bitwise; + public final TpuOps tpu; + public final MathOps math; public final AudioOps audio; @@ -385,8 +385,8 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - tpu = new TpuOps(this); bitwise = new BitwiseOps(this); + tpu = new TpuOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); @@ -527,8 +527,8 @@ public Constant array(byte... data) { * * @param scope is a scope used to add the underlying operation. * @param charset charset for encoding/decoding strings bytes. - * @param data An array containing the values to put into the new constant. String elements are - * sequences of bytes from the last array dimension. + * @param data An array containing the values to put into the new constant. String elements are sequences of bytes + * from the last array dimension. * @return the {@code String} constant */ public Constant array(Charset charset, String... data) { @@ -1189,62 +1189,62 @@ public Constant constant(LongNdArray data) { } /** - * Creates a rank-1 constant of {@code int} elements. + * Creates a rank-3 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant */ - public Constant constant(int[] data) { - return Constant.vectorOf(scope, data); + public Constant constant(double[][][] data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-3 constant of {@code int} elements. + * Creates a rank-2 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant */ - public Constant constant(int[][][] data) { + public Constant constant(double[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant containing a single {@code double} element. + * Creates a rank-5 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data The value to put into the new constant. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ - public Constant constant(double data) { - return Constant.scalarOf(scope, data); + public Constant constant(double[][][][][] data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code long} elements. + * Creates a rank-2 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ - public Constant constant(long[][][][][] data) { + public Constant constant(long[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code boolean} elements. + * Creates a constant containing a single {@code double} element. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a boolean constant + * @param data The value to put into the new constant. + * @return a double constant */ - public Constant constant(boolean[][][][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(double data) { + return Constant.scalarOf(scope, data); } /** @@ -1270,26 +1270,26 @@ public Constant constant(DoubleNdArray data) { } /** - * Creates a rank-4 constant of {@code int} elements. + * Creates a rank-3 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a boolean constant */ - public Constant constant(int[][][][] data) { + public Constant constant(boolean[][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code float} elements. + * Creates a rank-5 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a byte constant */ - public Constant constant(float[][][][][][] data) { + public Constant constant(byte[][][][][] data) { return Constant.tensorOf(scope, data); } @@ -1305,168 +1305,180 @@ public Constant constant(byte data) { } /** - * Creates a rank-3 constant of {@code boolean} elements. + * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, using the default + * UTF-8 encoding. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a boolean constant + * @param data an n-dimensional array of {@code String} elements. + * @return a string constant */ - public Constant constant(boolean[][][] data) { + public Constant constant(NdArray data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-4 constant of {@code float} elements. + * Creates a rank-1 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(float[][][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(int[] data) { + return Constant.vectorOf(scope, data); } /** - * Creates a rank-2 constant of {@code long} elements. + * Creates a rank-4 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a long constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a boolean constant */ - public Constant constant(long[][] data) { + public Constant constant(boolean[][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code byte} elements. + * Creates a rank-4 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a byte constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant */ - public Constant constant(byte[][][][][] data) { + public Constant constant(double[][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant of {@code boolean} elements that is a copy of a given n-dimensional array. + * Creates a rank-2 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data an n-dimensional array of {@code boolean} elements. - * @return a boolean constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a byte constant */ - public Constant constant(BooleanNdArray data) { + public Constant constant(byte[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-2 constant of {@code float} elements. + * Creates a rank-6 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a byte constant */ - public Constant constant(float[][] data) { + public Constant constant(byte[][][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant of {@code byte} elements that is a copy of a given n-dimensional array. + * Creates a constant of {@code boolean} elements that is a copy of a given n-dimensional array. * * @param scope is a scope used to add the underlying operation. - * @param data an n-dimensional array of {@code byte} elements. - * @return a byte constant + * @param data an n-dimensional array of {@code boolean} elements. + * @return a boolean constant */ - public Constant constant(ByteNdArray data) { + public Constant constant(BooleanNdArray data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-2 constant of {@code byte} elements. + * Creates a rank-3 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ - public Constant constant(byte[][] data) { + public Constant constant(byte[][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code double} elements. + * Creates a rank-5 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(double[][][][][] data) { + public Constant constant(int[][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-3 constant of {@code float} elements. + * Creates a rank-1 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ - public Constant constant(float[][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(float[] data) { + return Constant.vectorOf(scope, data); } /** - * Creates a rank-1 constant of {@code byte} elements. + * Creates a constant of {@code byte} elements that is a copy of a given n-dimensional array. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data an n-dimensional array of {@code byte} elements. * @return a byte constant */ - public Constant constant(byte[] data) { - return Constant.vectorOf(scope, data); + public Constant constant(ByteNdArray data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-1 constant of {@code float} elements. + * Creates a rank-4 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(float[] data) { + public Constant constant(int[][][][] data) { + return Constant.tensorOf(scope, data); + } + + /** + * Creates a rank-1 constant of {@code double} elements. + * + * @param scope is a scope used to add the underlying operation. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant + */ + public Constant constant(double[] data) { return Constant.vectorOf(scope, data); } /** - * Creates a rank-2 constant of {@code boolean} elements. + * Creates a rank-5 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a boolean constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a float constant */ - public Constant constant(boolean[][] data) { + public Constant constant(float[][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, - * using the default UTF-8 encoding. + * Creates a rank-3 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data an n-dimensional array of {@code String} elements. - * @return a string constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a float constant */ - public Constant constant(NdArray data) { + public Constant constant(float[][][] data) { return Constant.tensorOf(scope, data); } @@ -1482,46 +1494,46 @@ public Constant constant(String data) { } /** - * Creates a rank-4 constant of {@code double} elements. + * Creates a rank-2 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a boolean constant */ - public Constant constant(double[][][][] data) { + public Constant constant(boolean[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-2 constant of {@code double} elements. + * Creates a constant containing a single {@code int} element. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data The value to put into the new constant. + * @return an integer constant */ - public Constant constant(double[][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(int data) { + return Constant.scalarOf(scope, data); } /** - * Creates a constant containing a single {@code int} element. + * Creates a rank-1 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data The value to put into the new constant. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a long constant */ - public Constant constant(int data) { - return Constant.scalarOf(scope, data); + public Constant constant(long[] data) { + return Constant.vectorOf(scope, data); } /** * Creates a rank-4 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ public Constant constant(byte[][][][] data) { @@ -1529,14 +1541,14 @@ public Constant constant(byte[][][][] data) { } /** - * Creates a rank-6 constant of {@code int} elements. + * Creates a rank-6 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a double constant */ - public Constant constant(int[][][][][][] data) { + public Constant constant(double[][][][][][] data) { return Constant.tensorOf(scope, data); } @@ -1563,169 +1575,157 @@ public Constant constant(float data) { } /** - * Creates a rank-5 constant of {@code float} elements. + * Creates a rank-1 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a float constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a byte constant */ - public Constant constant(float[][][][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(byte[] data) { + return Constant.vectorOf(scope, data); } /** - * Creates a rank-3 constant of {@code double} elements. + * Creates a rank-2 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(double[][][] data) { + public Constant constant(int[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code long} elements. + * Creates a rank-4 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a long constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a float constant */ - public Constant constant(long[][][][][][] data) { + public Constant constant(float[][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-4 constant of {@code long} elements. + * Creates a rank-6 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a long constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a boolean constant */ - public Constant constant(long[][][][] data) { + public Constant constant(boolean[][][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-1 constant of {@code long} elements. + * Creates a constant of {@code float} elements that is a copy of a given n-dimensional array. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a long constant + * @param data an n-dimensional array of {@code float} elements. + * @return a float constant */ - public Constant constant(long[] data) { - return Constant.vectorOf(scope, data); + public Constant constant(FloatNdArray data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-1 constant of {@code boolean} elements. + * Creates a rank-5 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ - public Constant constant(boolean[] data) { - return Constant.vectorOf(scope, data); - } - - /** - * Creates a rank-3 constant of {@code byte} elements. - * - * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a byte constant - */ - public Constant constant(byte[][][] data) { + public Constant constant(boolean[][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code byte} elements. + * Creates a rank-6 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a byte constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a long constant */ - public Constant constant(byte[][][][][][] data) { + public Constant constant(long[][][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-2 constant of {@code int} elements. + * Creates a rank-4 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a long constant */ - public Constant constant(int[][] data) { + public Constant constant(long[][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a constant of {@code float} elements that is a copy of a given n-dimensional array. + * Creates a rank-2 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data an n-dimensional array of {@code float} elements. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ - public Constant constant(FloatNdArray data) { + public Constant constant(float[][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-5 constant of {@code int} elements. + * Creates a rank-5 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return an integer constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a long constant */ - public Constant constant(int[][][][][] data) { + public Constant constant(long[][][][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-1 constant of {@code double} elements. + * Creates a rank-6 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return a float constant */ - public Constant constant(double[] data) { - return Constant.vectorOf(scope, data); + public Constant constant(float[][][][][][] data) { + return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code boolean} elements. + * Creates a rank-3 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a boolean constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(boolean[][][][][][] data) { + public Constant constant(int[][][] data) { return Constant.tensorOf(scope, data); } /** - * Creates a rank-6 constant of {@code double} elements. + * Creates a rank-6 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. - * @return a double constant + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. + * @return an integer constant */ - public Constant constant(double[][][][][][] data) { + public Constant constant(int[][][][][][] data) { return Constant.tensorOf(scope, data); } @@ -1741,23 +1741,23 @@ public Constant constant(boolean data) { } /** - * Creates a rank-4 constant of {@code boolean} elements. + * Creates a rank-1 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ - public Constant constant(boolean[][][][] data) { - return Constant.tensorOf(scope, data); + public Constant constant(boolean[] data) { + return Constant.vectorOf(scope, data); } /** * Creates a rank-3 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ public Constant constant(long[][][] data) { @@ -1765,8 +1765,7 @@ public Constant constant(long[][][] data) { } /** - * Creates a rank-1 constant of {@code long} elements representing the size of each dimensions of - * the given shape. + * Creates a rank-1 constant of {@code long} elements representing the size of each dimensions of the given shape. * * @param scope is a scope used to add the underlying operation. * @param shape a shape @@ -1781,8 +1780,8 @@ public Constant constant(Shape shape) { * * @param scope is a scope used to add the underlying operation. * @param charset charset for encoding/decoding strings bytes. - * @param data An array containing the values to put into the new constant. String elements are - * sequences of bytes from the last array dimension. + * @param data An array containing the values to put into the new constant. String elements are sequences of bytes + * from the last array dimension. * @return the {@code String} constant */ public Constant constant(Charset charset, String[] data) { @@ -1802,8 +1801,8 @@ public Constant constant(Charset charset, String data) { } /** - * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, - * using the given encoding. + * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, using the given + * encoding. * * @param scope is a scope used to add the underlying operation. * @param charset charset used to encode/decode string bytes. @@ -1854,29 +1853,28 @@ public Constant constant(Shape shape, ByteDataBuffer data) { } /** - * Create a {@link TInt64} constant with data from the given buffer. + * Create a {@link TString} constant with data from the given buffer, using the default UTF-8 encoding. * * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. - * @return a long constant + * @return a string constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public Constant constant(Shape shape, LongDataBuffer data) { + public Constant constant(Shape shape, DataBuffer data) { return Constant.tensorOf(scope, shape, data); } /** - * Create a {@link TString} constant with data from the given buffer, using the default UTF-8 - * encoding. + * Create a {@link TInt64} constant with data from the given buffer. * * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. * @param data a buffer containing the tensor data. - * @return a string constant + * @return a long constant * @throws IllegalArgumentException If the tensor shape is not compatible with the buffer */ - public Constant constant(Shape shape, DataBuffer data) { + public Constant constant(Shape shape, LongDataBuffer data) { return Constant.tensorOf(scope, shape, data); } @@ -1943,8 +1941,7 @@ public Constant constant(Charset charset, Shape shape, DataBuffer Constant constant(Class type, Shape shape, ByteDataBuffer data) { return Constant.tensorOf(scope, type, shape, data); diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java index 0ffd6c2205e..eba6d755ce0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/AbstractOperation.java @@ -84,8 +84,9 @@ public String toString() { * *

This is only supported in an eager execution environment. * + * @param scope the {@link TensorScope} to create the tensor in * @param outputIdx index of the output of this operation * @return output tensor */ - abstract Tensor tensor(int outputIdx); + abstract Tensor tensor(TensorScope scope, int outputIdx); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java index 71dc0f7cefc..d3ba68d62a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/ConcreteFunction.java @@ -16,9 +16,9 @@ package org.tensorflow; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.ListIterator; -import java.util.HashMap; import java.util.Map; import java.util.function.Function; import org.tensorflow.op.Ops; @@ -43,12 +43,12 @@ public class ConcreteFunction implements AutoCloseable { * Creates a function by building a new graph. * *

The {@code functionBuilder} must initialize the function graph from the provided - * {@link Ops} instance and return a valid signature that will be used to feed the input tensors - * and fetch the output tensors on execution. + * {@link Ops} instance and return a valid signature that will be used to feed the input tensors and fetch the output + * tensors on execution. * *

The function will be the owner of the new graph and its resulting session. Therefore, - * the function must be enclosed properly with a try-with-resources block to guarantee that - * all native resources will be freed once the function is discarded. For example: + * the function must be enclosed properly with a try-with-resources block to guarantee that all native resources will + * be freed once the function is discarded. For example: * *

{@code
    * public class MyModel {
@@ -87,8 +87,8 @@ public static ConcreteFunction create(Function functionBuilder)
    * Create a function from a signature and an existing graph.
    *
    * 

The function will keep the ownership of the session used to run the graph but not - * the graph itself, meaning that the lifetime of the latter can extend beyond the scope - * of the function. For example: + * the graph itself, meaning that the lifetime of the latter can extend beyond the scope of the function. For + * example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -116,8 +116,8 @@ public static ConcreteFunction create(Signature signature, Graph graph) {
    * Create a function from a signature and a valid graph session.
    *
    * 

The function will not own the session nor its graph, meaning that their lifetime - * can extend beyond the scope of the function. Therefore the function does not need to be - * closed after its usage. For example: + * can extend beyond the scope of the function. Therefore the function does not need to be closed after its usage. For + * example: * *

{@code
    * try (Graph g = new Graph()) {
@@ -156,14 +156,11 @@ public Signature signature() {
   /**
    * Invokes a function.
    *
-   * 

Caller is responsible for closing all Tensors. - * - * @param arguments list of tensors to pass in input to the function, - * mapped by their signature name - * @return output tensors resulting from the execution of the function, - * mapped by their signature name + * @param scope the {@link TensorScope} to create the outputs in + * @param arguments list of tensors to pass in input to the function, mapped by their signature name + * @return output tensors resulting from the execution of the function, mapped by their signature name */ - public Map call(Map arguments) + public Map call(TensorScope scope, Map arguments) throws IllegalArgumentException { final SignatureDef signatureDef = signature.asSignatureDef(); @@ -180,13 +177,13 @@ public Map call(Map arguments) Map outputToNode = signatureDef.getOutputsMap(); outputToNode.values().forEach(t -> runner.fetch(t.getName())); - List resultTensors = runner.run(); + List resultTensors = runner.run(scope); try { ListIterator resultTensorIter = resultTensors.listIterator(); Map returnMap = new HashMap(); // Use the output names as present in the signature definition - for (String nodeName: outputToNode.keySet()) { + for (String nodeName : outputToNode.keySet()) { returnMap.put(nodeName, resultTensorIter.next()); } return returnMap; @@ -203,29 +200,27 @@ public Map call(Map arguments) /** * Invokes a function with a single input and output. * - *

Caller is responsible for closing all Tensors. - * + * @param scope the {@link TensorScope} to create the output in * @param tensor input tensor * @return output tensor - * @throws IllegalArgumentException if there are multiple input or output parameters defined - * in the function + * @throws IllegalArgumentException if there are multiple input or output parameters defined in the function */ - public Tensor call(Tensor tensor) throws IllegalArgumentException { + public Tensor call(TensorScope scope, Tensor tensor) throws IllegalArgumentException { final SignatureDef signatureDef = signature.asSignatureDef(); if (signatureDef.getInputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); + String.format("Function [%s] requires multiple inputs", signatureDef.getMethodName())); } String inputNodeName = signatureDef.getInputsMap().values().iterator().next().getName(); if (signatureDef.getOutputsCount() != 1) { throw new IllegalArgumentException( - String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); + String.format("Function [%s] has multiple outputs", signatureDef.getMethodName())); } String outputNodeName = signatureDef.getOutputsMap().values().iterator().next().getName(); - return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run().get(0); + return session.runner().feed(inputNodeName, tensor).fetch(outputNodeName).run(scope).get(0); } /** @@ -245,8 +240,8 @@ public void save(String exportDir) throws IOException { * Returns the session used to execute the graph when calling this function * *

In general, a user does not need to handle directly the session of a function and rely - * on {@link #call(Map)} to execute the graph instead. But in some cases, direct access to - * the session might be necessary, as it allows more running options. + * on {@link #call(TensorScope, Map)} to execute the graph instead. But in some cases, direct access to the session + * might be necessary, as it allows more running options. * * @return the function session */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java index 9f87fd8b95e..c5f261215b0 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/EagerOperation.java @@ -35,8 +35,8 @@ * Implementation of an {@link Operation} executed eagerly. * *

EagerOperation instances are valid only as long as the {@link EagerSession} they are a part of - * is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the - * EagerOperation instance may fail with an {@code IllegalStateException}. + * is valid. Thus, if {@link EagerSession#close()} has been invoked, then methods on the EagerOperation instance may + * fail with an {@code IllegalStateException}. * *

EagerOperation instances are thread-safe. */ @@ -120,10 +120,10 @@ DataType dtype(int outputIndex) { } @Override - Tensor tensor(int outputIndex) { + Tensor tensor(TensorScope scope, int outputIndex) { Tensor tensor = outputTensors.get(outputIndex); if (tensor == null) { - tensor = resolveTensor(outputIndex); + tensor = resolveTensor(scope, outputIndex); } return tensor; } @@ -133,11 +133,11 @@ Tensor tensor(int outputIndex) { private final String name; private final AtomicReferenceArray outputTensors; - private Tensor resolveTensor(int outputIndex) { + private Tensor resolveTensor(TensorScope scope, int outputIndex) { // Take an optimistic approach, where we attempt to resolve the output tensor without locking. // If another thread has resolved it meanwhile, release our copy and reuse the existing one // instead. - Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), session); + Tensor tensor = resolveTensorHandle(getUnsafeNativeHandle(outputIndex), scope); if (!outputTensors.compareAndSet(outputIndex, null, tensor)) { session.detach(tensor.asRawTensor().nativeHandle()); tensor = outputTensors.get(outputIndex); @@ -160,13 +160,13 @@ private static void requireTensorHandle(TFE_TensorHandle handle) { } } - private static Tensor resolveTensorHandle(TFE_TensorHandle handle, EagerSession session) { + private static Tensor resolveTensorHandle(TFE_TensorHandle handle, TensorScope tensorScope) { requireTensorHandle(handle); try (PointerScope scope = new PointerScope()) { TF_Status status = TF_Status.newStatus(); - TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(); + TF_Tensor tensor = TFE_TensorHandleResolve(handle, status).withDeallocator(true); status.throwExceptionIfNotOK(); - return RawTensor.fromHandle(tensor, session).asTypedTensor(); + return RawTensor.fromHandle(tensorScope, tensor).asTypedTensor(); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java index fbad92160a2..220abc8e590 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GraphOperation.java @@ -37,8 +37,8 @@ * Implementation for an {@link Operation} added as a node to a {@link Graph}. * *

GraphOperation instances are valid only as long as the {@link Graph} they are a part of is - * valid. Thus, if {@link Graph#close()} has been invoked, then methods on the GraphOperation - * instance may fail with an {@code IllegalStateException}. + * valid. Thus, if {@link Graph#close()} has been invoked, then methods on the GraphOperation instance may fail with an + * {@code IllegalStateException}. * *

GraphOperation instances are immutable and thread-safe. */ @@ -166,7 +166,7 @@ DataType dtype(int outputIdx) { } @Override - Tensor tensor(int outputIdx) { + Tensor tensor(TensorScope scope, int outputIdx) { throw new IllegalStateException("Graph tensors must be fetched by running a session"); } @@ -236,7 +236,9 @@ private static long[] shape(TF_Graph graphHandle, TF_Operation opHandle, int out TF_Status status = TF_Status.newStatus(); int numDims = TF_GraphGetTensorNumDims(graphHandle, output, status); status.throwExceptionIfNotOK(); - if (numDims < 0) return null; + if (numDims < 0) { + return null; + } long[] dims = new long[numDims]; TF_GraphGetTensorShape(graphHandle, output, dims, numDims, status); status.throwExceptionIfNotOK(); @@ -250,8 +252,8 @@ private static int dtype(TF_Graph graphHandle, TF_Operation opHandle, int output int numOutputs = TF_OperationNumOutputs(opHandle); if (outputIndex < 0 || outputIndex >= numOutputs) { - throw new IndexOutOfBoundsException("invalid output index (" + outputIndex - + ") for an operation that has " + numOutputs + " outputs"); + throw new IndexOutOfBoundsException("invalid output index (" + outputIndex + + ") for an operation that has " + numOutputs + " outputs"); } try (PointerScope scope = new PointerScope()) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java index 80f62eb5acc..5cdaae9e22e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Operand.java @@ -58,11 +58,12 @@ public interface Operand extends Op, Shaped { * * Only works when running in an eager execution * + * @param scope the {@link TensorScope} to create the tensor in * @return the tensor * @throws IllegalStateException if this is an operand of a graph */ - default T asTensor() { - return asOutput().asTensor(); + default T asTensor(TensorScope scope) { + return asOutput().asTensor(scope); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java index 9e7dedfdc75..5e40850acd8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Output.java @@ -33,31 +33,36 @@ */ public final class Output implements Operand { - /** Returns the index into the outputs of the Operation. */ + /** + * Returns the index into the outputs of the Operation. + */ public int index() { return index; } - /** Returns the DataType of the tensor referred to by this Output. */ + /** + * Returns the DataType of the tensor referred to by this Output. + */ @SuppressWarnings("unchecked") public DataType dataType() { return operation.dtype(index); } - /** Returns the type of the tensor referred to by this Output. */ + /** + * Returns the type of the tensor referred to by this Output. + */ @SuppressWarnings("unchecked") @Override public Class type() { - return (Class)TensorTypeRegistry.find(dataType()).type(); + return (Class) TensorTypeRegistry.find(dataType()).type(); } /** - * Returns this Output object with the type {@code Output}. This method is useful when given a - * value of type {@code Output}. + * Returns this Output object with the type {@code Output}. This method is useful when given a value of type {@code + * Output}. * * @param type any supported tensor type - * @throws IllegalArgumentException if the actual data type of this object does not match the type - * {@code U}. + * @throws IllegalArgumentException if the actual data type of this object does not match the type {@code U}. */ @SuppressWarnings("unchecked") public Output expect(Class type) { @@ -72,8 +77,7 @@ public Output expect(Class type) { * Returns the tensor at this output. * *

This operation is only supported on the outputs of an operation executed eagerly. For graph - * environments, output tensors must be fetched by running a session, using {@link - * Session.Runner#fetch(Output)}. + * environments, output tensors must be fetched by running a session, using {@link Session.Runner#fetch(Output)}. * *

It is recommended to close explicitly the returned tensor as soon as possible, since the * garbage collector is not aware of the amount of memory it consumes, which can be significant. @@ -84,8 +88,8 @@ public Output expect(Class type) { * @see EagerSession */ @SuppressWarnings("unchecked") - public T asTensor() { - return (T)operation.tensor(index); + public T asTensor(TensorScope scope) { + return (T) operation.tensor(scope, index); } /** @@ -130,7 +134,9 @@ public String toString() { operation.type(), operation.name(), index, shape().toString(), dataType()); } - /** Handle to the idx-th output of the Operation {@code op}. */ + /** + * Handle to the idx-th output of the Operation {@code op}. + */ Output(AbstractOperation op, int idx) { operation = op; index = idx; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java index c332fd7f1d1..c738aac3e4c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/RawTensor.java @@ -34,9 +34,9 @@ * A tensor which memory has not been mapped to a data space directly accessible from the JVM. * *

A raw tensor is a minimalist representation of a tensor allocated in native memory by the - * TensorFlow runtime library and it controls its lifetime within the current process. The data - * is represented by a flat {@link ByteDataBuffer buffer of bytes}, until it is mapped in a - * n-dimensional typed space by a {@link TType typed tensor}.

+ * TensorFlow runtime library and it controls its lifetime within the current process. The data is represented by a flat + * {@link ByteDataBuffer buffer of bytes}, until it is mapped in a n-dimensional typed space by a {@link TType typed + * tensor}.

* *

Instances of a RawTensor are not thread-safe and their resource must be released * by calling {@link #close()} explicitly or implicitly via try-with-resources.

@@ -65,7 +65,19 @@ public RawTensor asRawTensor() { @Override public void close() { - tensorScope.close(); + if (!isClosed()) { + pointerScope.close(); + } + } + + @Override + public boolean isClosed() { + return tensorHandle.isNull(); + } + + @Override + public boolean isAttached() { + return tensorScope != null; } /** @@ -93,22 +105,23 @@ public String toString() { * Allocates a new tensor in native memory of the given type, shape and size. * *

The size of the tensor must be at least large enough to contain all scalars for the - * given type and shape. More memory can also be allocated to store also metadata within the - * tensor itself, e.g. a lookup table in a string tensor. + * given type and shape. More memory can also be allocated to store also metadata within the tensor itself, e.g. a + * lookup table in a string tensor. * + * @param tensorScope the {@link TensorScope} to create the tensor in * @param type tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor, or -1 to compute the size from the shape * @return allocated tensor - * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to - * store the tensor data - * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given - * {@code type} are of variable length (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to store the tensor + * data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code type} are of + * variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension() + * unknown} * @throws IllegalStateException if tensor failed to be allocated */ - static RawTensor allocate(Class type, Shape shape, long size) { + static RawTensor allocate(TensorScope tensorScope, Class type, Shape shape, long size) { if (shape.hasUnknownDimension()) { throw new IllegalArgumentException( "Cannot allocate a tensor from a totally or partially unknown shape"); @@ -131,9 +144,9 @@ static RawTensor allocate(Class type, Shape shape, long size) { TF_Tensor nativeHandle = allocate(typeInfo.dataType().getNumber(), shape.asArray(), allocatedSize); try (PointerScope scope = new PointerScope()) { scope.attach(nativeHandle); - RawTensor t = new RawTensor(typeInfo, shape); + RawTensor t = new RawTensor(typeInfo, shape, tensorScope); t.tensorHandle = nativeHandle; - t.tensorScope = scope.extend(); + t.pointerScope = scope.extend(); return t; } } @@ -143,31 +156,20 @@ static RawTensor allocate(Class type, Shape shape, long size) { * *

Takes ownership of the handle. */ - static RawTensor fromHandle(TF_Tensor handle) { + static RawTensor fromHandle(TensorScope tensorScope, TF_Tensor handle) { TensorTypeInfo typeInfo = TensorTypeRegistry.find(DataType.forNumber(dtype(handle))); - RawTensor t = new RawTensor(typeInfo, Shape.of(shape(handle))); + RawTensor t = new RawTensor(typeInfo, Shape.of(shape(handle)), tensorScope); try (PointerScope scope = new PointerScope()) { - scope.attach(handle); - t.tensorHandle = handle; - t.tensorScope = scope.extend(); + scope.attach(handle); + t.tensorHandle = handle; + t.pointerScope = scope.extend(); } return t; } - /** - * Create an eager Tensor object from a handle to the C TF_Tensor object. - * - *

Takes ownership of the handle. - */ - static RawTensor fromHandle(TF_Tensor handle, EagerSession session) { - RawTensor t = fromHandle(handle); - session.attach(handle); - t.tensorScope.detach(handle); - return t; - } - /** * Returns the native handle to this tensor + * * @throws IllegalStateException if tensor has been closed */ TF_Tensor nativeHandle() { @@ -216,12 +218,19 @@ private static long[] shape(TF_Tensor handle) { return dims; } - RawTensor(TensorTypeInfo typeInfo, Shape shape) { + RawTensor(TensorTypeInfo typeInfo, Shape shape, TensorScope tensorScope) { this.typeInfo = typeInfo; this.shape = shape; + if (tensorScope == null) { + throw new NullPointerException("Can't create a tensor with a null TensorScope"); + } + + tensorScope.attach(this); + this.tensorScope = tensorScope; } - private PointerScope tensorScope; + private PointerScope pointerScope; + TensorScope tensorScope; private TF_Tensor tensorHandle; private final TensorTypeInfo typeInfo; private final Shape shape; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java index 0974cc94a24..dc213a227e7 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/SavedModelBundle.java @@ -51,19 +51,22 @@ * SavedModelBundle represents a model loaded from storage. * *

The model consists of a description of the computation (a {@link Graph}), a {@link Session} - * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, - * and a description of the model as a MetaGraphDef + * with tensors (e.g., parameters or variables in the graph) initialized to values saved in storage, and a description + * of the model as a MetaGraphDef * protocol buffer. */ public class SavedModelBundle implements AutoCloseable { public static final String DEFAULT_TAG = "serve"; - /** Options for loading a SavedModel. */ + /** + * Options for loading a SavedModel. + */ public static final class Loader { - /** Load a SavedModelBundle with the configured options. */ + /** + * Load a SavedModelBundle with the configured options. + */ public SavedModelBundle load() { return SavedModelBundle.load(exportDir, tags, configProto, runOptions); } @@ -71,9 +74,8 @@ public SavedModelBundle load() { /** * Sets options to use when executing model initialization operations. * - * @param options A RunOptions - * protocol buffer. + * @param options A RunOptions + * protocol buffer. * @return this object */ public Loader withRunOptions(RunOptions options) { @@ -84,9 +86,8 @@ public Loader withRunOptions(RunOptions options) { /** * Set configuration of the Session object created when loading the model. * - * @param configProto A ConfigProto - * protocol buffer. + * @param configProto A ConfigProto + * protocol buffer. * @return this object */ public Loader withConfigProto(ConfigProto configProto) { @@ -114,12 +115,14 @@ private Loader(String exportDir) { } private String exportDir = null; - private String[] tags = { DEFAULT_TAG }; + private String[] tags = {DEFAULT_TAG}; private ConfigProto configProto = null; private RunOptions runOptions = null; } - /** Options for exporting a SavedModel. */ + /** + * Options for exporting a SavedModel. + */ public static final class Exporter { /** @@ -144,9 +147,9 @@ public Exporter withTags(String... tags) { * names to a graph) and a valid session to a graph to be saved in the model. * *

Note:Eventually, TensorFlow for Java will support the export of functions objects like - * the Python API does but right now, only session-centric models are supported (i.e. models that - * has a single main graph and one or more signatures). These models are compatible with those - * exported by TensorFlow 1.x or by TensorFlow 2.x estimators. + * the Python API does but right now, only session-centric models are supported (i.e. models that has a single main + * graph and one or more signatures). These models are compatible with those exported by TensorFlow 1.x or by + * TensorFlow 2.x estimators. * *
Therefore, all functions exported in a model should share the same session at the moment * or an exception will be thrown.
@@ -154,8 +157,8 @@ public Exporter withTags(String... tags) { * @param function a function carrying a signature and a valid session to the graph to be saved * @return this object * @throws IllegalArgumentException if a function with the same name has already been added to the model - * @throws UnsupportedOperationException if this function does not share the same session with the other - * functions added to this model + * @throws UnsupportedOperationException if this function does not share the same session with the other functions + * added to this model */ public Exporter withFunction(ConcreteFunction function) { Signature signature = function.signature(); @@ -166,7 +169,8 @@ public Exporter withFunction(ConcreteFunction function) { if (session == null) { session = function.session(); } else if (session != function.session()) { - throw new UnsupportedOperationException("Saving multiple functions with different graphs/sessions is not supported yet."); + throw new UnsupportedOperationException( + "Saving multiple functions with different graphs/sessions is not supported yet."); } metaGraphDefBuilder.putSignatureDef(signature.key(), signature.asSignatureDef()); return this; @@ -213,16 +217,15 @@ public void export() throws IOException { } private final String exportDir; - private String[] tags = { DEFAULT_TAG }; + private String[] tags = {DEFAULT_TAG}; private final MetaGraphDef.Builder metaGraphDefBuilder = MetaGraphDef.newBuilder(); private final Map functions = new LinkedHashMap<>(); private Session session; } /** - * Load a saved model from an export directory. The model that is being loaded should be created - * using the Saved Model - * API. + * Load a saved model from an export directory. The model that is being loaded should be created using the Saved Model API. * *

This method is a shorthand for: * @@ -267,15 +270,16 @@ public static Exporter exporter(String exportDir) { } /** - * Returns the MetaGraphDef + * Returns the MetaGraphDef * protocol buffer associated with the saved model. */ public MetaGraphDef metaGraphDef() { return metaGraphDef; } - /** Returns the graph that describes the computation performed by the model. */ + /** + * Returns the graph that describes the computation performed by the model. + */ public Graph graph() { return graph; } @@ -306,8 +310,7 @@ public List signatures() { * * @param signatureKey name of the {@code SignatureDef} in the saved model. * @return object that can be used to make calls to a function - * @throws IllegalArgumentException if {@code signatureKey} is not found in this - * saved model. + * @throws IllegalArgumentException if {@code signatureKey} is not found in this saved model. */ public ConcreteFunction function(String signatureKey) { ConcreteFunction function = functions.get(signatureKey); @@ -330,11 +333,12 @@ public ConcreteFunction function(String signatureKey) { * *

Caller is responsible for closing all returned Tensors. * + * @param scope the {@link TensorScope} to create the outputs in * @param arguments list of input tensors, mapped by their signature name * @return list of output tensors, mapped by the signature name * @throws IllegalArgumentException if no function can be selected by default */ - public Map call(Map arguments) { + public Map call(TensorScope scope, Map arguments) { ConcreteFunction function = null; if (functions.size() == 1) { function = functions.values().iterator().next(); @@ -344,12 +348,11 @@ public Map call(Map arguments) { if (function == null) { throw new IllegalArgumentException("Cannot elect a default function for this model"); } - return function.call(arguments); + return function.call(scope, arguments); } /** - * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model - * bundle. + * Releases resources (the {@link Graph} and {@link Session}) associated with the saved model bundle. */ @Override public void close() { @@ -362,7 +365,8 @@ public void close() { private final MetaGraphDef metaGraphDef; private final Map functions; - private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, Map functions) { + private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef, + Map functions) { this.graph = graph; this.session = session; this.metaGraphDef = metaGraphDef; @@ -370,8 +374,8 @@ private SavedModelBundle(Graph graph, Session session, MetaGraphDef metaGraphDef } /** - * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session - * object, plus the MetaGraphDef. + * Create a SavedModelBundle object from a handle to the C TF_Graph object and to the C TF_Session object, plus the + * MetaGraphDef. * *

Invoked from the native load method. Takes ownership of the handles. */ diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java index b67f4a611e6..ecaa7ec57ab 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Session.java @@ -15,7 +15,16 @@ package org.tensorflow; +import static org.tensorflow.Graph.resolveOutputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_CloseSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewSession; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SessionRun; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_SetConfig; + import com.google.protobuf.InvalidProtocolBufferException; +import java.util.ArrayList; +import java.util.List; import org.bytedeco.javacpp.BytePointer; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.PointerPointer; @@ -33,15 +42,9 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.RunMetadata; import org.tensorflow.proto.framework.RunOptions; - -import java.util.ArrayList; -import java.util.List; import org.tensorflow.proto.util.SaverDef; import org.tensorflow.types.TString; -import static org.tensorflow.Graph.resolveOutputs; -import static org.tensorflow.internal.c_api.global.tensorflow.*; - /** * Driver for {@link Graph} execution. * @@ -84,11 +87,9 @@ public Session(Graph g) { * Construct a new session with the associated {@link Graph} and configuration options. * * @param g The {@link Graph} the created Session will operate on. - * @param config Configuration parameters for the session specified as a ConfigProto - * protocol buffer. - * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto - * protocol buffer. + * @param config Configuration parameters for the session specified as a ConfigProto + * protocol buffer. + * @throws IllegalArgumentException if the config is not a valid serialization of the ConfigProto protocol buffer. */ public Session(Graph g, ConfigProto config) { graph = g; @@ -101,7 +102,9 @@ public Session(Graph g, ConfigProto config) { } } - /** Wrap an existing session with the associated {@link Graph}. */ + /** + * Wrap an existing session with the associated {@link Graph}. + */ Session(Graph g, TF_Session nativeHandle) { graph = g; this.nativeHandle = nativeHandle; @@ -111,7 +114,7 @@ public Session(Graph g, ConfigProto config) { /** * Release resources associated with the Session. * - *

Blocks until there are no active executions ({@link Session.Runner#run()} calls). A Session + *

Blocks until there are no active executions ({@link Session.Runner#run(TensorScope)} calls). A Session * is not usable after close returns. */ @Override @@ -139,22 +142,20 @@ public void close() { * Run {@link Operation}s and evaluate {@link Tensor Tensors}. * *

A Runner runs the necessary graph fragments to execute every {@link Operation} required to - * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String,int,Tensor)} call allows - * callers to override the value of {@link Tensor Tensors} in the graph by substituting the - * provided {@link Tensor Tensors} for the outputs of the operations provided to {@link - * #feed(String,int,Tensor)}. + * evaluate the {@link Tensor Tensors} to fetch. The {@link #feed(String, int, Tensor)} call allows callers to override + * the value of {@link Tensor Tensors} in the graph by substituting the provided {@link Tensor Tensors} for the + * outputs of the operations provided to {@link #feed(String, int, Tensor)}. */ public final class Runner { /** * Avoid evaluating {@code operation} and substitute {@code t} for the value it produces. * - * @param operation Is either the string name of the operation, in which case this method is a - * shorthand for {@code feed(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * feed(operation_name, output_index)}. These colon-separated names are commonly used in the - * {@code SignatureDef} protocol buffer messages that are included in {@link - * SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code + * feed(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * feed(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} + * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. * @param t the tensor substituting the operation * @return this session runner */ @@ -163,8 +164,8 @@ public Runner feed(String operation, Tensor t) { } /** - * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} - * for the value it produces. + * Avoid evaluating the {@code index}-th output of {@code operation} by substituting {@code t} for the value it + * produces. * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one {@code t} is being provided for. @@ -183,8 +184,7 @@ public Runner feed(String operation, int index, Tensor t) { } /** - * Use {@code t} instead of the Tensor referred to by executing the operation referred to by - * {@code operand}. + * Use {@code t} instead of the Tensor referred to by executing the operation referred to by {@code operand}. * * @param operand the node in the graph representing the operation to substitute * @param t the tensor substituting the operation @@ -197,14 +197,13 @@ public Runner feed(Operand operand, Tensor t) { } /** - * Make {@link #run()} return the output of {@code operation}. + * Make {@link #run(TensorScope)} return the output of {@code operation}. * - * @param operation Is either the string name of the operation, in which case this method is a - * shorthand for {@code fetch(operation, 0)}, or it is a string of the form - * operation_name:output_index , in which case this method acts like {@code - * fetch(operation_name, output_index)}. These colon-separated names are commonly used in - * the {@code SignatureDef} protocol buffer messages that are included in {@link - * SavedModelBundle#metaGraphDef()}. + * @param operation Is either the string name of the operation, in which case this method is a shorthand for {@code + * fetch(operation, 0)}, or it is a string of the form + * operation_name:output_index , in which case this method acts like {@code + * fetch(operation_name, output_index)}. These colon-separated names are commonly used in the {@code SignatureDef} + * protocol buffer messages that are included in {@link SavedModelBundle#metaGraphDef()}. * @return this session runner */ public Runner fetch(String operation) { @@ -212,7 +211,7 @@ public Runner fetch(String operation) { } /** - * Make {@link #run()} return the {@code index}-th output of {@code operation}. + * Make {@link #run(TensorScope)} return the {@code index}-th output of {@code operation}. * *

Operations in a {@link Graph} can have multiple outputs, {@code index} identifies which * one to return. @@ -229,7 +228,7 @@ public Runner fetch(String operation, int index) { } /** - * Makes {@link #run()} return the Tensor referred to by {@code output}. + * Makes {@link #run(TensorScope)} return the Tensor referred to by {@code output}. * * @param output the node to fetch the tensor from * @return this session runner @@ -240,7 +239,7 @@ public Runner fetch(Output output) { } /** - * Makes {@link #run()} return the Tensor referred to by the output of {@code operand}. + * Makes {@link #run(TensorScope)} return the Tensor referred to by the output of {@code operand}. * * @param operand the node to fetch the tensor from, as an operand * @return this session runner @@ -250,8 +249,7 @@ public Runner fetch(Operand operand) { } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor - * Tensors}. + * Make {@link #run(TensorScope)} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. * * @param operation the string name of the operation to execute * @return this session runner @@ -265,8 +263,7 @@ public Runner addTarget(String operation) { } /** - * Make {@link #run()} execute {@code operation}, but not return any evaluated {@link Tensor - * Tensors}. + * Make {@link #run(TensorScope)} execute {@code operation}, but not return any evaluated {@link Tensor Tensors}. * * @param operation the operation to execute * @return this session runner @@ -297,8 +294,7 @@ public Runner addTarget(Op op) { * Set options (typically for debugging) for this run. * *

The options are presented as a RunOptions - * protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunOptions protocol buffer. * * @param options a {@code RunOptions} proto * @return this session runner @@ -309,41 +305,50 @@ public Runner setOptions(RunOptions options) { } /** - * Execute the graph fragments necessary to compute all requested fetches. + * Execute the graph fragments necessary to compute all requested fetches and complete all targets. * - *

WARNING: The caller assumes ownership of all returned {@link Tensor Tensors}, i.e., - * the caller must call {@link Tensor#close} on all elements of the returned list to free up - * resources. + *

The returned tensors will be part of the passed scope. * *

TODO(ashankar): Reconsider the return type here. Two things in particular: (a) Make it - * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in - * SessionTest.java), and (b) Evaluate whether the return value should be a list, or maybe a - * {@code Map}? + * easier for the caller to cleanup (perhaps returning something like AutoCloseableList in SessionTest.java), and + * (b) Evaluate whether the return value should be a list, or maybe a {@code Map}? * *

TODO(andrewmyers): It would also be good if whatever is returned here made it easier to * extract output tensors in a type-safe way. * + * @param scope the {@link TensorScope} to create outputs in. May be null if there are no outputs. * @return list of resulting tensors fetched by this session runner */ - public List run() { - return runHelper(false).outputs; + public List run(TensorScope scope) { + return runHelper(scope, false).outputs; + } + + /** + * Execute the graph fragments necessary to compute all requested fetches and complete all targets. + * + * @see #run(TensorScope) + */ + public void runWithoutOutputs() { + try (TensorScope scope = new TensorScope()) { + run(scope); + } } /** * Execute graph fragments to compute requested fetches and return metadata about the run. * - *

This is exactly like {@link #run()}, but in addition to the requested Tensors, also - * returns metadata about the graph execution in the form of a RunMetadata + *

This is exactly like {@link #run(TensorScope)}, but in addition to the requested Tensors, also + * returns metadata about the graph execution in the form of a RunMetadata * protocol buffer. * + * @param scope the {@link TensorScope} to create outputs in. May be null if there are no outputs. * @return list of resulting tensors fetched by this session runner, with execution metadata */ - public Run runAndFetchMetadata() { - return runHelper(true); + public Run runAndFetchMetadata(TensorScope scope) { + return runHelper(scope, true); } - private Run runHelper(boolean wantMetadata) { + private Run runHelper(TensorScope tensorScope, boolean wantMetadata) { TF_Tensor[] inputTensorHandles = new TF_Tensor[inputTensors.size()]; TF_Operation[] inputOpHandles = new TF_Operation[inputs.size()]; int[] inputOpIndices = new int[inputs.size()]; @@ -351,6 +356,10 @@ private Run runHelper(boolean wantMetadata) { int[] outputOpIndices = new int[outputs.size()]; TF_Operation[] targetOpHandles = new TF_Operation[targets.size()]; + if (outputs.size() > 0 && tensorScope == null) { + throw new NullPointerException("tensorScope must not be null when outputs are requested"); + } + // It's okay to use Operation.getUnsafeNativeHandle() here since the safety depends on the // validity of the Graph and graphRef ensures that. int idx = 0; @@ -388,7 +397,8 @@ private Run runHelper(boolean wantMetadata) { outputOpIndices, targetOpHandles, wantMetadata, - outputs); + outputs, + tensorScope); } catch (Exception e) { for (Tensor t : outputs) { t.close(); @@ -405,6 +415,7 @@ private Run runHelper(boolean wantMetadata) { } private class Reference implements AutoCloseable { + public Reference() { synchronized (nativeHandleLock) { if (nativeHandle == null || nativeHandle.isNull()) { @@ -457,7 +468,9 @@ private Output parseOutput(String opName) { private RunOptions runOptions = null; } - /** Create a Runner to execute graph operations and evaluate Tensors. */ + /** + * Create a Runner to execute graph operations and evaluate Tensors. + */ public Runner runner() { return new Runner(); } @@ -476,7 +489,7 @@ public void run(String opName) { throw new IllegalArgumentException( "Operation named '" + opName + "' cannot be found in the graph"); } - runner().addTarget(operation).run(); + runner().addTarget(operation).run(null); } /** @@ -487,7 +500,7 @@ public void run(String opName) { * @param op the operation to run. */ public void run(Op op) { - runner().addTarget(op.op()).run(); + runner().addTarget(op.op()).run(null); } @@ -495,12 +508,11 @@ public void run(Op op) { * Execute the graph's initializers. * *

This method is equivalent to {@code session.run(Ops.create(session.graph).init())}. - * */ - public void runInit(){ + public void runInit() { Runner runner = runner(); graph.initializers().forEach(runner::addTarget); - runner.run(); + runner.run(null); } /** @@ -518,14 +530,16 @@ public void runInit(){ */ public void save(String prefix) { SaverDef saverDef = graph.saverDef(); - runner().addTarget(saverDef.getSaveTensorName()) - .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) - .run(); + try (TensorScope scope = new TensorScope()) { + runner().addTarget(saverDef.getSaveTensorName()) + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(scope, prefix)) + .run(scope); + } } /** * Restore the actual state of the variables of this session's graph. - * + * *

{@code prefix} is the path where the files containing the variables state live, * followed by the filename prefix. For example, if {@code prefix} is set to * mymodel/myvariables/variables, then the files are loaded from @@ -538,26 +552,30 @@ public void save(String prefix) { */ public void restore(String prefix) { SaverDef saverDef = graph.saverDef(); - runner().addTarget(saverDef.getRestoreOpName()) - .feed(saverDef.getFilenameTensorName(), TString.scalarOf(prefix)) - .run(); + try (TensorScope scope = new TensorScope()) { + runner().addTarget(saverDef.getRestoreOpName()) + .feed(saverDef.getFilenameTensorName(), TString.scalarOf(scope, prefix)) + .run(null); + } } /** * Output tensors and metadata obtained when executing a session. * - *

See {@link Runner#runAndFetchMetadata()} + *

See {@link Runner#runAndFetchMetadata(TensorScope)} */ public static final class Run { - /** Tensors from requested fetches. */ + + /** + * Tensors from requested fetches. + */ public List outputs; /** * Metadata about the run. * *

A RunMetadata - * protocol buffer. + * href="https://www.tensorflow.org/code/tensorflow/core/protobuf/config.proto">RunMetadata protocol buffer. */ public RunMetadata metadata; } @@ -633,20 +651,20 @@ private static void delete(TF_Session handle) { * @param runOptions A RunOptions protocol buffer, or null * @param inputOpHandles (see inputOpIndices) * @param inputOpIndices (see inputTensorHandles) - * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values - * that are being "fed" (do not need to be computed) during graph execution. - * inputTensorHandles[i] (which corresponds to a Tensor.nativeHandle) is considered to be the - * inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, it is required that - * inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. + * @param inputTensorHandles together with inputOpHandles and inputOpIndices specifies the values that are being "fed" + * (do not need to be computed) during graph execution. inputTensorHandles[i] (which corresponds to a + * Tensor.nativeHandle) is considered to be the inputOpIndices[i]-th output of the Operation inputOpHandles[i]. Thus, + * it is required that inputOpHandles.length == inputOpIndices.length == inputTensorHandles.length. * @param outputOpHandles (see outputOpIndices) - * @param outputOpIndices together with outputOpHandles identifies the set of values that should - * be computed. The outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is - * required that outputOpHandles.length == outputOpIndices.length. - * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose - * output will not be returned + * @param outputOpIndices together with outputOpHandles identifies the set of values that should be computed. The + * outputOpIndices[i]-th output of the Operation outputOpHandles[i], It is required that outputOpHandles.length == + * outputOpIndices.length. + * @param targetOpHandles is the set of Operations in the graph that are to be executed but whose output will not be + * returned * @param wantRunMetadata indicates whether metadata about this execution should be returned. - * @param outputTensors will be filled in with tensors to the outputs requested. It is required - * that outputs.length == outputOpHandles.length. + * @param outputTensors will be filled in with tensors to the outputs requested. It is required that outputs.length == + * outputOpHandles.length. + * @param tensorScope the {@link TensorScope} to create tensors in * @return if wantRunMetadata is true, a RunMetadata protocol buffer, false otherwise. */ private static RunMetadata run( @@ -659,7 +677,8 @@ private static RunMetadata run( int[] outputOpIndices, TF_Operation[] targetOpHandles, boolean wantRunMetadata, - List outputTensors) { + List outputTensors, + TensorScope tensorScope) { requireHandle(handle); int ninputs = inputTensorHandles.length; @@ -698,8 +717,8 @@ private static RunMetadata run( status.throwExceptionIfNotOK(); for (int i = 0; i < noutputs; ++i) { - TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(); - outputTensors.add(RawTensor.fromHandle(h).asTypedTensor()); + TF_Tensor h = outputValues.get(TF_Tensor.class, i).withDeallocator(false); + outputTensors.add(RawTensor.fromHandle(tensorScope, h).asTypedTensor()); } try { return runMetadata != null ? RunMetadata.parseFrom(runMetadata.dataAsByteBuffer()) : null; diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java index fc1275229bf..1f13d4e9f32 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/Tensor.java @@ -16,6 +16,7 @@ package org.tensorflow; import java.util.function.Consumer; +import org.bytedeco.javacpp.Pointer; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.Shaped; import org.tensorflow.ndarray.buffer.ByteDataBuffer; @@ -26,19 +27,27 @@ * A statically typed multi-dimensional array. * *

There are two categories of tensors in TensorFlow Java: {@link TType typed tensors} and - * {@link RawTensor raw tensors}. The former maps the tensor native memory to an - * n-dimensional typed data space, allowing direct I/O operations from the JVM, while the latter - * is only a reference to a native tensor allowing basic operations and flat data access.

+ * {@link RawTensor raw tensors}. The former maps the tensor native memory to an n-dimensional typed data space, + * allowing direct I/O operations from the JVM, while the latter is only a reference to a native tensor allowing basic + * operations and flat data access.

* - *

WARNING: Resources consumed by the Tensor object must be explicitly freed by - * invoking the {@link #close()} method when the object is no longer needed. For example, using a + *

WARNING: Resources consumed by the Tensor object should be explicitly freed by + * invoking the {@link #close()} method on the tensor, or using {@link TensorScope}. For example, using a * try-with-resources block: * *

{@code
- * try (Tensor t = Tensor.of(...)) {
+ * try (TensorScope scope = new TensorScope()) {
+ *   Tensor t = Tensor.of(scope, ...);
  *   doSomethingWith(t);
  * }
  * }
+ * + * Dropped tensors will be closed when GC'd, but relying on the garbage collector for cleanup is inefficient. + * + * + *

JavaCPP properties are used to manage garbage collection, see {@link Pointer}. Specifically + * {@link Pointer#maxBytes} and {@link Pointer#maxPhysicalBytes}. + * *

Instances of a Tensor are not thread-safe. */ public interface Tensor extends Shaped, AutoCloseable { @@ -50,44 +59,45 @@ public interface Tensor extends Shaped, AutoCloseable { * and is left uninitialized. * * @param the tensor type + * @param scope the {@link TensorScope} to create the tensor in * @param type the tensor type class * @param shape shape of the tensor * @return an allocated but uninitialized tensor - * @throws IllegalArgumentException if elements of the given {@code type} are of variable length - * (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * @throws IllegalArgumentException if elements of the given {@code type} are of variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension() + * unknown} * @throws IllegalStateException if tensor failed to be allocated */ - static T of(Class type, Shape shape) { - return of(type, shape, -1); + static T of(TensorScope scope, Class type, Shape shape) { + return of(scope, type, shape, -1); } /** * Allocates a tensor of a given datatype, shape and size. * - *

This method is identical to {@link #of(Class, Shape)}, except that the final size of the - * tensor can be explicitly set instead of computing it from the datatype and shape, which could be - * larger than the actual space required to store the data but not smaller. + *

This method is identical to {@link #of(TensorScope, Class, Shape)}, except that the final size of the + * tensor can be explicitly set instead of computing it from the datatype and shape, which could be larger than the + * actual space required to store the data but not smaller. * * @param the tensor type + * @param scope the {@link TensorScope} to create the tensor in * @param type the tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor or -1 to compute the size from the shape * @return an allocated but uninitialized tensor - * @see #of(Class, Shape) - * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to - * store the tensor data - * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given - * {@code type} are of variable length (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to store the tensor + * data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code type} are of + * variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension() + * unknown} * @throws IllegalStateException if tensor failed to be allocated + * @see #of(TensorScope, Class, Shape) */ - static T of(Class type, Shape shape, long size) { - RawTensor tensor = RawTensor.allocate(type, shape, size); + static T of(TensorScope scope, Class type, Shape shape, long size) { + RawTensor tensor = RawTensor.allocate(scope, type, shape, size); try { - return (T)tensor.asTypedTensor(); + return (T) tensor.asTypedTensor(); } catch (Exception e) { tensor.close(); throw e; @@ -98,8 +108,8 @@ static T of(Class type, Shape shape, long size) { * Allocates and initialize a tensor of a given datatype and shape. * *

The amount of memory to allocate is derived from the datatype and the shape of the tensor. - * Tensor data is initialized by calling the {@code dataInitializer}, which receives in argument - * the value returned by {@link #data()} on the allocated tensor. For example: + * Tensor data is initialized by calling the {@code dataInitializer}, which receives in argument the value returned by + * {@link #data()} on the allocated tensor. For example: * *

{@code
    * FloatNdArray data = ...
@@ -112,46 +122,47 @@ static  T of(Class type, Shape shape, long size) {
    * automatically released before rethrowing the same exception.
    *
    * @param  the tensor type
+   * @param scope the {@link TensorScope} to create the tensor in
    * @param type the tensor type class
    * @param shape shape of the tensor
    * @param dataInitializer method receiving accessor to the allocated tensor data for initialization
    * @return an allocated and initialized tensor
-   * @throws IllegalArgumentException if elements of the given {@code type} are of variable length
-   *                                  (e.g. strings)
-   * @throws IllegalArgumentException if {@code shape} is totally or partially
-   *                                  {@link Shape#hasUnknownDimension() unknown}
+   * @throws IllegalArgumentException if elements of the given {@code type} are of variable length (e.g. strings)
+   * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension()
+   * unknown}
    * @throws IllegalStateException if tensor failed to be allocated
    */
-  static  T of(Class type, Shape shape, Consumer dataInitializer) {
-    return of(type, shape, -1, dataInitializer);
+  static  T of(TensorScope scope, Class type, Shape shape, Consumer dataInitializer) {
+    return of(scope, type, shape, -1, dataInitializer);
   }
 
   /**
    * Allocates a tensor of a given datatype, shape and size.
    *
-   * 

This method is identical to {@link #of(Class, Shape, Consumer)}, except that the final + *

This method is identical to {@link #of(TensorScope, Class, Shape, Consumer)}, except that the final * size for the tensor can be explicitly set instead of being computed from the datatype and shape. * *

This could be useful for tensor types that stores data but also metadata in the tensor memory, * such as the lookup table in a tensor of strings. * * @param the tensor type + * @param scope the {@link TensorScope} to create the tensor in * @param type the tensor type class * @param shape shape of the tensor * @param size size in bytes of the tensor or -1 to compute the size from the shape * @param dataInitializer method receiving accessor to the allocated tensor data for initialization * @return an allocated and initialized tensor - * @see #of(Class, Shape, long, Consumer) - * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to - * store the tensor data - * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given - * {@code type} are of variable length (e.g. strings) - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * @throws IllegalArgumentException if {@code size} is smaller than the minimum space required to store the tensor + * data + * @throws IllegalArgumentException if {@code size} is set to -1 but elements of the given {@code type} are of + * variable length (e.g. strings) + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension() + * unknown} * @throws IllegalStateException if tensor failed to be allocated + * @see #of(TensorScope, Class, Shape, long, Consumer) */ - static T of(Class type, Shape shape, long size, Consumer dataInitializer) { - T tensor = of(type, shape, size); + static T of(TensorScope scope, Class type, Shape shape, long size, Consumer dataInitializer) { + T tensor = of(scope, type, shape, size); try { dataInitializer.accept(tensor); return tensor; @@ -168,17 +179,17 @@ static T of(Class type, Shape shape, long size, Consumer * href="https://www.tensorflow.org/code/tensorflow/c/c_api.h">C API. * * @param the tensor type + * @param scope the {@link TensorScope} to create the tensor in * @param type the tensor type class * @param shape the tensor shape. * @param rawData a buffer containing the tensor raw data. - * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor - * data - * @throws IllegalArgumentException if {@code shape} is totally or partially - * {@link Shape#hasUnknownDimension() unknown} + * @throws IllegalArgumentException if {@code rawData} is not large enough to contain the tensor data + * @throws IllegalArgumentException if {@code shape} is totally or partially {@link Shape#hasUnknownDimension() + * unknown} * @throws IllegalStateException if tensor failed to be allocated with the given parameters */ - static T of(Class type, Shape shape, ByteDataBuffer rawData) { - return of(type, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size())); + static T of(TensorScope scope, Class type, Shape shape, ByteDataBuffer rawData) { + return of(scope, type, shape, rawData.size(), t -> rawData.copyTo(t.asRawTensor().data(), rawData.size())); } /** @@ -204,12 +215,31 @@ static T of(Class type, Shape shape, ByteDataBuffer rawData /** * Release resources associated with the Tensor. + *

All tensors should be closed using this method or {@link TensorScope}. + * Memory will not leak if they aren't, but relying on the garbage collector for cleanup is not efficient. * - *

WARNING:This must be invoked for all tensors that were not been produced by an eager - * operation or memory will be leaked. + *

JavaCPP properties are used to manage garbage collection, see {@link Pointer}. Specifically + * {@link Pointer#maxBytes} and {@link Pointer#maxPhysicalBytes}. * *

The Tensor object is no longer usable after {@code close} returns. */ @Override void close(); + + /** + * Get whether this tensor has been closed. + */ + boolean isClosed(); + + /** + * Detach this tensor from any scopes managing it. It must be manually closed or attached to another scope. + */ + default void detach() { + TensorScope.detach(this); + } + + /** + * Returns true if this tensor is attached to a {@link TensorScope}. + */ + boolean isAttached(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java new file mode 100644 index 00000000000..6052daab086 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorContainer.java @@ -0,0 +1,52 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +/** + * An interface representing a collection or group of tensors. Provides methods for resource management. + */ +public interface TensorContainer extends AutoCloseable { + + /** + * Get the tensors held by this object. + */ + Iterable tensors(); + + /** + * Detach these tensors from any scopes managing them. They must be manually closed or attached to another scope. + * + * @see Tensor#detach() + */ + default void detach() { + tensors().forEach(Tensor::detach); + } + + + /** + * Release resources associated with these tensors. + *

All tensors should be closed using this method or {@link TensorScope}. + * Memory will not leak if they aren't, but relying on the garbage collector for cleanup is not efficient. + * + *

The Tensor objects are no longer usable after {@code close} returns. + * + * @see Tensor#close() + */ + @Override + default void close() { + tensors().forEach(Tensor::close); + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java new file mode 100644 index 00000000000..ecb8cadabae --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/TensorScope.java @@ -0,0 +1,252 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.WeakHashMap; +import org.bytedeco.javacpp.Pointer; + + +/** + * A scope used to manage tensor resources. All tensor-creating methods take a scope as a parameter, and create their + * tensors in that scope. When a scope is closed, it closes all of it's attached tensors. Tensors may be manually + * closed earlier without issue, and being attached to a scope will not keep a tensor from being GC'd. Using a {@code + * TensorScope} is recommended over manually closing every tensor. For example, using a try-with-resources block: + * + *

{@code
+ * try (TensorScope scope = new TensorScope()) {
+ *   Tensor t = Tensor.of(scope, ...);
+ *   doSomethingWith(t);
+ *   Tensor t2 = Tensor.of(scope, ...);
+ *   doSomething2(t);
+ * }
+ * }
+ * + *

While tensors will be closed when GC'd, relying on the garbage collector for cleanup is not efficient. This + * class or manual management should be used. + * + *

JavaCPP properties are used to manage garbage collection, see {@link Pointer}. Specifically + * {@link Pointer#maxBytes} and {@link Pointer#maxPhysicalBytes}. + * + *

+ * {@link TensorScope#detach(Tensor)} and {@link Tensor#detach()} detaches the tensor from it's scope, requiring the + * user to close it manually or attach it to another scope. + * + *

+ * Like Tensors, TensorScope is not thread safe. + */ +public final class TensorScope implements AutoCloseable { + + + /** + * Create a new tensor scope. + * + * @see TensorScope + */ + public TensorScope() { + } + + /** + * Closes this scope and its tensors. + *

All tensors should be closed using this method or {@link Tensor#close()}. + * Memory will not leak if they aren't, but relying on the garbage collector for cleanup is not efficient. + */ + @Override + public void close() { + if (closed) { + return; + } + tensors.forEach(Tensor::close); + + closed = true; + } + + /** + * Detach all of this scope's tensors, then close the scope. + *

+ * EXTREMELY DANGEROUS: this will close this scope, but does not close any of it's resources. + * + * @return All of this scope's now-detached tensors + */ + public Set detachAll() { + Set detachedTensors = new HashSet<>(this.tensors); + detachedTensors.forEach(TensorScope::detach); + closed = true; + tensors.clear(); + return detachedTensors; + } + + public static T detach(T tensor) { + // ensure that I'm not attaching or detaching at the same time in different threads + RawTensor rt = tensor.asRawTensor(); + if (rt.tensorScope != null) { + rt.tensorScope.tensors.remove(rt); + rt.tensorScope = null; + } + return tensor; + } + + /** + * @see #detach(Tensor) + */ + public static void detach(Tensor... tensors) { + for (Tensor t : tensors) { + detach(t); + } + } + + /** + * @see #detach(Tensor) + */ + public static T detach(T tensors) { + detach(tensors.tensors()); + return tensors; + } + + /** + * @see #detach(Tensor) + */ + public static void detach(TensorContainer... tensors) { + for (TensorContainer ht : tensors) { + detach(ht); + } + } + + /** + * @see #detach(Tensor) + */ + public static > T detach(T tensors) { + tensors.forEach(TensorScope::detach); + return tensors; + } + + /** + * @see #detach(Tensor) + */ + @SafeVarargs + public static void detach(Iterable... tensors) { + for (Iterable iterable : tensors) { + detach(iterable); + } + } + + /** + * Attach a tensor to this scope. This happens automatically to tensors that are created in the scope. + * + * @return this + */ + public T attach(T tensor) { + if (this.closed) { + throw new IllegalStateException("Scope has been closed, can not attach new tensor."); + } + + RawTensor rt = tensor.asRawTensor(); + detach(tensor); + rt.tensorScope = this; + tensors.add(rt); + + return tensor; + } + + /** + * @see #attach(Tensor) + */ + public void attach(Tensor... tensors) { + if (tensors != null) { + for (Tensor t : tensors) { + attach(t); + } + } + } + + /** + * @see #attach(Tensor) + */ + public T attach(T tensors) { + attach(tensors.tensors()); + return tensors; + } + + /** + * @see #attach(Tensor) + */ + public void attach(TensorContainer... tensors) { + if (tensors != null) { + for (TensorContainer ht : tensors) { + attach(ht); + } + } + } + + /** + * @see #attach(Tensor) + */ + public > T attach(T tensors) { + tensors.forEach(this::attach); + + return tensors; + } + + /** + * @see #attach(Tensor) + */ + @SafeVarargs + public final void attach(Iterable... tensors) { + if (tensors != null) { + for (Iterable ht : tensors) { + attach(ht); + } + } + } + + /** + * @see #attach(Tensor) + */ + public TensorScope withTensors(Tensor... tensors) { + attach(tensors); + return this; + } + + /** + * @see #attach(Tensor) + */ + public TensorScope withTensors(TensorContainer... tensors) { + attach(tensors); + return this; + } + + /** + * @see #attach(Tensor) + */ + @SafeVarargs + public final TensorScope withTensors(Iterable... tensors) { + attach(tensors); + return this; + } + + /** + * Gets whether the scope is closed. + */ + public boolean isClosed() { + return closed; + } + + private boolean closed = false; + private final Set tensors = Collections.newSetFromMap(new WeakHashMap<>()); +} diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java index fba056c6dcb..940f717806a 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/internal/c_api/AbstractTF_Tensor.java @@ -20,29 +20,67 @@ import static org.tensorflow.internal.c_api.global.tensorflow.TF_AllocateTensor; import static org.tensorflow.internal.c_api.global.tensorflow.TF_DeleteTensor; import static org.tensorflow.internal.c_api.global.tensorflow.TF_NewTensor; +import static org.tensorflow.internal.c_api.global.tensorflow.TF_TensorByteSize; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.annotation.Properties; @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) public abstract class AbstractTF_Tensor extends Pointer { + protected static class DeleteDeallocator extends TF_Tensor implements Pointer.Deallocator { - DeleteDeallocator(TF_Tensor s) { super(s); } - @Override public void deallocate() { if (!isNull()) TF_DeleteTensor(this); setNull(); } + + DeleteDeallocator(TF_Tensor s, boolean registerMemory) { + super(s); + // some tensors, like those from eager operations, are just views + if (registerMemory) { + // ideally this would be TF_TensorElementCount and sizeof would be the datatype size, + // but datatype isn't stored anywhere and may be variably sized + s.capacity = TF_TensorByteSize(s); + } else { + s.capacity = 0; + } + } + + DeleteDeallocator(TF_Tensor s) { + this(s, true); + } + + @Override + public void deallocate() { + if (!isNull()) { + TF_DeleteTensor(this); + } + setNull(); + } } - /** TensorFlow crashes if we don't pass it a deallocator, so... */ + /** + * TensorFlow crashes if we don't pass it a deallocator, so... + */ protected static Deallocator_Pointer_long_Pointer dummyDeallocator = new Deallocator_Pointer_long_Pointer() { - @Override public void call(Pointer data, long len, Pointer arg) { } + @Override + public void call(Pointer data, long len, Pointer arg) { + } }.retainReference(); - /** A reference to prevent deallocation. */ + /** + * A reference to prevent deallocation. + */ protected Pointer pointer; - public AbstractTF_Tensor(Pointer p) { super(p); } + public AbstractTF_Tensor(Pointer p) { + super(p); + } + + @Override + public int sizeof() { + return 1; + } /** * Calls TF_NewTensor(), and registers a deallocator. + * * @return TF_Tensor created. Do not call TF_DeleteTensor() on it. */ public static TF_Tensor newTensor(int dtype, long[] dims, Pointer data) { @@ -66,9 +104,11 @@ public static TF_Tensor allocateTensor(int dtype, long[] dims, long length) { return t; } - /** Registers a deallocator and returns this. */ - public TF_Tensor withDeallocator() { - return (TF_Tensor)this.deallocator(new DeleteDeallocator((TF_Tensor)this)); + /** + * Registers a deallocator and returns this. + */ + public TF_Tensor withDeallocator(boolean isView) { + return (TF_Tensor) this.deallocator(new DeleteDeallocator((TF_Tensor) this, isView)); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java index 497ee5f2d46..efe6a06e80b 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/core/Constant.java @@ -20,6 +20,7 @@ import org.tensorflow.Operation; import org.tensorflow.Output; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.BooleanNdArray; import org.tensorflow.ndarray.ByteNdArray; import org.tensorflow.ndarray.DoubleNdArray; @@ -80,7 +81,8 @@ public final class Constant extends RawOp implements Operand */ @Endpoint public static Constant scalarOf(Scope scope, int data) { - try (TInt32 value = TInt32.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -89,13 +91,14 @@ public static Constant scalarOf(Scope scope, int data) { * Creates a rank-1 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant vectorOf(Scope scope, int[] data) { - try (TInt32 value = TInt32.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -119,14 +122,15 @@ public static Constant arrayOf(Scope scope, int... data) { * Creates a rank-2 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -135,14 +139,15 @@ public static Constant tensorOf(Scope scope, int[][] data) { * Creates a rank-3 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -151,14 +156,15 @@ public static Constant tensorOf(Scope scope, int[][][] data) { * Creates a rank-4 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -167,14 +173,15 @@ public static Constant tensorOf(Scope scope, int[][][][] data) { * Creates a rank-5 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -183,14 +190,15 @@ public static Constant tensorOf(Scope scope, int[][][][][] data) { * Creates a rank-6 constant of {@code int} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return an integer constant */ @Endpoint public static Constant tensorOf(Scope scope, int[][][][][][] data) { - try (TInt32 value = TInt32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -207,7 +215,8 @@ public static Constant tensorOf(Scope scope, IntNdArray data) { if (data instanceof TInt32) { return create(scope, (TInt32) data); } - try (TInt32 value = TInt32.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -223,7 +232,8 @@ public static Constant tensorOf(Scope scope, IntNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, IntDataBuffer data) { - try (TInt32 value = TInt32.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt32 value = TInt32.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -237,7 +247,8 @@ public static Constant tensorOf(Scope scope, Shape shape, IntDataBuffer */ @Endpoint public static Constant scalarOf(Scope scope, float data) { - try (TFloat32 value = TFloat32.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -246,13 +257,14 @@ public static Constant scalarOf(Scope scope, float data) { * Creates a rank-1 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant vectorOf(Scope scope, float[] data) { - try (TFloat32 value = TFloat32.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -276,14 +288,15 @@ public static Constant arrayOf(Scope scope, float... data) { * Creates a rank-2 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -292,14 +305,15 @@ public static Constant tensorOf(Scope scope, float[][] data) { * Creates a rank-3 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -308,14 +322,15 @@ public static Constant tensorOf(Scope scope, float[][][] data) { * Creates a rank-4 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -324,14 +339,15 @@ public static Constant tensorOf(Scope scope, float[][][][] data) { * Creates a rank-5 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -340,14 +356,15 @@ public static Constant tensorOf(Scope scope, float[][][][][] data) { * Creates a rank-6 constant of {@code float} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a float constant */ @Endpoint public static Constant tensorOf(Scope scope, float[][][][][][] data) { - try (TFloat32 value = TFloat32.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -364,7 +381,8 @@ public static Constant tensorOf(Scope scope, FloatNdArray data) { if (data instanceof TFloat32) { return create(scope, (TFloat32) data); } - try (TFloat32 value = TFloat32.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -380,7 +398,8 @@ public static Constant tensorOf(Scope scope, FloatNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, FloatDataBuffer data) { - try (TFloat32 value = TFloat32.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 value = TFloat32.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -394,7 +413,8 @@ public static Constant tensorOf(Scope scope, Shape shape, FloatDataBuf */ @Endpoint public static Constant scalarOf(Scope scope, double data) { - try (TFloat64 value = TFloat64.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -403,13 +423,14 @@ public static Constant scalarOf(Scope scope, double data) { * Creates a rank-1 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant vectorOf(Scope scope, double[] data) { - try (TFloat64 value = TFloat64.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -433,14 +454,15 @@ public static Constant arrayOf(Scope scope, double... data) { * Creates a rank-2 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -449,14 +471,15 @@ public static Constant tensorOf(Scope scope, double[][] data) { * Creates a rank-3 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -465,14 +488,15 @@ public static Constant tensorOf(Scope scope, double[][][] data) { * Creates a rank-4 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -481,14 +505,15 @@ public static Constant tensorOf(Scope scope, double[][][][] data) { * Creates a rank-5 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -497,14 +522,15 @@ public static Constant tensorOf(Scope scope, double[][][][][] data) { * Creates a rank-6 constant of {@code double} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a double constant */ @Endpoint public static Constant tensorOf(Scope scope, double[][][][][][] data) { - try (TFloat64 value = TFloat64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo( - data, t))) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo( + data, t))) { return create(scope, value); } } @@ -521,7 +547,8 @@ public static Constant tensorOf(Scope scope, DoubleNdArray data) { if (data instanceof TFloat64) { return create(scope, (TFloat64) data); } - try (TFloat64 value = TFloat64.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -537,7 +564,8 @@ public static Constant tensorOf(Scope scope, DoubleNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, DoubleDataBuffer data) { - try (TFloat64 value = TFloat64.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 value = TFloat64.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -551,7 +579,8 @@ public static Constant tensorOf(Scope scope, Shape shape, DoubleDataBu */ @Endpoint public static Constant scalarOf(Scope scope, long data) { - try (TInt64 value = TInt64.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -560,13 +589,14 @@ public static Constant scalarOf(Scope scope, long data) { * Creates a rank-1 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant vectorOf(Scope scope, long[] data) { - try (TInt64 value = TInt64.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -575,14 +605,15 @@ public static Constant vectorOf(Scope scope, long[] data) { * Creates a rank-2 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -606,14 +637,15 @@ public static Constant arrayOf(Scope scope, long... data) { * Creates a rank-3 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -622,14 +654,15 @@ public static Constant tensorOf(Scope scope, long[][][] data) { * Creates a rank-4 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -638,14 +671,15 @@ public static Constant tensorOf(Scope scope, long[][][][] data) { * Creates a rank-5 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -654,14 +688,15 @@ public static Constant tensorOf(Scope scope, long[][][][][] data) { * Creates a rank-6 constant of {@code long} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a long constant */ @Endpoint public static Constant tensorOf(Scope scope, long[][][][][][] data) { - try (TInt64 value = TInt64.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -678,7 +713,8 @@ public static Constant tensorOf(Scope scope, LongNdArray data) { if (data instanceof TInt64) { return create(scope, (TInt64) data); } - try (TInt64 value = TInt64.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -694,7 +730,8 @@ public static Constant tensorOf(Scope scope, LongNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, LongDataBuffer data) { - try (TInt64 value = TInt64.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TInt64 value = TInt64.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -708,7 +745,8 @@ public static Constant tensorOf(Scope scope, Shape shape, LongDataBuffer */ @Endpoint public static Constant scalarOf(Scope scope, boolean data) { - try (TBool value = TBool.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -717,13 +755,14 @@ public static Constant scalarOf(Scope scope, boolean data) { * Creates a rank-1 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant vectorOf(Scope scope, boolean[] data) { - try (TBool value = TBool.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -747,14 +786,15 @@ public static Constant arrayOf(Scope scope, boolean... data) { * Creates a rank-2 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -763,14 +803,15 @@ public static Constant tensorOf(Scope scope, boolean[][] data) { * Creates a rank-3 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -779,14 +820,15 @@ public static Constant tensorOf(Scope scope, boolean[][][] data) { * Creates a rank-4 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -795,14 +837,15 @@ public static Constant tensorOf(Scope scope, boolean[][][][] data) { * Creates a rank-5 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -811,14 +854,15 @@ public static Constant tensorOf(Scope scope, boolean[][][][][] data) { * Creates a rank-6 constant of {@code boolean} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a boolean constant */ @Endpoint public static Constant tensorOf(Scope scope, boolean[][][][][][] data) { - try (TBool value = TBool.tensorOf(StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, - t))) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, StdArrays.shapeOf(data), t -> StdArrays.copyTo(data, + t))) { return create(scope, value); } } @@ -835,7 +879,8 @@ public static Constant tensorOf(Scope scope, BooleanNdArray data) { if (data instanceof TBool) { return create(scope, (TBool) data); } - try (TBool value = TBool.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -851,7 +896,8 @@ public static Constant tensorOf(Scope scope, BooleanNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, BooleanDataBuffer data) { - try (TBool value = TBool.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TBool value = TBool.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -865,7 +911,8 @@ public static Constant tensorOf(Scope scope, Shape shape, BooleanDataBuff */ @Endpoint public static Constant scalarOf(Scope scope, byte data) { - try (TUint8 value = TUint8.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -874,13 +921,14 @@ public static Constant scalarOf(Scope scope, byte data) { * Creates a rank-1 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant vectorOf(Scope scope, byte[] data) { - try (TUint8 value = TUint8.vectorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.vectorOf(tensorScope, data)) { return create(scope, value); } } @@ -904,14 +952,15 @@ public static Constant arrayOf(Scope scope, byte... data) { * Creates a rank-2 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -920,14 +969,15 @@ public static Constant tensorOf(Scope scope, byte[][] data) { * Creates a rank-3 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -936,14 +986,15 @@ public static Constant tensorOf(Scope scope, byte[][][] data) { * Creates a rank-4 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -952,14 +1003,15 @@ public static Constant tensorOf(Scope scope, byte[][][][] data) { * Creates a rank-5 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -968,14 +1020,15 @@ public static Constant tensorOf(Scope scope, byte[][][][][] data) { * Creates a rank-6 constant of {@code byte} elements. * * @param scope is a scope used to add the underlying operation. - * @param data An array containing the values to put into the new constant. The dimensions of the - * new constant will match those of the array. + * @param data An array containing the values to put into the new constant. The dimensions of the new constant will + * match those of the array. * @return a byte constant */ @Endpoint public static Constant tensorOf(Scope scope, byte[][][][][][] data) { - try (TUint8 value = TUint8.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, - d))) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, + d))) { return create(scope, value); } } @@ -992,7 +1045,8 @@ public static Constant tensorOf(Scope scope, ByteNdArray data) { if (data instanceof TUint8) { return create(scope, (TUint8) data); } - try (TUint8 value = TUint8.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, data)) { return create(scope, value); } } @@ -1008,7 +1062,8 @@ public static Constant tensorOf(Scope scope, ByteNdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer data) { - try (TUint8 value = TUint8.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TUint8 value = TUint8.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -1022,13 +1077,13 @@ public static Constant tensorOf(Scope scope, Shape shape, ByteDataBuffer * @param shape the tensor shape. * @param data a buffer containing the tensor data. * @return a constant of type `type` - * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the - * buffer + * @throws IllegalArgumentException If the tensor datatype or shape is not compatible with the buffer */ @Endpoint public static Constant tensorOf(Scope scope, Class type, Shape shape, ByteDataBuffer data) { - try (T value = Tensor.of(type, shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + T value = Tensor.of(tensorScope, type, shape, data)) { return create(scope, value); } } @@ -1042,7 +1097,8 @@ public static Constant tensorOf(Scope scope, Class type, */ @Endpoint public static Constant scalarOf(Scope scope, String data) { - try (TString value = TString.scalarOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.scalarOf(tensorScope, data)) { return create(scope, value); } } @@ -1057,7 +1113,8 @@ public static Constant scalarOf(Scope scope, String data) { */ @Endpoint public static Constant scalarOf(Scope scope, Charset charset, String data) { - try (TString value = TString.tensorOf(charset, NdArrays.scalarOfObject(data))) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, charset, NdArrays.scalarOfObject(data))) { return create(scope, value); } } @@ -1071,7 +1128,8 @@ public static Constant scalarOf(Scope scope, Charset charset, String da */ public static Constant vectorOf(Scope scope, String[] data) { NdArray src = NdArrays.vectorOfObjects(data); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1081,13 +1139,14 @@ public static Constant vectorOf(Scope scope, String[] data) { * * @param scope is a scope used to add the underlying operation. * @param charset charset for encoding/decoding strings bytes. - * @param data An array containing the values to put into the new constant. String elements are - * sequences of bytes from the last array dimension. + * @param data An array containing the values to put into the new constant. String elements are sequences of bytes + * from the last array dimension. * @return the {@code String} constant */ @Endpoint public static Constant vectorOf(Scope scope, Charset charset, String[] data) { - try (TString value = TString.tensorOf(charset, NdArrays.vectorOfObjects(data))) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, charset, NdArrays.vectorOfObjects(data))) { return Constant.create(scope, value); } } @@ -1112,8 +1171,8 @@ public static Constant arrayOf(Scope scope, String... data) { * * @param scope is a scope used to add the underlying operation. * @param charset charset for encoding/decoding strings bytes. - * @param data An array containing the values to put into the new constant. String elements are - * sequences of bytes from the last array dimension. + * @param data An array containing the values to put into the new constant. String elements are sequences of bytes + * from the last array dimension. * @return the {@code String} constant */ @Endpoint(name = "array") @@ -1134,7 +1193,8 @@ public static Constant arrayOf(Scope scope, Charset charset, String... public static Constant tensorOf(Scope scope, String[][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1149,7 +1209,8 @@ public static Constant tensorOf(Scope scope, String[][] data) { public static Constant tensorOf(Scope scope, String[][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1164,7 +1225,8 @@ public static Constant tensorOf(Scope scope, String[][][] data) { public static Constant tensorOf(Scope scope, String[][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1179,7 +1241,8 @@ public static Constant tensorOf(Scope scope, String[][][][] data) { public static Constant tensorOf(Scope scope, String[][][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } @@ -1194,14 +1257,15 @@ public static Constant tensorOf(Scope scope, String[][][][][] data) { public static Constant tensorOf(Scope scope, String[][][][][][] data) { NdArray src = NdArrays.ofObjects(String.class, StdArrays.shapeOf(data)); StdArrays.copyTo(data, src); - try (TString value = TString.tensorOf(src)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, src)) { return create(scope, value); } } /** - * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, - * using the default UTF-8 encoding. + * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, using the default + * UTF-8 encoding. * * @param scope is a scope used to add the underlying operation. * @param data an n-dimensional array of {@code String} elements. @@ -1212,14 +1276,15 @@ public static Constant tensorOf(Scope scope, NdArray data) { if (data instanceof TString) { return create(scope, (TString) data); } - try (TString value = TString.tensorOf(data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, data)) { return create(scope, value); } } /** - * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, - * using the given encoding. + * Creates a constant of {@code String} elements that is a copy of a given n-dimensional array, using the given + * encoding. * * @param scope is a scope used to add the underlying operation. * @param charset charset used to encode/decode string bytes. @@ -1228,14 +1293,14 @@ public static Constant tensorOf(Scope scope, NdArray data) { */ @Endpoint public static Constant tensorOf(Scope scope, Charset charset, NdArray data) { - try (TString value = TString.tensorOf(charset, data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, charset, data)) { return create(scope, value); } } /** - * Create a {@link TString} constant with data from the given buffer, using the default UTF-8 - * encoding. + * Create a {@link TString} constant with data from the given buffer, using the default UTF-8 encoding. * * @param scope is a scope used to add the underlying operation. * @param shape the tensor shape. @@ -1245,7 +1310,8 @@ public static Constant tensorOf(Scope scope, Charset charset, NdArray tensorOf(Scope scope, Shape shape, DataBuffer data) { - try (TString value = TString.tensorOf(shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, shape, data)) { return create(scope, value); } } @@ -1263,14 +1329,14 @@ public static Constant tensorOf(Scope scope, Shape shape, DataBuffer tensorOf(Scope scope, Charset charset, Shape shape, DataBuffer data) { - try (TString value = TString.tensorOf(charset, shape, data)) { + try (TensorScope tensorScope = new TensorScope(); + TString value = TString.tensorOf(tensorScope, charset, shape, data)) { return create(scope, value); } } /** - * Creates a rank-1 constant of {@code long} elements representing the size of each dimensions of - * the given shape. + * Creates a rank-1 constant of {@code long} elements representing the size of each dimensions of the given shape. * * @param scope is a scope used to add the underlying operation. * @param shape a shape @@ -1294,36 +1360,38 @@ public static Constant tensorOf(Scope scope, Shape shape) { @SuppressWarnings("unchecked") @Endpoint public static Constant tensorOf(Scope scope, Class type, Number number) { - if (type.equals(TBfloat16.class)) { - try (TBfloat16 tensor = TBfloat16.scalarOf(number.floatValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TFloat64.class)) { - try (TFloat64 tensor = TFloat64.scalarOf(number.doubleValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TFloat32.class)) { - try (TFloat32 tensor = TFloat32.scalarOf(number.floatValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TFloat16.class)) { - try (TFloat16 tensor = TFloat16.scalarOf(number.floatValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TInt64.class)) { - try (TInt64 tensor = TInt64.scalarOf(number.longValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TInt32.class)) { - try (TInt32 tensor = TInt32.scalarOf(number.intValue())) { - return (Constant) create(scope, tensor); - } - } else if (type.equals(TUint8.class)) { - try (TUint8 tensor = TUint8.scalarOf(number.byteValue())) { - return (Constant) create(scope, tensor); + try (TensorScope tensorScope = new TensorScope()) { + if (type.equals(TBfloat16.class)) { + try (TBfloat16 tensor = TBfloat16.scalarOf(tensorScope, number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat64.class)) { + try (TFloat64 tensor = TFloat64.scalarOf(tensorScope, number.doubleValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat32.class)) { + try (TFloat32 tensor = TFloat32.scalarOf(tensorScope, number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TFloat16.class)) { + try (TFloat16 tensor = TFloat16.scalarOf(tensorScope, number.floatValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TInt64.class)) { + try (TInt64 tensor = TInt64.scalarOf(tensorScope, number.longValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TInt32.class)) { + try (TInt32 tensor = TInt32.scalarOf(tensorScope, number.intValue())) { + return (Constant) create(scope, tensor); + } + } else if (type.equals(TUint8.class)) { + try (TUint8 tensor = TUint8.scalarOf(tensorScope, number.byteValue())) { + return (Constant) create(scope, tensor); + } + } else { + throw new IllegalArgumentException("Tensor type " + type + " is an abstract or unknown numeric type."); } - } else { - throw new IllegalArgumentException("Tensor type " + type + " is an abstract or unknown numeric type."); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java index ef20b5ec2b6..5203c0892bf 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBfloat16.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TBfloat16Mapper; import org.tensorflow.ndarray.FloatNdArray; @@ -34,14 +35,13 @@ * Brain 16-bit float tensor type. * *

This type differs from {@link TFloat16} as it truncates the mantissa of a 32-bit float and - * preserve all exponent bits for faster conversion, while the latter shrink the exponent and have a - * longer mantissa for more precision. + * preserve all exponent bits for faster conversion, while the latter shrink the exponent and have a longer mantissa for + * more precision. * *

Since there is no floating-point type that fits in 16 bits in Java, a conversion (with - * potentially a precision loss) is required for each 32 bits value written or read on a tensor of - * this type from the JVM. Therefore, if a lot of I/O operations are to be expected on a tensor, - * performances will be improved by working with {@link TFloat32} or {@link TFloat64} data types - * whenever possible. + * potentially a precision loss) is required for each 32 bits value written or read on a tensor of this type from the + * JVM. Therefore, if a lot of I/O operations are to be expected on a tensor, performances will be improved by working + * with {@link TFloat32} or {@link TFloat64} data types whenever possible. * *

Note that some CPUs support the bfloat16 format natively, which can result in faster * computation compared to {@link TFloat16} when GPUs are not used. @@ -52,24 +52,26 @@ public interface TBfloat16 extends FloatNdArray, TFloating { /** * Allocates a new tensor for storing a single float value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value float to store in the new tensor * @return the new tensor */ - static TBfloat16 scalarOf(float value) { - return Tensor.of(TBfloat16.class, Shape.scalar(), data -> data.setFloat(value)); + static TBfloat16 scalarOf(TensorScope scope, float value) { + return Tensor.of(scope, TBfloat16.class, Shape.scalar(), data -> data.setFloat(value)); } /** * Allocates a new tensor for storing a vector of floats. * + * @param scope the {@link TensorScope} to create the tensor in * @param values floats to store in the new tensor * @return the new tensor */ - static TBfloat16 vectorOf(float... values) { + static TBfloat16 vectorOf(TensorScope scope, float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TBfloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TBfloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -77,44 +79,48 @@ static TBfloat16 vectorOf(float... values) { * *

The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TBfloat16 tensorOf(NdArray src) { - return Tensor.of(TBfloat16.class, src.shape(), src::copyTo); + static TBfloat16 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TBfloat16.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TBfloat16 tensorOf(Shape shape) { - return Tensor.of(TBfloat16.class, shape); + static TBfloat16 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TBfloat16.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static TBfloat16 tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(TBfloat16.class, shape, d -> d.write(data)); + static TBfloat16 tensorOf(TensorScope scope, Shape shape, FloatDataBuffer data) { + return Tensor.of(scope, TBfloat16.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TBfloat16 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TBfloat16.class, shape, dataInit); + static TBfloat16 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TBfloat16.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java index 0158c12b910..47179faf045 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TBool.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TBoolMapper; import org.tensorflow.ndarray.BooleanNdArray; @@ -35,8 +36,8 @@ * Boolean tensor type. * *

If direct memory mapping is not available in the JVM, tensors of this type might require an - * explicit mapping between Java boolean values and byte buffers using the {@link DataLayouts#BOOL - * BOOL} layout, which may impact I/O performances. + * explicit mapping between Java boolean values and byte buffers using the {@link DataLayouts#BOOL BOOL} layout, which + * may impact I/O performances. */ @TensorType(dataType = DataType.DT_BOOL, byteSize = 1, mapperClass = TBoolMapper.class) public interface TBool extends BooleanNdArray, TType { @@ -44,24 +45,26 @@ public interface TBool extends BooleanNdArray, TType { /** * Allocates a new tensor for storing a single boolean value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value boolean to store in the new tensor * @return the new tensor */ - static TBool scalarOf(boolean value) { - return Tensor.of(TBool.class, Shape.scalar(), data -> data.setBoolean(value)); + static TBool scalarOf(TensorScope scope, boolean value) { + return Tensor.of(scope, TBool.class, Shape.scalar(), data -> data.setBoolean(value)); } /** * Allocates a new tensor for storing a vector of booleans. * + * @param scope the {@link TensorScope} to create the tensor in * @param values booleans to store in the new tensor * @return the new tensor */ - static TBool vectorOf(boolean... values) { + static TBool vectorOf(TensorScope scope, boolean... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TBool.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TBool.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -69,43 +72,47 @@ static TBool vectorOf(boolean... values) { * *

The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TBool tensorOf(NdArray src) { - return Tensor.of(TBool.class, src.shape(), src::copyTo); + static TBool tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TBool.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TBool tensorOf(Shape shape) { - return Tensor.of(TBool.class, shape); + static TBool tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TBool.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of booleans to initialize the tensor with * @return the new tensor */ - static TBool tensorOf(Shape shape, BooleanDataBuffer data) { - return Tensor.of(TBool.class, shape, d -> d.write(data)); + static TBool tensorOf(TensorScope scope, Shape shape, BooleanDataBuffer data) { + return Tensor.of(scope, TBool.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TBool tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TBool.class, shape, dataInit); + static TBool tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TBool.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java index a43a0831f10..9e907d7e77c 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat16.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TFloat16Mapper; import org.tensorflow.ndarray.FloatNdArray; @@ -34,14 +35,13 @@ * IEEE-754 half-precision 16-bit float tensor type. * *

Since there is no floating-point type that fits in 16 bits in Java, a conversion (with - * potentially a precision loss) is required for each 32 bits value written or read on a tensor of - * this type from the JVM. Therefore, if a lot of I/O operations are to be expected on a tensor, - * performances will be improved by working with {@link TFloat32} or {@link TFloat64} data types - * whenever possible. + * potentially a precision loss) is required for each 32 bits value written or read on a tensor of this type from the + * JVM. Therefore, if a lot of I/O operations are to be expected on a tensor, performances will be improved by working + * with {@link TFloat32} or {@link TFloat64} data types whenever possible. * *

Also, {@code TFloat16} tensors normally perform better if they are located in GPU memory since - * most CPUs do not support this format natively. For CPU computation on 16-bit floats, the {@link - * TBfloat16} tensor type might be a better option. + * most CPUs do not support this format natively. For CPU computation on 16-bit floats, the {@link TBfloat16} tensor + * type might be a better option. */ @TensorType(dataType = DataType.DT_HALF, byteSize = 2, mapperClass = TFloat16Mapper.class) public interface TFloat16 extends FloatNdArray, TFloating { @@ -49,24 +49,26 @@ public interface TFloat16 extends FloatNdArray, TFloating { /** * Allocates a new tensor for storing a single float value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value float to store in the new tensor * @return the new tensor */ - static TFloat16 scalarOf(float value) { - return Tensor.of(TFloat16.class, Shape.scalar(), data -> data.setFloat(value)); + static TFloat16 scalarOf(TensorScope scope, float value) { + return Tensor.of(scope, TFloat16.class, Shape.scalar(), data -> data.setFloat(value)); } /** * Allocates a new tensor for storing a vector of floats. * + * @param scope the {@link TensorScope} to create the tensor in * @param values floats to store in the new tensor * @return the new tensor */ - static TFloat16 vectorOf(float... values) { + static TFloat16 vectorOf(TensorScope scope, float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TFloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TFloat16.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -74,43 +76,47 @@ static TFloat16 vectorOf(float... values) { * *

The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TFloat16 tensorOf(NdArray src) { - return Tensor.of(TFloat16.class, src.shape(), src::copyTo); + static TFloat16 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TFloat16.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TFloat16 tensorOf(Shape shape) { - return Tensor.of(TFloat16.class, shape); + static TFloat16 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TFloat16.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static TFloat16 tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(TFloat16.class, shape, d -> d.write(data)); + static TFloat16 tensorOf(TensorScope scope, Shape shape, FloatDataBuffer data) { + return Tensor.of(scope, TFloat16.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TFloat16 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TFloat16.class, shape, dataInit); + static TFloat16 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TFloat16.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java index 35208f7de43..0e8496fd0a1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat32.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TFloat32Mapper; import org.tensorflow.ndarray.FloatNdArray; @@ -30,31 +31,35 @@ import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TFloating; -/** IEEE-754 single-precision 32-bit float tensor type. */ +/** + * IEEE-754 single-precision 32-bit float tensor type. + */ @TensorType(dataType = DataType.DT_FLOAT, byteSize = 4, mapperClass = TFloat32Mapper.class) public interface TFloat32 extends FloatNdArray, TFloating { /** * Allocates a new tensor for storing a single float value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value float to store in the new tensor * @return the new tensor */ - static TFloat32 scalarOf(float value) { - return Tensor.of(TFloat32.class, Shape.scalar(), data -> data.setFloat(value)); + static TFloat32 scalarOf(TensorScope scope, float value) { + return Tensor.of(scope, TFloat32.class, Shape.scalar(), data -> data.setFloat(value)); } /** * Allocates a new tensor for storing a vector of floats. * + * @param scope the {@link TensorScope} to create the tensor in * @param values floats to store in the new tensor * @return the new tensor */ - static TFloat32 vectorOf(float... values) { + static TFloat32 vectorOf(TensorScope scope, float... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TFloat32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TFloat32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -62,43 +67,47 @@ static TFloat32 vectorOf(float... values) { * *

The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TFloat32 tensorOf(NdArray src) { - return Tensor.of(TFloat32.class, src.shape(), src::copyTo); + static TFloat32 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TFloat32.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TFloat32 tensorOf(Shape shape) { - return Tensor.of(TFloat32.class, shape); + static TFloat32 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TFloat32.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of floats to initialize the tensor with * @return the new tensor */ - static TFloat32 tensorOf(Shape shape, FloatDataBuffer data) { - return Tensor.of(TFloat32.class, shape, d -> d.write(data)); + static TFloat32 tensorOf(TensorScope scope, Shape shape, FloatDataBuffer data) { + return Tensor.of(scope, TFloat32.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TFloat32 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TFloat32.class, shape, dataInit); + static TFloat32 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TFloat32.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java index 957612691e5..2c2f6f95f78 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TFloat64.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TFloat64Mapper; import org.tensorflow.ndarray.DoubleNdArray; @@ -31,31 +32,35 @@ import org.tensorflow.types.family.TFloating; -/** IEEE-754 double-precision 64-bit float tensor type. */ +/** + * IEEE-754 double-precision 64-bit float tensor type. + */ @TensorType(dataType = DataType.DT_DOUBLE, byteSize = 8, mapperClass = TFloat64Mapper.class) public interface TFloat64 extends DoubleNdArray, TFloating { /** * Allocates a new tensor for storing a single double value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value double to store in the new tensor * @return the new tensor */ - static TFloat64 scalarOf(double value) { - return Tensor.of(TFloat64.class, Shape.scalar(), data -> data.setDouble(value)); + static TFloat64 scalarOf(TensorScope scope, double value) { + return Tensor.of(scope, TFloat64.class, Shape.scalar(), data -> data.setDouble(value)); } /** * Allocates a new tensor for storing a vector of doubles. * + * @param scope the {@link TensorScope} to create the tensor in * @param values doubles to store in the new tensor * @return the new tensor */ - static TFloat64 vectorOf(double... values) { + static TFloat64 vectorOf(TensorScope scope, double... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TFloat64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TFloat64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -63,43 +68,47 @@ static TFloat64 vectorOf(double... values) { * *

The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TFloat64 tensorOf(NdArray src) { - return Tensor.of(TFloat64.class, src.shape(), src::copyTo); + static TFloat64 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TFloat64.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TFloat64 tensorOf(Shape shape) { - return Tensor.of(TFloat64.class, shape); + static TFloat64 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TFloat64.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of doubles to initialize the tensor with * @return the new tensor */ - static TFloat64 tensorOf(Shape shape, DoubleDataBuffer data) { - return Tensor.of(TFloat64.class, shape, d -> d.write(data)); + static TFloat64 tensorOf(TensorScope scope, Shape shape, DoubleDataBuffer data) { + return Tensor.of(scope, TFloat64.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TFloat64 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TFloat64.class, shape, dataInit); + static TFloat64 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TFloat64.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java index 8f6b587795b..6b005b25630 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt32.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.internal.types.TInt32Mapper; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.NdArray; @@ -29,32 +30,36 @@ import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TIntegral; -/** 32-bit signed integer tensor type. */ +/** + * 32-bit signed integer tensor type. + */ @TensorType(dataType = DataType.DT_INT32, byteSize = 4, mapperClass = TInt32Mapper.class) public interface TInt32 extends IntNdArray, TIntegral { /** * Allocates a new tensor for storing a single int value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value int to store in the new tensor * @return the new tensor */ - static TInt32 scalarOf(int value) { - return Tensor.of(TInt32.class, Shape.scalar(), data -> data.setInt(value)); + static TInt32 scalarOf(TensorScope scope, int value) { + return Tensor.of(scope, TInt32.class, Shape.scalar(), data -> data.setInt(value)); } /** * Allocates a new tensor for storing a vector of ints. * + * @param scope the {@link TensorScope} to create the tensor in * @param values ints to store in the new tensor * @return the new tensor * @throws IllegalArgumentException if no values are provided */ - static TInt32 vectorOf(int... values) { + static TInt32 vectorOf(TensorScope scope, int... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TInt32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TInt32.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -62,43 +67,47 @@ static TInt32 vectorOf(int... values) { * *

The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TInt32 tensorOf(NdArray src) { - return Tensor.of(TInt32.class, src.shape(), src::copyTo); + static TInt32 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TInt32.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TInt32 tensorOf(Shape shape) { - return Tensor.of(TInt32.class, shape); + static TInt32 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TInt32.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of ints to initialize the tensor with * @return the new tensor */ - static TInt32 tensorOf(Shape shape, IntDataBuffer data) { - return Tensor.of(TInt32.class, shape, d -> d.write(data)); + static TInt32 tensorOf(TensorScope scope, Shape shape, IntDataBuffer data) { + return Tensor.of(scope, TInt32.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor */ - static TInt32 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TInt32.class, shape, dataInit); + static TInt32 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TInt32.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java index 867248c5392..05a4434df65 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TInt64.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TInt64Mapper; import org.tensorflow.ndarray.LongNdArray; @@ -30,31 +31,35 @@ import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TIntegral; -/** 64-bit signed integer tensor type. */ +/** + * 64-bit signed integer tensor type. + */ @TensorType(dataType = DataType.DT_INT64, byteSize = 8, mapperClass = TInt64Mapper.class) public interface TInt64 extends LongNdArray, TIntegral { /** * Allocates a new tensor for storing a single long value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value long to store in the new tensor * @return the new tensor */ - static TInt64 scalarOf(long value) { - return Tensor.of(TInt64.class, Shape.scalar(), data -> data.setLong(value)); + static TInt64 scalarOf(TensorScope scope, long value) { + return Tensor.of(scope, TInt64.class, Shape.scalar(), data -> data.setLong(value)); } /** * Allocates a new tensor for storing a vector of longs. * + * @param scope the {@link TensorScope} to create the tensor in * @param values longs to store in the new tensor * @return the new tensor */ - static TInt64 vectorOf(long... values) { + static TInt64 vectorOf(TensorScope scope, long... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TInt64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TInt64.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -62,43 +67,47 @@ static TInt64 vectorOf(long... values) { * *

The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TInt64 tensorOf(NdArray src) { - return Tensor.of(TInt64.class, src.shape(), src::copyTo); + static TInt64 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TInt64.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TInt64 tensorOf(Shape shape) { - return Tensor.of(TInt64.class, shape); + static TInt64 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TInt64.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of longs to initialize the tensor with * @return the new tensor */ - static TInt64 tensorOf(Shape shape, LongDataBuffer data) { - return Tensor.of(TInt64.class, shape, d -> d.write(data)); + static TInt64 tensorOf(TensorScope scope, Shape shape, LongDataBuffer data) { + return Tensor.of(scope, TInt64.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TInt64 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TInt64.class, shape, dataInit); + static TInt64 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TInt64.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java index b3000cc2f8a..7cb098257a1 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TString.java @@ -21,6 +21,7 @@ import java.nio.charset.StandardCharsets; import java.util.function.Function; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.internal.types.TStringInitializer; import org.tensorflow.internal.types.TStringMapper; import org.tensorflow.ndarray.NdArray; @@ -37,8 +38,8 @@ *

This type can be used to store any arbitrary byte sequence of variable length. * *

Since the size of a tensor is fixed, creating a tensor of this type requires to provide all of - * its values initially, so TensorFlow can compute and allocate the right amount of memory. Then the - * data in the tensor is initialized once and cannot be modified afterwards. + * its values initially, so TensorFlow can compute and allocate the right amount of memory. Then the data in the tensor + * is initialized once and cannot be modified afterwards. */ @TensorType(dataType = DataType.DT_STRING, byteSize = -1, mapperClass = TStringMapper.class) public interface TString extends NdArray, TType { @@ -48,11 +49,12 @@ public interface TString extends NdArray, TType { * *

The string is encoded into bytes using the UTF-8 charset. * + * @param scope the {@link TensorScope} to create the tensor in * @param value scalar value to store in the new tensor * @return the new tensor */ - static TString scalarOf(String value) { - return tensorOf(NdArrays.scalarOfObject(value)); + static TString scalarOf(TensorScope scope, String value) { + return tensorOf(scope, NdArrays.scalarOfObject(value)); } /** @@ -60,14 +62,15 @@ static TString scalarOf(String value) { * *

The strings are encoded into bytes using the UTF-8 charset. * + * @param scope the {@link TensorScope} to create the tensor in * @param values values to store in the new tensor * @return the new tensor */ - static TString vectorOf(String... values) { + static TString vectorOf(TensorScope scope, String... values) { if (values == null) { throw new IllegalArgumentException(); } - return tensorOf(NdArrays.vectorOfObjects(values)); + return tensorOf(scope, NdArrays.vectorOfObjects(values)); } /** @@ -76,11 +79,12 @@ static TString vectorOf(String... values) { *

The tensor will have the same shape as the source array and its data will be copied. The * strings are encoded into bytes using the UTF-8 charset. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TString tensorOf(NdArray src) { - return tensorOf(StandardCharsets.UTF_8, src); + static TString tensorOf(TensorScope scope, NdArray src) { + return tensorOf(scope, StandardCharsets.UTF_8, src); } /** @@ -100,13 +104,14 @@ static TString tensorOf(NdArray src) { * assertEquals(originalStrings.getObject(0), tensorStrings.getObject(0)); * }

* + * @param scope the {@link TensorScope} to create the tensor in * @param charset charset to use for encoding the strings into bytes * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TString tensorOf(Charset charset, NdArray src) { + static TString tensorOf(TensorScope scope, Charset charset, NdArray src) { TStringInitializer initializer = new TStringInitializer<>(src, s -> s.getBytes(charset)); - return Tensor.of(TString.class, src.shape(), initializer.computeRequiredSize(), initializer); + return Tensor.of(scope, TString.class, src.shape(), initializer.computeRequiredSize(), initializer); } /** @@ -115,12 +120,13 @@ static TString tensorOf(Charset charset, NdArray src) { *

The data will be copied from the provided buffer to the tensor after it is allocated. The * strings are encoded into bytes using the UTF-8 charset. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor * @param data buffer of strings to initialize the tensor with * @return the new tensor */ - static TString tensorOf(Shape shape, DataBuffer data) { - return tensorOf(NdArrays.wrap(shape, data)); + static TString tensorOf(TensorScope scope, Shape shape, DataBuffer data) { + return tensorOf(scope, NdArrays.wrap(shape, data)); } /** @@ -141,13 +147,14 @@ static TString tensorOf(Shape shape, DataBuffer data) { * assertEquals(originalStrings.getObject(0), tensorStrings.getObject(0)); * }

* + * @param scope the {@link TensorScope} to create the tensor in * @param charset charset to use for encoding the strings into bytes * @param shape shape of the tensor * @param data buffer of strings to initialize the tensor with * @return the new tensor */ - static TString tensorOf(Charset charset, Shape shape, DataBuffer data) { - return tensorOf(charset, NdArrays.wrap(shape, data)); + static TString tensorOf(TensorScope scope, Charset charset, Shape shape, DataBuffer data) { + return tensorOf(scope, charset, NdArrays.wrap(shape, data)); } /** @@ -162,12 +169,13 @@ static TString tensorOf(Charset charset, Shape shape, DataBuffer data) { * byte[] bytes = tensor.data().asBytes().getObject(0); // returns first sequence of bytes in the tensor * }
* + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TString tensorOfBytes(NdArray src) { + static TString tensorOfBytes(TensorScope scope, NdArray src) { TStringInitializer initializer = new TStringInitializer<>(src, Function.identity()); - return Tensor.of(TString.class, src.shape(), initializer.computeRequiredSize(), initializer); + return Tensor.of(scope, TString.class, src.shape(), initializer.computeRequiredSize(), initializer); } /** @@ -182,12 +190,13 @@ static TString tensorOfBytes(NdArray src) { * byte[] bytes = tensor.data().asBytes().getObject(0); // returns first sequence of bytes in the tensor * }
* + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to create * @param data the source array giving the shape and data to the new tensor * @return the new tensor */ - static TString tensorOfBytes(Shape shape, DataBuffer data) { - return tensorOfBytes(NdArrays.wrap(shape, data)); + static TString tensorOfBytes(TensorScope scope, Shape shape, DataBuffer data) { + return tensorOfBytes(scope, NdArrays.wrap(shape, data)); } /** @@ -208,6 +217,8 @@ static TString tensorOfBytes(Shape shape, DataBuffer data) { */ TString using(Charset charset); - /** @return the tensor data as a n-dimensional array of raw byte sequences. */ + /** + * @return the tensor data as a n-dimensional array of raw byte sequences. + */ NdArray asBytes(); } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java index eae86414cb4..8744b8016a6 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/TUint8.java @@ -19,6 +19,7 @@ import java.util.function.Consumer; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.internal.types.TUint8Mapper; import org.tensorflow.ndarray.ByteNdArray; @@ -30,31 +31,35 @@ import org.tensorflow.types.annotation.TensorType; import org.tensorflow.types.family.TIntegral; -/** 8-bit unsigned integer tensor type. */ +/** + * 8-bit unsigned integer tensor type. + */ @TensorType(dataType = DataType.DT_UINT8, byteSize = 1, mapperClass = TUint8Mapper.class) public interface TUint8 extends ByteNdArray, TIntegral { /** * Allocates a new tensor for storing a single byte value. * + * @param scope the {@link TensorScope} to create the tensor in * @param value byte to store in the new tensor * @return the new tensor */ - static TUint8 scalarOf(byte value) { - return Tensor.of(TUint8.class, Shape.scalar(), data -> data.setByte(value)); + static TUint8 scalarOf(TensorScope scope, byte value) { + return Tensor.of(scope, TUint8.class, Shape.scalar(), data -> data.setByte(value)); } /** * Allocates a new tensor for storing a vector of bytes. * + * @param scope the {@link TensorScope} to create the tensor in * @param values bytes to store in the new tensor * @return the new tensor */ - static TUint8 vectorOf(byte... values) { + static TUint8 vectorOf(TensorScope scope, byte... values) { if (values == null) { throw new IllegalArgumentException(); } - return Tensor.of(TUint8.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); + return Tensor.of(scope, TUint8.class, Shape.of(values.length), data -> StdArrays.copyTo(values, data)); } /** @@ -62,43 +67,47 @@ static TUint8 vectorOf(byte... values) { * *

The tensor will have the same shape as the source array and its data will be copied. * + * @param scope the {@link TensorScope} to create the tensor in * @param src the source array giving the shape and data to the new tensor * @return the new tensor */ - static TUint8 tensorOf(NdArray src) { - return Tensor.of(TUint8.class, src.shape(), src::copyTo); + static TUint8 tensorOf(TensorScope scope, NdArray src) { + return Tensor.of(scope, TUint8.class, src.shape(), src::copyTo); } /** * Allocates a new tensor of the given shape. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @return the new tensor */ - static TUint8 tensorOf(Shape shape) { - return Tensor.of(TUint8.class, shape); + static TUint8 tensorOf(TensorScope scope, Shape shape) { + return Tensor.of(scope, TUint8.class, shape); } /** * Allocates a new tensor of the given shape, initialized with the provided data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param data buffer of bytes to initialize the tensor with * @return the new tensor */ - static TUint8 tensorOf(Shape shape, ByteDataBuffer data) { - return Tensor.of(TUint8.class, shape, d -> d.write(data)); + static TUint8 tensorOf(TensorScope scope, Shape shape, ByteDataBuffer data) { + return Tensor.of(scope, TUint8.class, shape, d -> d.write(data)); } /** * Allocates a new tensor of the given shape and initialize its data. * + * @param scope the {@link TensorScope} to create the tensor in * @param shape shape of the tensor to allocate * @param dataInit tensor data initializer * @return the new tensor * @throws TensorFlowException if the tensor cannot be allocated or initialized */ - static TUint8 tensorOf(Shape shape, Consumer dataInit) { - return Tensor.of(TUint8.class, shape, dataInit); + static TUint8 tensorOf(TensorScope scope, Shape shape, Consumer dataInit) { + return Tensor.of(scope, TUint8.class, shape, dataInit); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java index 2fc423b914e..afe10f685a8 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/types/family/TType.java @@ -27,10 +27,9 @@ * to a n-dimensional data space allowing direct I/O access from the JVM.

* *

Subinterfaces of {@code TType} are propagated as a generic parameter to various entities of - * TensorFlow to identify the type of the tensor they carry. For example, a - * {@link org.tensorflow.Operand Operand} is an operand which outputs a 32-bit floating - * point tensor. This parameter ensure type-compatibility between operands of a computation at - * compile-time. For example: + * TensorFlow to identify the type of the tensor they carry. For example, a {@link org.tensorflow.Operand + * Operand} is an operand which outputs a 32-bit floating point tensor. This parameter ensure + * type-compatibility between operands of a computation at compile-time. For example: * *

{@code
  * Ops tf = Ops.create();
@@ -44,8 +43,8 @@
  * }
* *

Even if all typed tensors implements somehow {@link org.tensorflow.ndarray.NdArray NdArray} - * to provide access to their data, {@code TType} deliberately does not extend directly from this - * interface, for the following reasons: + * to provide access to their data, {@code TType} deliberately does not extend directly from this interface, for the + * following reasons: *

    *
  • Implementing {@code NdArray} at this level could only expose boxed-type accessors, which * are less performant than their primitive equivalent, only exposed by subinterfaces of @@ -80,4 +79,14 @@ default long numBytes() { default void close() { asRawTensor().close(); } + + @Override + default boolean isClosed() { + return asRawTensor().isClosed(); + } + + @Override + default boolean isAttached() { + return asRawTensor().isAttached(); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java deleted file mode 100644 index 330a40bae6b..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/AutoCloseableList.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.tensorflow; - -import java.util.ArrayList; -import java.util.Collection; - -public final class AutoCloseableList extends ArrayList - implements AutoCloseable { - - public AutoCloseableList(Collection c) { - super(c); - } - - @Override - public void close() { - Exception toThrow = null; - for (AutoCloseable c : this) { - try { - c.close(); - } catch (Exception e) { - toThrow = e; - } - } - if (toThrow != null) { - throw new RuntimeException(toThrow); - } - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java index b2b2c34e223..515a5faa067 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/ConcreteFunctionTest.java @@ -43,32 +43,30 @@ private static Signature minusTwo(Ops tf) { @Test public void createFunction() { try (ConcreteFunction f = ConcreteFunction.create(ConcreteFunctionTest::plusFive); - TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); + TensorScope scope = new TensorScope()) { + TFloat32 x = TFloat32.scalarOf(scope, 3.0f); + assertEquals(8.0f, ((TFloat32) f.call(scope, x)).getFloat()); } } @Test public void createFunctionFromGraph() { - try (Graph g = new Graph()) { - Signature signature = plusFive(Ops.create(g)); - try (ConcreteFunction f = ConcreteFunction.create(signature, g); - TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); - } + try (Graph g = new Graph(); + TensorScope scope = new TensorScope(); + ConcreteFunction f = ConcreteFunction.create(plusFive(Ops.create(g)), g)) { + TFloat32 x = TFloat32.scalarOf(scope, 3.0f); + assertEquals(8.0f, ((TFloat32) f.call(scope, x)).getFloat()); } } @Test public void createFunctionFromSession() { - try (Graph g = new Graph()) { - Signature signature = plusFive(Ops.create(g)); - try (Session s = new Session(g)) { - try (ConcreteFunction f = ConcreteFunction.create(signature, s); - TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(8.0f, ((TFloat32)f.call(x)).getFloat()); - } - } + try (Graph g = new Graph(); + Session s = new Session(g); + TensorScope scope = new TensorScope(); + ConcreteFunction f = ConcreteFunction.create(plusFive(Ops.create(g)), s)) { + TFloat32 x = TFloat32.scalarOf(scope, 3.0f); + assertEquals(8.0f, ((TFloat32) f.call(scope, x)).getFloat()); } } @@ -76,8 +74,9 @@ public void createFunctionFromSession() { public void chainFunctions() { try (ConcreteFunction f1 = ConcreteFunction.create(ConcreteFunctionTest::plusFive); ConcreteFunction f2 = ConcreteFunction.create(ConcreteFunctionTest::minusTwo); - TFloat32 x = TFloat32.scalarOf(3.0f)) { - assertEquals(6.0f, ((TFloat32)f2.call(f1.call(x))).getFloat()); + TensorScope scope = new TensorScope()) { + TFloat32 x = TFloat32.scalarOf(scope, 3.0f); + assertEquals(6.0f, ((TFloat32) f2.call(scope, f1.call(scope, x))).getFloat()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java index e4340da3275..47e289efffd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/DeviceSpecTest.java @@ -14,6 +14,12 @@ ==============================================================================*/ package org.tensorflow; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; +import static org.tensorflow.DeviceSpec.DeviceType; + +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; @@ -21,200 +27,200 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.types.TInt32; -import static com.google.common.truth.Truth.assertThat; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.fail; -import static org.tensorflow.DeviceSpec.DeviceType; - -/** Tests for {@link DeviceSpec}. */ +/** + * Tests for {@link DeviceSpec}. + */ public class DeviceSpecTest { + @Test public void withDeviceMethod() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) - .build(); + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - Ops tf = Ops.create(g).withSubScope("testScope"); + try (Graph g = new Graph(); Session session = new Session(g, config)) { + Ops tf = Ops.create(g).withSubScope("testScope"); - Constant aOps = tf.constant(-1); + Constant aOps = tf.constant(-1); - DeviceSpec deviceSpec = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); + DeviceSpec deviceSpec = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) + .build(); - Output absOps = tf - .withName("absWithDevice") - .withDevice(deviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = tf + .withName("absWithDevice") + .withDevice(deviceSpec) + .math + .abs(aOps) + .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + List t = session.runner().fetch(absOps).run(scope); + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } @Test public void withEmptyDeviceSpec() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) - .build(); + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - Ops tf = Ops.create(g).withSubScope("testScope"); + try (Graph g = new Graph(); Session session = new Session(g, config)) { + Ops tf = Ops.create(g).withSubScope("testScope"); - Constant aOps = tf.constant(-1); + Constant aOps = tf.constant(-1); - DeviceSpec deviceSpec = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); + DeviceSpec deviceSpec = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) + .build(); - Output absOps = tf - .withName("absWithDevice") - .withDevice(deviceSpec) - .math - .abs(aOps) - .asOutput(); + Output absOps = tf + .withName("absWithDevice") + .withDevice(deviceSpec) + .math + .abs(aOps) + .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + List t = session.runner().fetch(absOps).run(scope); + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } @Test public void withTwoScopes() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); + + try (Graph g = new Graph(); Session session = new Session(g, config)) { + DeviceSpec deviceSpec1 = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) + .build(); + + DeviceSpec deviceSpec2 = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec deviceSpec1 = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); - - DeviceSpec deviceSpec2 = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); - - Ops tf1 = Ops.create(g).withSubScope("testScope1").withDevice(deviceSpec1); - Ops tf2 = Ops.create(g).withSubScope("testScope2").withDevice(deviceSpec2); - - Constant aOps = tf1.constant(-1); - Constant bOps = tf2.constant(10); - - Output absOps = tf1 - .withName("absWithDevice") - .math - .abs(aOps) - .asOutput(); - - Output mulOps = tf2 - .withName("mulWithDevice") - .math - .mul(absOps, bOps) - .asOutput(); - - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { - assertEquals(10, ((TInt32)t.get(0)).getInt()); + Ops tf1 = Ops.create(g).withSubScope("testScope1").withDevice(deviceSpec1); + Ops tf2 = Ops.create(g).withSubScope("testScope2").withDevice(deviceSpec2); + + Constant aOps = tf1.constant(-1); + Constant bOps = tf2.constant(10); + + Output absOps = tf1 + .withName("absWithDevice") + .math + .abs(aOps) + .asOutput(); + + Output mulOps = tf2 + .withName("mulWithDevice") + .math + .mul(absOps, bOps) + .asOutput(); + + List t = session.runner().fetch(mulOps).run(scope); + assertEquals(10, ((TInt32) t.get(0)).getInt()); } } } @Test public void withIncorrectDeviceSpec() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); + + try (Graph g = new Graph(); Session session = new Session(g, config)) { + DeviceSpec correctDeviceSpec = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec correctDeviceSpec = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); - - // Incorrect device spec, it will never be executed - DeviceSpec incorrectDeviceSpec = DeviceSpec.newBuilder() - .job("UNKNOWN") - .replica(1) - .task(1000) - .deviceType(DeviceType.TPU) - .build(); - - Ops tf = Ops.create(g); - - Constant aOps = tf.constant(-1); - Constant bOps = tf.constant(10); - - Output absOps = tf - .withName("absWithDevice") - .withDevice(incorrectDeviceSpec) - .math - .abs(aOps) - .asOutput(); - - Output mulOps = tf - .withName("mulWithDevice") - .withDevice(correctDeviceSpec) - .math - .mul(absOps, bOps) - .asOutput(); - - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(mulOps).run())) { - fail(); - } catch (TFInvalidArgumentException e) { - // ok + // Incorrect device spec, it will never be executed + DeviceSpec incorrectDeviceSpec = DeviceSpec.newBuilder() + .job("UNKNOWN") + .replica(1) + .task(1000) + .deviceType(DeviceType.TPU) + .build(); + + Ops tf = Ops.create(g); + + Constant aOps = tf.constant(-1); + Constant bOps = tf.constant(10); + + Output absOps = tf + .withName("absWithDevice") + .withDevice(incorrectDeviceSpec) + .math + .abs(aOps) + .asOutput(); + + Output mulOps = tf + .withName("mulWithDevice") + .withDevice(correctDeviceSpec) + .math + .mul(absOps, bOps) + .asOutput(); + + try { + List t = session.runner().fetch(mulOps).run(scope); + fail(); + } catch (TFInvalidArgumentException e) { + // ok + } } } } @Test public void withDeviceSpecInScope() { - ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) - .setLogDevicePlacement(true) + try (TensorScope scope = new TensorScope()) { + ConfigProto config = ConfigProto.newBuilder(ConfigProto.getDefaultInstance()) + .setLogDevicePlacement(true) + .build(); + + try (Graph g = new Graph(); Session session = new Session(g, config)) { + DeviceSpec deviceSpec = DeviceSpec.newBuilder() + .job("localhost") + .replica(0) + .task(0) + .deviceType(DeviceSpec.DeviceType.CPU) .build(); - try (Graph g = new Graph(); Session session = new Session(g, config)) { - DeviceSpec deviceSpec = DeviceSpec.newBuilder() - .job("localhost") - .replica(0) - .task(0) - .deviceType(DeviceSpec.DeviceType.CPU) - .build(); - - Ops tf = Ops.create(g).withSubScope("testScope").withDevice(deviceSpec); + Ops tf = Ops.create(g).withSubScope("testScope").withDevice(deviceSpec); - Constant aOps = tf.constant(-1); + Constant aOps = tf.constant(-1); - Output absOps = tf - .withName("absWithDevice") - .math - .abs(aOps) - .asOutput(); + Output absOps = tf + .withName("absWithDevice") + .math + .abs(aOps) + .asOutput(); - try (AutoCloseableList t = - new AutoCloseableList<>(session.runner().fetch(absOps).run())) { - assertEquals(1, ((TInt32)t.get(0)).getInt()); + List t = session.runner().fetch(absOps).run(scope); + assertEquals(1, ((TInt32) t.get(0)).getInt()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java index b39ecec9881..0b474ec795a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationBuilderTest.java @@ -23,7 +23,9 @@ import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TInt32; -/** Unit tests for {@link EagerOperationBuilder} class. */ +/** + * Unit tests for {@link EagerOperationBuilder} class. + */ public class EagerOperationBuilderTest { @Test @@ -59,7 +61,7 @@ public void addInputs() { Operation asrt = opBuilder(session, "Assert", "assert") .addInput(tf.constant(true).asOutput()) - .addInputList(new Output[] {tf.constant(-1).asOutput()}) + .addInputList(new Output[]{tf.constant(-1).asOutput()}) .build(); opBuilder(session, "Const", "var").addControlInput(asrt); } @@ -79,62 +81,63 @@ public void setDevice() { @Test public void setAttrs() { - // The effect of setting an attribute may not easily be visible from the other parts of this - // package's API. Thus, for now, the test simply executes the various setAttr variants to see - // that there are no exceptions. - // - // This is a bit of an awkward test since it has to find operations with attributes of specific - // types that aren't inferred from the input arguments. - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); - // dtype, tensor attributes. - try (TInt32 t = TInt32.scalarOf(1)) { + try (TensorScope scope = new TensorScope()) { + // The effect of setting an attribute may not easily be visible from the other parts of this + // package's API. Thus, for now, the test simply executes the various setAttr variants to see + // that there are no exceptions. + // + // This is a bit of an awkward test since it has to find operations with attributes of specific + // types that aren't inferred from the input arguments. + try (EagerSession session = EagerSession.create()) { + Ops tf = Ops.create(session); + // dtype, tensor attributes. + TInt32 t = TInt32.scalarOf(scope, 1); opBuilder(session, "Const", "DataTypeAndTensor") .setAttr("dtype", t.dataType()) .setAttr("value", t) .build(); + // type, int (TF "int" attributes are 64-bit signed, so a Java long). + opBuilder(session, "RandomUniform", "DataTypeAndInt") + .addInput(tf.array(1).asOutput()) + .setAttr("seed", 10) + .setAttr("dtype", DataType.DT_FLOAT) + .build(); + // list(int), string + opBuilder(session, "MaxPool", "IntListAndString") + .addInput(tf.constant(new float[2][2][2][2]).asOutput()) + .setAttr("ksize", new long[]{1, 1, 1, 1}) + .setAttr("strides", new long[]{1, 1, 1, 1}) + .setAttr("padding", "SAME") + .build(); + // list(float), device + opBuilder(session, "FractionalMaxPool", "FloatList") + .addInput(tf.constant(new float[2][2][2][2]).asOutput()) + .setAttr("pooling_ratio", new float[]{1.0f, 1.44f, 1.73f, 1.0f}) + .build(); + // shape + opBuilder(session, "EnsureShape", "ShapeAttr") + .addInput(tf.constant(new int[2][2]).asOutput()) + .setAttr("shape", Shape.of(2, 2)) + .build(); + // list(shape) + opBuilder(session, "FIFOQueue", "queue") + .setAttr("component_types", new DataType[]{DataType.DT_INT32, DataType.DT_INT32}) + .setAttr("shapes", new Shape[]{Shape.of(2, 2), Shape.of(2, 2, 2)}) + .build(); + // bool + opBuilder(session, "All", "Bool") + .addInput(tf.constant(new boolean[]{true, true, false}).asOutput()) + .addInput(tf.constant(0).asOutput()) + .setAttr("keep_dims", false) + .build(); + // float + opBuilder(session, "ApproximateEqual", "Float") + .addInput(tf.constant(10.00001f).asOutput()) + .addInput(tf.constant(10.00000f).asOutput()) + .setAttr("tolerance", 0.1f) + .build(); + // Missing tests: list(string), list(byte), list(bool), list(type) } - // type, int (TF "int" attributes are 64-bit signed, so a Java long). - opBuilder(session, "RandomUniform", "DataTypeAndInt") - .addInput(tf.array(1).asOutput()) - .setAttr("seed", 10) - .setAttr("dtype", DataType.DT_FLOAT) - .build(); - // list(int), string - opBuilder(session, "MaxPool", "IntListAndString") - .addInput(tf.constant(new float[2][2][2][2]).asOutput()) - .setAttr("ksize", new long[] {1, 1, 1, 1}) - .setAttr("strides", new long[] {1, 1, 1, 1}) - .setAttr("padding", "SAME") - .build(); - // list(float), device - opBuilder(session, "FractionalMaxPool", "FloatList") - .addInput(tf.constant(new float[2][2][2][2]).asOutput()) - .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) - .build(); - // shape - opBuilder(session, "EnsureShape", "ShapeAttr") - .addInput(tf.constant(new int[2][2]).asOutput()) - .setAttr("shape", Shape.of(2, 2)) - .build(); - // list(shape) - opBuilder(session, "FIFOQueue", "queue") - .setAttr("component_types", new DataType[] {DataType.DT_INT32, DataType.DT_INT32}) - .setAttr("shapes", new Shape[] {Shape.of(2, 2), Shape.of(2, 2, 2)}) - .build(); - // bool - opBuilder(session, "All", "Bool") - .addInput(tf.constant(new boolean[] {true, true, false}).asOutput()) - .addInput(tf.constant(0).asOutput()) - .setAttr("keep_dims", false) - .build(); - // float - opBuilder(session, "ApproximateEqual", "Float") - .addInput(tf.constant(10.00001f).asOutput()) - .addInput(tf.constant(10.00000f).asOutput()) - .setAttr("tolerance", 0.1f) - .build(); - // Missing tests: list(string), list(byte), list(bool), list(type) } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java index 38714b86599..941b2073474 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/EagerOperationTest.java @@ -35,7 +35,8 @@ public class EagerOperationTest { public void failToCreateIfSessionIsClosed() { EagerSession session = EagerSession.create(); session.close(); - try (TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { + try (TensorScope scope = new TensorScope()) { + TInt32 t = TInt32.tensorOf(scope, Shape.of(2, 3)); EagerOperation op = opBuilder(session, "Const", "OutputAttrs") .setAttr("dtype", t.dataType()) @@ -50,7 +51,8 @@ public void failToCreateIfSessionIsClosed() { @Test public void outputDataTypeAndShape() { try (EagerSession session = EagerSession.create(); - TInt32 t = TInt32.tensorOf(Shape.of(2, 3))) { + TensorScope scope = new TensorScope()) { + TInt32 t = TInt32.tensorOf(scope, Shape.of(2, 3)); EagerOperation op = opBuilder(session, "Const", "OutputAttrs") .setAttr("dtype", t.dataType()) @@ -64,14 +66,15 @@ public void outputDataTypeAndShape() { @Test public void outputTensor() { - try (EagerSession session = EagerSession.create()) { + try (EagerSession session = EagerSession.create(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(session); EagerOperation add = opBuilder(session, "Add", "CompareResult") .addInput(tf.constant(2).asOutput()) .addInput(tf.constant(4).asOutput()) .build(); - assertEquals(6, ((TInt32)add.tensor(0)).getInt()); + assertEquals(6, ((TInt32) add.tensor(scope, 0)).getInt()); // Validate that we retrieve the right shape and datatype from the tensor // that has been resolved @@ -154,7 +157,8 @@ public void opNotAccessibleIfSessionIsClosed() { @Test public void outputIndexOutOfBounds() { - try (EagerSession session = EagerSession.create()) { + try (EagerSession session = EagerSession.create(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(session); EagerOperation add = opBuilder(session, "Add", "OutOfRange") @@ -180,7 +184,7 @@ public void outputIndexOutOfBounds() { // expected } try { - add.tensor(1); + add.tensor(scope, 1); fail(); } catch (IndexOutOfBoundsException e) { // expected diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java index 33ae979ccbd..d19781d8fb1 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationBuilderTest.java @@ -27,13 +27,16 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.GraphOperationBuilder}. */ +/** + * Unit tests for {@link org.tensorflow.GraphOperationBuilder}. + */ public class GraphOperationBuilderTest { @Test public void failOnUseAfterBuild() { try (Graph g = new Graph(); - TInt32 t = TInt32.scalarOf(1)) { + TensorScope scope = new TensorScope()) { + TInt32 t = TInt32.scalarOf(scope, 1); OperationBuilder b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); b.build(); @@ -49,7 +52,8 @@ public void failOnUseAfterBuild() { public void failOnUseAfterGraphClose() { OperationBuilder b = null; try (Graph g = new Graph(); - TInt32 t = TInt32.scalarOf(1)) { + TensorScope scope = new TensorScope()) { + TInt32 t = TInt32.scalarOf(scope, 1); b = g.opBuilder("Const", "Const").setAttr("dtype", t.dataType()).setAttr("value", t); } try { @@ -68,17 +72,17 @@ public void setAttr() { // // This is a bit of an awkward test since it has to find operations with attributes of specific // types that aren't inferred from the input arguments. - try (Graph g = new Graph()) { + try (Graph g = new Graph(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); // dtype, tensor attributes. - try (TInt32 t = TInt32.scalarOf(1)) { - g.opBuilder("Const", "DataTypeAndTensor") - .setAttr("dtype", t.dataType()) - .setAttr("value", t) - .build() - .output(0); - assertTrue(hasNode(g, "DataTypeAndTensor")); - } + TInt32 t = TInt32.scalarOf(scope, 1); + g.opBuilder("Const", "DataTypeAndTensor") + .setAttr("dtype", t.dataType()) + .setAttr("value", t) + .build() + .output(0); + assertTrue(hasNode(g, "DataTypeAndTensor")); // string, bool attributes. g.opBuilder("Abort", "StringAndBool") .setAttr("error_msg", "SomeErrorMessage") @@ -95,15 +99,15 @@ public void setAttr() { // list(int) g.opBuilder("MaxPool", "IntList") .addInput(tf.constant(new float[2][2][2][2]).asOutput()) - .setAttr("ksize", new long[] {1, 1, 1, 1}) - .setAttr("strides", new long[] {1, 1, 1, 1}) + .setAttr("ksize", new long[]{1, 1, 1, 1}) + .setAttr("strides", new long[]{1, 1, 1, 1}) .setAttr("padding", "SAME") .build(); assertTrue(hasNode(g, "IntList")); // list(float) g.opBuilder("FractionalMaxPool", "FloatList") .addInput(tf.constant(new float[2][2][2][2]).asOutput()) - .setAttr("pooling_ratio", new float[] {1.0f, 1.44f, 1.73f, 1.0f}) + .setAttr("pooling_ratio", new float[]{1.0f, 1.44f, 1.73f, 1.0f}) .build(); assertTrue(hasNode(g, "FloatList")); // Missing tests: float, list(dtype), list(tensor), list(string), list(bool) @@ -138,10 +142,10 @@ public void setAttrShape() { @Test public void setAttrShapeList() { // Those shapes match tensors ones, so no exception is thrown - testSetAttrShapeList(new Shape[] {Shape.of(2, 2), Shape.of(2, 2, 2)}); + testSetAttrShapeList(new Shape[]{Shape.of(2, 2), Shape.of(2, 2, 2)}); try { // Those shapes do not match tensors ones, exception is thrown - testSetAttrShapeList(new Shape[] {Shape.of(2, 2), Shape.of(2, 2, 2, 2)}); + testSetAttrShapeList(new Shape[]{Shape.of(2, 2), Shape.of(2, 2, 2, 2)}); fail("Shapes are incompatible and an exception was expected"); } catch (TFInvalidArgumentException e) { // expected @@ -152,23 +156,24 @@ public void setAttrShapeList() { public void addControlInput() { try (Graph g = new Graph(); Session s = new Session(g); - TBool yes = TBool.scalarOf(true); - TBool no = TBool.scalarOf(false)) { + TensorScope scope = new TensorScope()) { + TBool yes = TBool.scalarOf(scope, true); + TBool no = TBool.scalarOf(scope, false); Ops tf = Ops.create(g); Output placeholder = tf.placeholder(TBool.class).asOutput(); GraphOperation check = g.opBuilder("Assert", "assert") .addInput(placeholder) - .addInputList(new Output[] {placeholder}) + .addInputList(new Output[]{placeholder}) .build(); Operation noop = g.opBuilder("NoOp", "noop").addControlInput(check).build(); // No problems when the Assert check succeeds - s.runner().feed(placeholder, yes).addTarget(noop).run(); + s.runner().feed(placeholder, yes).addTarget(noop).run(scope); // Exception thrown by the execution of the Assert node try { - s.runner().feed(placeholder, no).addTarget(noop).run(); + s.runner().feed(placeholder, no).addTarget(noop).run(scope); fail("Did not run control operation."); } catch (TFInvalidArgumentException e) { // expected @@ -178,26 +183,27 @@ public void addControlInput() { private static void testSetAttrShapeList(Shape[] shapes) { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); - int[][] matrix = new int[][] {{0, 0}, {0, 0}}; + int[][] matrix = new int[][]{{0, 0}, {0, 0}}; Output queue = g.opBuilder("FIFOQueue", "queue") - .setAttr("component_types", new DataType[] {DataType.DT_INT32, DataType.DT_INT32}) + .setAttr("component_types", new DataType[]{DataType.DT_INT32, DataType.DT_INT32}) .setAttr("shapes", shapes) .build() .output(0); assertTrue(hasNode(g, "queue")); Output c1 = tf.constant(matrix).asOutput(); - Output c2 = tf.constant(new int[][][] {matrix, matrix}).asOutput(); + Output c2 = tf.constant(new int[][][]{matrix, matrix}).asOutput(); Operation enqueue = g.opBuilder("QueueEnqueue", "enqueue") .addInput(queue) - .addInputList(new Output[] {c1, c2}) + .addInputList(new Output[]{c1, c2}) .build(); assertTrue(hasNode(g, "enqueue")); - s.runner().addTarget(enqueue).run(); + s.runner().addTarget(enqueue).run(scope); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java index b164c129745..7bc0e0953fd 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphOperationTest.java @@ -29,7 +29,9 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.GraphOperation}. */ +/** + * Unit tests for {@link org.tensorflow.GraphOperation}. + */ public class GraphOperationTest { @Test @@ -53,8 +55,8 @@ public void operationEquality() { GraphOperation op1; try (Graph g = new Graph()) { Ops tf = Ops.create(g); - op1 = (GraphOperation)tf.withName("op1").constant(1).op(); - GraphOperation op2 = (GraphOperation)tf.withName("op2").constant(2).op(); + op1 = (GraphOperation) tf.withName("op1").constant(1).op(); + GraphOperation op2 = (GraphOperation) tf.withName("op2").constant(2).op(); GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle()); GraphOperation op4 = g.operation("op1"); assertEquals(op1, op1); @@ -78,8 +80,8 @@ public void operationEquality() { public void operationCollection() { try (Graph g = new Graph()) { Ops tf = Ops.create(g); - GraphOperation op1 = (GraphOperation)tf.withName("op1").constant(1).op(); - GraphOperation op2 = (GraphOperation)tf.withName("op2").constant(2).op(); + GraphOperation op1 = (GraphOperation) tf.withName("op1").constant(1).op(); + GraphOperation op2 = (GraphOperation) tf.withName("op2").constant(2).op(); GraphOperation op3 = new GraphOperation(g, op1.getUnsafeNativeHandle()); GraphOperation op4 = g.operation("op1"); Set ops = new HashSet<>(); @@ -179,11 +181,12 @@ public void outputList() { @Test public void outputTensorNotSupported() { - try (Graph g = new Graph()) { + try (Graph g = new Graph(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Operation split = tf.split(tf.constant(0), tf.array(0, 1, 2), 3L).op(); try { - split.output(0).asTensor(); + split.output(0).asTensor(scope); fail(); } catch (IllegalStateException e) { } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java index d8ffc1a475b..838c6ea1279 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/GraphTest.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.HashSet; import java.util.Iterator; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TFInvalidArgumentException; import org.tensorflow.op.Ops; @@ -32,7 +33,9 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.Graph}. */ +/** + * Unit tests for {@link org.tensorflow.Graph}. + */ public class GraphTest { @Test @@ -138,7 +141,8 @@ public void failOnUseAfterClose() { @Test public void addGradientsToGraph() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x1 = tf.placeholder(TFloat32.class).output(); @@ -146,7 +150,7 @@ public void addGradientsToGraph() { Output y0 = tf.math.square(x1).y(); Output y1 = tf.math.square(y0).y(); Output y2 = tf.math.addN(Arrays.asList(y0, x2)).sum(); - + Output[] grads0 = g.addGradients(y1, toArray(x1)); assertNotNull(grads0); assertEquals(1, grads0.length); @@ -157,30 +161,29 @@ public void addGradientsToGraph() { assertEquals(2, grads1.length); assertEquals(DataType.DT_FLOAT, grads1[0].dataType()); assertEquals(DataType.DT_FLOAT, grads1[1].dataType()); - - try (TFloat32 c1 = TFloat32.scalarOf(3.0f); - TFloat32 c2 = TFloat32.scalarOf(2.0f); - AutoCloseableList outputs = new AutoCloseableList<>( - s.runner() - .feed(x1, c1) - .feed(x2, c2) - .fetch(grads0[0]) - .fetch(grads1[0]) - .fetch(grads1[1]) - .run())) { - - assertEquals(3, outputs.size()); - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - assertEquals(6.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); - assertEquals(1.0f, ((TFloat32)outputs.get(2)).getFloat(), 0.0f); - } + + TFloat32 c1 = TFloat32.scalarOf(scope, 3.0f); + TFloat32 c2 = TFloat32.scalarOf(scope, 2.0f); + List outputs = s.runner() + .feed(x1, c1) + .feed(x2, c2) + .fetch(grads0[0]) + .fetch(grads1[0]) + .fetch(grads1[1]) + .run(scope); + + assertEquals(3, outputs.size()); + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + assertEquals(6.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); + assertEquals(1.0f, ((TFloat32) outputs.get(2)).getFloat(), 0.0f); } } @Test public void addGradientSumsToGraph() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); @@ -192,27 +195,27 @@ public void addGradientSumsToGraph() { assertEquals(1, grad.length); assertEquals(DataType.DT_FLOAT, grad[0].dataType()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - TFloat32 output = (TFloat32)s.runner() - .feed(x, c) - .fetch(grad[0]) - .run() - .get(0)) { - assertEquals(114.0f, output.getFloat(), 0.0f); - } + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + TFloat32 output = (TFloat32) s.runner() + .feed(x, c) + .fetch(grad[0]) + .run(scope) + .get(0); + assertEquals(114.0f, output.getFloat(), 0.0f); } } @Test public void addGradientsWithInitialValuesToGraph() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); Output y0 = tf.math.square(x).y(); Output y1 = tf.math.square(y0).y(); - + Output[] grad0 = g.addGradients(y1, toArray(y0)); assertNotNull(grad0); assertEquals(1, grad0.length); @@ -223,14 +226,13 @@ public void addGradientsWithInitialValuesToGraph() { assertEquals(1, grad1.length); assertEquals(DataType.DT_FLOAT, grad1[0].dataType()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - TFloat32 output = (TFloat32)s.runner() - .feed(x, c) - .fetch(grad1[0]) - .run() - .get(0)) { - assertEquals(108.0f, output.getFloat(), 0.0f); - } + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + TFloat32 output = (TFloat32) s.runner() + .feed(x, c) + .fetch(grad1[0]) + .run(scope) + .get(0); + assertEquals(108.0f, output.getFloat(), 0.0f); } } @@ -265,7 +267,8 @@ public void validateGradientsNames() { @Test public void buildWhileLoopSingleInput() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output input = tf.placeholder(TInt32.class).output(); @@ -275,29 +278,29 @@ public void buildWhileLoopSingleInput() { toArray(input), (condGraph, condInputs, condOutputs) -> { Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); + condOutputs[0] = tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); }, (bodyGraph, bodyInputs, bodyOutputs) -> { Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); }, "test_loop"); - try (TInt32 c = TInt32.scalarOf(2); - TInt32 output = (TInt32)s.runner() - .feed(input, c) - .fetch(loopOutputs[0]) - .run() - .get(0)) { - assertEquals(16, output.getInt()); // ((2^2)^2) - } + TInt32 c = TInt32.scalarOf(scope, 2); + TInt32 output = (TInt32) s.runner() + .feed(input, c) + .fetch(loopOutputs[0]) + .run(scope) + .get(0); + assertEquals(16, output.getInt()); // ((2^2)^2) } } @Test public void buildWhileLoopMultipleInputs() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output input1 = tf.placeholder(TInt32.class).output(); @@ -309,29 +312,26 @@ public void buildWhileLoopMultipleInputs() { inputs, (condGraph, condInputs, condOutputs) -> { Ops tfc = Ops.create(condGraph); - condOutputs[0] = tfc.math.less((Output)condInputs[0], tfc.constant(16)).z(); + condOutputs[0] = tfc.math.less((Output) condInputs[0], tfc.constant(16)).z(); }, (bodyGraph, bodyInputs, bodyOutputs) -> { Ops tfb = Ops.create(bodyGraph); - bodyOutputs[0] = tfb.math.square((Output)bodyInputs[0]).y(); - bodyOutputs[1] = tfb.math.square((Output)bodyInputs[1]).y(); + bodyOutputs[0] = tfb.math.square((Output) bodyInputs[0]).y(); + bodyOutputs[1] = tfb.math.square((Output) bodyInputs[1]).y(); }, "test_loop"); - try (TInt32 c1 = TInt32.scalarOf(2); - TInt32 c2 = TInt32.scalarOf(5); - AutoCloseableList outputs = - new AutoCloseableList<>( - s.runner() - .feed(input1, c1) - .feed(input2, c2) - .fetch(loopOutputs[0]) - .fetch(loopOutputs[1]) - .run())) { - assertEquals(2, outputs.size()); - assertEquals(16, ((TInt32)outputs.get(0)).getInt()); // ((2^2)^2) - assertEquals(625, ((TInt32)outputs.get(1)).getInt()); // ((5^2)^2) - } + TInt32 c1 = TInt32.scalarOf(scope, 2); + TInt32 c2 = TInt32.scalarOf(scope, 5); + List outputs = s.runner() + .feed(input1, c1) + .feed(input2, c2) + .fetch(loopOutputs[0]) + .fetch(loopOutputs[1]) + .run(scope); + assertEquals(2, outputs.size()); + assertEquals(16, ((TInt32) outputs.get(0)).getInt()); // ((2^2)^2) + assertEquals(625, ((TInt32) outputs.get(1)).getInt()); // ((5^2)^2) } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java index 0d2d8af8b1c..7c365517aa3 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/RawTensorTest.java @@ -30,61 +30,66 @@ public class RawTensorTest { @Test public void rawToTypedTensor() { - RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), -1); - TFloat32 floatTensor = (TFloat32)rawTensor.asTypedTensor(); - assertSame(floatTensor.asRawTensor(), rawTensor); - try { - TInt32 intTensor = (TInt32)rawTensor.asTypedTensor(); - fail(); - } catch (ClassCastException e) { - // ok + try (TensorScope scope = new TensorScope()) { + RawTensor rawTensor = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), -1); + TFloat32 floatTensor = (TFloat32) rawTensor.asTypedTensor(); + assertSame(floatTensor.asRawTensor(), rawTensor); + try { + TInt32 intTensor = (TInt32) rawTensor.asTypedTensor(); + fail(); + } catch (ClassCastException e) { + // ok + } } } @Test public void allocateTensorWithSize() { - try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), 16)) { + try (TensorScope scope = new TensorScope()) { + RawTensor rawTensor = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), 16); assertEquals(16, rawTensor.numBytes()); - } - try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), 100)) { + rawTensor = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), 100); assertEquals(100, rawTensor.numBytes()); - } - try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), 10)) { - fail(); - } catch (IllegalArgumentException e) { - // ok - } - try (RawTensor rawTensor = RawTensor.allocate(TString.class, Shape.of(2, 2), 100)) { + try (RawTensor rawTensor2 = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), 10)) { + fail(); + } catch (IllegalArgumentException e) { + // ok + } + rawTensor = RawTensor.allocate(scope, TString.class, Shape.of(2, 2), 100); assertEquals(100, rawTensor.numBytes()); } } @Test public void allocateTensorWithoutSize() { - try (RawTensor rawTensor = RawTensor.allocate(TFloat32.class, Shape.of(2, 2), -1)) { - assertEquals(16, rawTensor.numBytes()); - // ok - } - try (RawTensor rawTensor = RawTensor.allocate(TString.class, Shape.of(2, 2), -1)) { - fail(); - } catch (IllegalArgumentException e) { - // ok + try (TensorScope scope = new TensorScope()) { + try (RawTensor rawTensor = RawTensor.allocate(scope, TFloat32.class, Shape.of(2, 2), -1)) { + assertEquals(16, rawTensor.numBytes()); + // ok + } + try (RawTensor rawTensor = RawTensor.allocate(scope, TString.class, Shape.of(2, 2), -1)) { + fail(); + } catch (IllegalArgumentException e) { + // ok + } } } @Test public void failToAllocateTensorFromUnknownShape() { - try { - RawTensor.allocate(TFloat32.class, Shape.of(3, -1, 3), -1); - fail(); - } catch (IllegalArgumentException e) { - // ok - } - try { - RawTensor.allocate(TString.class, Shape.unknown(), 100); - fail(); - } catch (IllegalArgumentException e) { - // ok + try (TensorScope scope = new TensorScope()) { + try { + RawTensor.allocate(scope, TFloat32.class, Shape.of(3, -1, 3), -1); + fail(); + } catch (IllegalArgumentException e) { + // ok + } + try { + RawTensor.allocate(scope, TString.class, Shape.unknown(), 100); + fail(); + } catch (IllegalArgumentException e) { + // ok + } } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java index cd8ac7e2ae4..726294547b5 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SavedModelBundleTest.java @@ -27,8 +27,8 @@ import java.nio.file.Path; import java.nio.file.Paths; import java.util.Collections; -import java.util.Map; import java.util.HashMap; +import java.util.Map; import org.junit.jupiter.api.Test; import org.tensorflow.exceptions.TensorFlowException; import org.tensorflow.ndarray.FloatNdArray; @@ -46,7 +46,9 @@ import org.tensorflow.proto.framework.TensorInfo; import org.tensorflow.types.TFloat32; -/** Unit tests for {@link org.tensorflow.SavedModelBundle}. */ +/** + * Unit tests for {@link org.tensorflow.SavedModelBundle}. + */ public class SavedModelBundleTest { private static final float EPSILON = 1e-7f; @@ -56,7 +58,8 @@ public class SavedModelBundleTest { static { try { SAVED_MODEL_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model").toURI()).toString(); - SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()).toString(); + SAVED_MODEL_PY_PATH = Paths.get(SavedModelBundleTest.class.getResource("/saved_model_using_python/model").toURI()) + .toString(); } catch (URISyntaxException e) { throw new RuntimeException(e); } @@ -102,15 +105,15 @@ public void exportFunctionWithVariables() throws IOException { float reducedSum; FloatNdArray xValue = StdArrays.ndCopyOf(new float[][]{{0, 1, 2}, {3, 4, 5}}); Shape xyShape = Shape.of(2, 3L); - try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape))) { + try (ConcreteFunction f = ConcreteFunction.create(tf -> buildGraphWithVariables(tf, xyShape)); + TensorScope scope = new TensorScope()) { // Init variable state by running the Init operation directly f.session().run(Init.DEFAULT_NAME); // Call the graph and remember the result of computation for later - try (TFloat32 xTensor = TFloat32.tensorOf(xValue); - TFloat32 zTensor = (TFloat32)f.call(xTensor)) { - reducedSum = zTensor.getFloat(); - } + TFloat32 xTensor = TFloat32.tensorOf(scope, xValue); + TFloat32 zTensor = (TFloat32) f.call(scope, xTensor); + reducedSum = zTensor.getFloat(); // Save/export the model (which is a single function in this case) f.save(testFolder.toString()); } @@ -121,7 +124,8 @@ public void exportFunctionWithVariables() throws IOException { // Reload the model just saved and validate its data try (SavedModelBundle savedModel = - SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG)) { + SavedModelBundle.load(testFolder.toString(), SavedModelBundle.DEFAULT_TAG); + TensorScope scope = new TensorScope()) { assertNotNull(savedModel.metaGraphDef()); assertNotNull(savedModel.metaGraphDef().getSaverDef()); assertEquals(1, savedModel.metaGraphDef().getSignatureDefCount()); @@ -153,17 +157,14 @@ public void exportFunctionWithVariables() throws IOException { assertNotNull(outputInfo); assertEquals(0, outputInfo.getTensorShape().getDimCount()); - try (TFloat32 xTensor = TFloat32.tensorOf(xValue)) { - // Call the saved model function and make sure it returns the same result as before - try (TFloat32 zTensor = (TFloat32)function.call(xTensor)) { - assertEquals(reducedSum, zTensor.getFloat(), EPSILON); - } - // Now call the same function directly from the model - try (TFloat32 zTensor = - (TFloat32)savedModel.call(Collections.singletonMap("input", xTensor)).get("reducedSum")) { - assertEquals(reducedSum, zTensor.getFloat(), EPSILON); - } - } + TFloat32 xTensor = TFloat32.tensorOf(scope, xValue); + // Call the saved model function and make sure it returns the same result as before + TFloat32 zTensor = (TFloat32) function.call(scope, xTensor); + assertEquals(reducedSum, zTensor.getFloat(), EPSILON); + // Now call the same function directly from the model + TFloat32 zTensor2 = + (TFloat32) savedModel.call(scope, Collections.singletonMap("input", xTensor)).get("reducedSum"); + assertEquals(reducedSum, zTensor2.getFloat(), EPSILON); } } @@ -177,32 +178,32 @@ public void exportMultipleFunctions() throws IOException { Signature f2Signature = buildIdentityGraph(tf, "identity"); try (Session s = new Session(g); ConcreteFunction f1 = ConcreteFunction.create(f1Signature, s); - ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s)) { + ConcreteFunction f2 = ConcreteFunction.create(f2Signature, s); + TensorScope scope = new TensorScope()) { f1.session().run(Init.DEFAULT_NAME); - try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - TFloat32 t = (TFloat32)f1.call(x)) { - reducedSum = t.getFloat(); - } + TFloat32 x = TFloat32.tensorOf(scope, StdArrays.ndCopyOf(new float[]{2, 2})); + TFloat32 t = (TFloat32) f1.call(scope, x); + reducedSum = t.getFloat(); SavedModelBundle.exporter(testFolder.toString()) .withFunction(f1) .withFunction(f2) .export(); } } - try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString())) { + try (SavedModelBundle model = SavedModelBundle.load(testFolder.toString()); + TensorScope scope = new TensorScope()) { assertEquals(2, model.signatures().size()); ConcreteFunction f1 = model.function(Signature.DEFAULT_KEY); assertNotNull(f1); - try (TFloat32 x = TFloat32.tensorOf(StdArrays.ndCopyOf(new float[]{2, 2})); - TFloat32 t = (TFloat32)f1.call(x)) { + try (TFloat32 x = TFloat32.tensorOf(scope, StdArrays.ndCopyOf(new float[]{2, 2})); + TFloat32 t = (TFloat32) f1.call(scope, x)) { assertEquals(reducedSum, t.getFloat(), EPSILON); } ConcreteFunction f2 = model.function("identity"); assertNotNull(f2); - try (TFloat32 x = TFloat32.scalarOf(10.0f); - TFloat32 t = (TFloat32)f2.call(x)) { - assertEquals(10.0f, t.getFloat(), 0.0f); - } + TFloat32 x = TFloat32.scalarOf(scope, 10.0f); + TFloat32 t = (TFloat32) f2.call(scope, x); + assertEquals(10.0f, t.getFloat(), 0.0f); try { model.function("NoSuchFunction"); fail(); @@ -284,23 +285,22 @@ public void cannotExportOrImportInvalidTags() { @Test public void pythonTfFunction() { // ConcreteFunctions on models saved using python - try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PY_PATH, "serve")) { + try (SavedModelBundle bundle = SavedModelBundle.load(SAVED_MODEL_PY_PATH, "serve"); + TensorScope scope = new TensorScope()) { /* * Test model was created in python * Signature name used for saving 'add', argument names 'a' and 'b' */ ConcreteFunction add = bundle.function("add"); Map args = new HashMap(); - try (TFloat32 a = TFloat32.scalarOf(10.0f); - TFloat32 b = TFloat32.scalarOf(15.5f)) { - args.put("a", a); - args.put("b", b); - Map result = add.call(args); - assertEquals(result.size(), 1); - try (TFloat32 c = (TFloat32)result.values().iterator().next()) { - assertEquals(25.5f, c.getFloat()); - } - } + TFloat32 a = TFloat32.scalarOf(scope, 10.0f); + TFloat32 b = TFloat32.scalarOf(scope, 15.5f); + args.put("a", a); + args.put("b", b); + Map result = add.call(scope, args); + assertEquals(result.size(), 1); + TFloat32 c = (TFloat32) result.values().iterator().next(); + assertEquals(25.5f, c.getFloat()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java index d7ea381d315..ca922cec6b8 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/SessionTest.java @@ -26,9 +26,13 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.Comparator; - +import java.util.List; import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Init; import org.tensorflow.op.core.Split; @@ -38,120 +42,114 @@ import org.tensorflow.proto.framework.ConfigProto; import org.tensorflow.proto.framework.GraphDef; import org.tensorflow.proto.framework.RunOptions; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.NdArrays; -import org.tensorflow.ndarray.StdArrays; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; -/** Unit tests for {@link org.tensorflow.Session}. */ +/** + * Unit tests for {@link org.tensorflow.Session}. + */ public class SessionTest { @Test public void runUsingOperationNames() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed("X", x).fetch("Y").run())) { - assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); - } + transpose_A_times_X(tf, new int[][]{{2}, {3}}); + TInt32 x = TInt32.tensorOf(scope, StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + List outputs = s.runner().feed("X", x).fetch("Y").run(scope); + assertEquals(1, outputs.size()); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } } @Test public void runUsingOperationHandles() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); + transpose_A_times_X(tf, new int[][]{{2}, {3}}); Output feed = g.operation("X").output(0); Output fetch = g.operation("Y").output(0); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}})); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().feed(feed, x).fetch(fetch).run())) { - assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); - } + TInt32 x = TInt32.tensorOf(scope, StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + List outputs = s.runner().feed(feed, x).fetch(fetch).run(scope); + assertEquals(1, outputs.size()); + assertEquals(31, ((TInt32) outputs.get(0)).getInt(0, 0)); } } @Test public void runUsingColonSeparatedNames() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Split split = tf.split(tf.constant(0), tf.array(1, 2, 3, 4), 2L); tf.math.add(split.output().get(0), split.output().get(1)); // Fetch using colon separated names. - try (TInt32 fetched = (TInt32)s.runner().fetch("Split:1").run().get(0)) { - assertEquals(3, fetched.getInt(0)); - assertEquals(4, fetched.getInt(1)); - } + TInt32 fetched = (TInt32) s.runner().fetch("Split:1").run(scope).get(0); + assertEquals(3, fetched.getInt(0)); + assertEquals(4, fetched.getInt(1)); // Feed using colon separated names. - try (TInt32 fed = TInt32.vectorOf(4, 3, 2, 1); - TInt32 fetched = (TInt32) s.runner() - .feed("Split:0", fed) - .feed("Split:1", fed) - .fetch("Add") - .run() - .get(0)) { - assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched); - } + TInt32 fed = TInt32.vectorOf(scope, 4, 3, 2, 1); + TInt32 fetched2 = (TInt32) s.runner() + .feed("Split:0", fed) + .feed("Split:1", fed) + .fetch("Add") + .run(scope) + .get(0); + assertEquals(NdArrays.vectorOf(8, 6, 4, 2), fetched2); } } @Test public void runWithMetadata() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); - transpose_A_times_X(tf, new int[][] {{2}, {3}}); - try (TInt32 x = TInt32.tensorOf(StdArrays.ndCopyOf(new int[][] {{5}, {7}}))) { - Session.Run result = s.runner() - .feed("X", x) - .fetch("Y") - .setOptions(fullTraceRunOptions()) - .runAndFetchMetadata(); - // Sanity check on outputs. - AutoCloseableList outputs = new AutoCloseableList<>(result.outputs); - assertEquals(1, outputs.size()); - assertEquals(31, ((TInt32)outputs.get(0)).getInt(0, 0)); - // Sanity check on metadata - assertNotNull(result.metadata); - assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); - outputs.close(); - } + transpose_A_times_X(tf, new int[][]{{2}, {3}}); + TInt32 x = TInt32.tensorOf(scope, StdArrays.ndCopyOf(new int[][]{{5}, {7}})); + Session.Run result = s.runner() + .feed("X", x) + .fetch("Y") + .setOptions(fullTraceRunOptions()) + .runAndFetchMetadata(scope); + // Sanity check on outputs. + assertEquals(1, result.outputs.size()); + assertEquals(31, ((TInt32) result.outputs.get(0)).getInt(0, 0)); + // Sanity check on metadata + assertNotNull(result.metadata); + assertTrue(result.metadata.hasStepStats(), result.metadata.toString()); } } @Test public void runMultipleOutputs() { try (Graph g = new Graph(); - Session s = new Session(g)) { + Session s = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); tf.withName("c1").constant(2718); tf.withName("c2").constant(31415); - AutoCloseableList outputs = - new AutoCloseableList<>(s.runner().fetch("c2").fetch("c1").run()); + List outputs = s.runner().fetch("c2").fetch("c1").run(scope); assertEquals(2, outputs.size()); - assertEquals(31415, ((TInt32)outputs.get(0)).getInt()); - assertEquals(2718, ((TInt32)outputs.get(1)).getInt()); - outputs.close(); + assertEquals(31415, ((TInt32) outputs.get(0)).getInt()); + assertEquals(2718, ((TInt32) outputs.get(1)).getInt()); } } @Test public void failOnUseAfterClose() { - try (Graph g = new Graph()) { + try (Graph g = new Graph(); + TensorScope scope = new TensorScope()) { Session s = new Session(g); s.close(); try { - s.runner().run(); + s.runner().run(scope); fail("methods on a session should fail after close() is called"); } catch (IllegalStateException e) { // expected exception @@ -162,7 +160,8 @@ public void failOnUseAfterClose() { @Test public void createWithConfigProto() { try (Graph g = new Graph(); - Session s = new Session(g, singleThreadConfigProto())) {} + Session s = new Session(g, singleThreadConfigProto())) { + } } @Test @@ -175,12 +174,12 @@ public void runInit() { Variable var2 = tf.variable(tf.constant(20)); Add add = tf.math.add(var1, var2); - try (Session s = new Session(g)) { + try (Session s = new Session(g); + TensorScope scope = new TensorScope()) { s.run(tf.init()); - try (TInt32 t = (TInt32) s.runner().fetch(add).run().get(0)) { - assertEquals(30, t.getInt()); - } + TInt32 t = (TInt32) s.runner().fetch(add).run(scope).get(0); + assertEquals(30, t.getInt()); } } } @@ -196,12 +195,12 @@ public void runInitByName() { Add add = tf.math.add(var1, var2); tf.withName("init_test").init(); - try (Session s = new Session(g)) { + try (Session s = new Session(g); + TensorScope scope = new TensorScope()) { s.run("init_test"); - try (TInt32 t = (TInt32) s.runner().fetch(add).run().get(0)) { - assertEquals(30, t.getInt()); - } + TInt32 t = (TInt32) s.runner().fetch(add).run(scope).get(0); + assertEquals(30, t.getInt()); try { s.run("wrong_name"); fail(); @@ -213,9 +212,10 @@ public void runInitByName() { } @Test - public void saveAndRestore() throws IOException { + public void saveAndRestore() throws IOException { Path testFolder = Files.createTempDirectory("tf-session-save-restore-test"); - try (Graph g = new Graph()) { + try (Graph g = new Graph(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Variable x = tf.withName("x").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); Variable y = tf.withName("y").variable(tf.random.randomUniform(tf.constant(Shape.of(3, 3L)), TFloat32.class)); @@ -230,11 +230,10 @@ public void saveAndRestore() throws IOException { restoredGraph.importGraphDef(graphDef); try (Session restoredSession = new Session(restoredGraph)) { restoredSession.restore(testFolder.resolve("checkpoint").toString()); - try (AutoCloseableList oldList = new AutoCloseableList<>(s.runner().fetch("x").fetch("y").run()); - AutoCloseableList newList = new AutoCloseableList<>(restoredSession.runner().fetch("x").fetch("y").run())){ - assertEquals(oldList.get(0),newList.get(0)); - assertEquals(oldList.get(1),newList.get(1)); - } + List oldList = s.runner().fetch("x").fetch("y").run(scope); + List newList = restoredSession.runner().fetch("x").fetch("y").run(scope); + assertEquals(oldList.get(0),newList.get(0)); + assertEquals(oldList.get(1),newList.get(1)); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java new file mode 100644 index 00000000000..8c285392d7b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorScopeTest.java @@ -0,0 +1,76 @@ +/* + Copyright 2021 The TensorFlow Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ============================================================================== + */ +package org.tensorflow; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.TFloat32; + +/** + * Unit tests for {@link TensorScope} + */ +public class TensorScopeTest { + + private static TFloat32 makeTensor(TensorScope scope, long size) { + return TFloat32.tensorOf(scope, Shape.of(size), x -> { + for (long i = 0; i < size; i++) { + x.setFloat(0, i); + } + }); + } + + @Test + public void testBasicScope() { + TensorScope scope = new TensorScope(); + + TFloat32 tensor = makeTensor(scope, 10); + TFloat32 detachTensor = makeTensor(scope, 10); + detachTensor.detach(); + + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + assertFalse(detachTensor.isAttached()); + assertFalse(detachTensor.isClosed()); + + scope.close(); + + assertTrue(tensor.isClosed()); + assertTrue(scope.isClosed()); + assertFalse(detachTensor.isClosed()); + detachTensor.close(); + } + + @Test + public void testAttach() { + TensorScope firstScope = new TensorScope(); + TFloat32 tensor = makeTensor(firstScope, 10); + TensorScope secondScope = new TensorScope().withTensors(tensor); + + assertTrue(tensor.isAttached()); + assertFalse(tensor.isClosed()); + + secondScope.close(); + + assertTrue(tensor.isClosed()); + firstScope.close(); + } + +} diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java index 9415a986222..c6f26afaeda 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/TensorTest.java @@ -40,7 +40,6 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.DataBuffers; -import org.tensorflow.op.Ops; import org.tensorflow.proto.framework.DataType; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; @@ -50,69 +49,73 @@ import org.tensorflow.types.TString; import org.tensorflow.types.TUint8; -/** Unit tests for {@link org.tensorflow.Tensor}. */ +/** + * Unit tests for {@link org.tensorflow.Tensor}. + */ public class TensorTest { + private static final double EPSILON = 1e-7; private static final float EPSILON_F = 1e-7f; @Test public void createWithRawData() { - double[] doubles = {1d, 2d, 3d, 4d}; - Shape doubles_shape = Shape.of(4); - boolean[] bools = {true, false, true, false}; - byte[] bools_ = {1, 0, 1, 0}; - Shape bools_shape = Shape.of(4); - String strings = "test"; - Shape strings_shape = Shape.scalar(); - byte[] strings_; // raw TF_STRING - try (TString t = TString.tensorOf(NdArrays.scalarOfObject(strings))) { - strings_ = new byte[(int)t.numBytes()]; + try (TensorScope scope = new TensorScope()) { + double[] doubles = {1d, 2d, 3d, 4d}; + Shape doubles_shape = Shape.of(4); + boolean[] bools = {true, false, true, false}; + byte[] bools_ = {1, 0, 1, 0}; + Shape bools_shape = Shape.of(4); + String strings = "test"; + Shape strings_shape = Shape.scalar(); + byte[] strings_; // raw TF_STRING + TString t = TString.tensorOf(scope, NdArrays.scalarOfObject(strings)); + strings_ = new byte[(int) t.numBytes()]; t.asRawTensor().data().read(strings_); - } - // validate creating a tensor using a raw data byte buffers - { - try (TBool t = Tensor.of(TBool.class, bools_shape, DataBuffers.of(bools_))) { + // validate creating a tensor using a raw data byte buffers + { + TBool t1 = Tensor.of(scope, TBool.class, bools_shape, DataBuffers.of(bools_)); boolean[] actual = new boolean[bools_.length]; - t.read(DataBuffers.of(actual)); + t1.read(DataBuffers.of(actual)); assertArrayEquals(bools, actual); - } - // note: the buffer is expected to contain raw TF_STRING (as per C API) - try (TString t = Tensor.of(TString.class, strings_shape, DataBuffers.of(strings_))) { - assertEquals(strings, t.getObject()); + // note: the buffer is expected to contain raw TF_STRING (as per C API) + TString t2 = Tensor.of(scope, TString.class, strings_shape, DataBuffers.of(strings_)); + assertEquals(strings, t2.getObject()); } - } - // validate creating a tensor using a direct byte buffer (in host order) - { - DoubleBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder()) - .asDoubleBuffer().put(doubles); - try (TFloat64 t = TFloat64.tensorOf(doubles_shape, d -> d.write(DataBuffers.of(buf)))) { + // validate creating a tensor using a direct byte buffer (in host order) + { + DoubleBuffer buf = ByteBuffer.allocateDirect(8 * doubles.length).order(ByteOrder.nativeOrder()) + .asDoubleBuffer().put(doubles); + TFloat64 t1 = TFloat64.tensorOf(scope, doubles_shape, d -> d.write(DataBuffers.of(buf))); double[] actual = new double[doubles.length]; - t.read(DataBuffers.of(actual)); + t1.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); } - } - // validate shape checking - try (TBool t = Tensor.of(TBool.class, Shape.of(bools_.length * 2), DataBuffers.of(bools_))) { - fail("should have failed on incompatible buffer"); - } catch (IllegalArgumentException e) { - // expected + // validate shape checking + try { + TBool t1 = Tensor.of(scope, TBool.class, Shape.of(bools_.length * 2), DataBuffers.of(bools_)); + fail("should have failed on incompatible buffer"); + } catch (IllegalArgumentException e) { + // expected + } } + } @Test public void createFromBufferWithNativeByteOrder() { - double[] doubles = {1d, 2d, 3d, 4d}; - DoubleBuffer buf = - ByteBuffer.allocate(8 * doubles.length) - .order(ByteOrder.nativeOrder()) - .asDoubleBuffer() - .put(doubles); - flipBuffer(buf); - try (TFloat64 t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { + try (TensorScope scope = new TensorScope()) { + double[] doubles = {1d, 2d, 3d, 4d}; + DoubleBuffer buf = + ByteBuffer.allocate(8 * doubles.length) + .order(ByteOrder.nativeOrder()) + .asDoubleBuffer() + .put(doubles); + flipBuffer(buf); + TFloat64 t = TFloat64.tensorOf(scope, Shape.of(4), DataBuffers.of(buf)); double[] actual = new double[doubles.length]; t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); @@ -121,17 +124,18 @@ public void createFromBufferWithNativeByteOrder() { @Test public void createFromBufferWithNonNativeByteOrder() { - double[] doubles = {1d, 2d, 3d, 4d}; - DoubleBuffer buf = - ByteBuffer.allocate(8 * doubles.length) - .order( - ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN - ? ByteOrder.BIG_ENDIAN - : ByteOrder.LITTLE_ENDIAN) - .asDoubleBuffer() - .put(doubles); - flipBuffer(buf); - try (TFloat64 t = TFloat64.tensorOf(Shape.of(4), DataBuffers.of(buf))) { + try (TensorScope scope = new TensorScope()) { + double[] doubles = {1d, 2d, 3d, 4d}; + DoubleBuffer buf = + ByteBuffer.allocate(8 * doubles.length) + .order( + ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN + ? ByteOrder.BIG_ENDIAN + : ByteOrder.LITTLE_ENDIAN) + .asDoubleBuffer() + .put(doubles); + flipBuffer(buf); + TFloat64 t = TFloat64.tensorOf(scope, Shape.of(4), DataBuffers.of(buf)); double[] actual = new double[doubles.length]; t.read(DataBuffers.of(actual)); assertArrayEquals(doubles, actual, EPSILON); @@ -140,75 +144,78 @@ public void createFromBufferWithNonNativeByteOrder() { @Test public void createWithTypedBuffer() { - IntBuffer ints = IntBuffer.wrap(new int[]{1, 2, 3, 4}); - FloatBuffer floats = FloatBuffer.wrap(new float[]{1f, 2f, 3f, 4f}); - DoubleBuffer doubles = DoubleBuffer.wrap(new double[]{1d, 2d, 3d, 4d}); - LongBuffer longs = LongBuffer.wrap(new long[]{1L, 2L, 3L, 4L}); - - // validate creating a tensor using a typed buffer - { - Shape shape = Shape.of(4); - try (TFloat64 t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { - DoubleBuffer actual = DoubleBuffer.allocate(doubles.capacity()); - t.read(DataBuffers.of(actual)); - assertEquals(doubles, actual); - } - try (TFloat32 t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { - FloatBuffer actual = FloatBuffer.allocate(floats.capacity()); - t.read(DataBuffers.of(actual)); - assertEquals(floats, actual); - } - try (TInt32 t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { - IntBuffer actual = IntBuffer.allocate(ints.capacity()); - t.read(DataBuffers.of(actual)); - assertEquals(ints, actual); - } - try (TInt64 t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { - LongBuffer actual = LongBuffer.allocate(longs.capacity()); - t.read(DataBuffers.of(actual)); - assertEquals(longs, actual); - } - } + try (TensorScope scope = new TensorScope()) { + IntBuffer ints = IntBuffer.wrap(new int[]{1, 2, 3, 4}); + FloatBuffer floats = FloatBuffer.wrap(new float[]{1f, 2f, 3f, 4f}); + DoubleBuffer doubles = DoubleBuffer.wrap(new double[]{1d, 2d, 3d, 4d}); + LongBuffer longs = LongBuffer.wrap(new long[]{1L, 2L, 3L, 4L}); + + // validate creating a tensor using a typed buffer + { + Shape shape = Shape.of(4); + TFloat64 tDouble = TFloat64.tensorOf(scope, shape, DataBuffers.of(doubles)); + DoubleBuffer doubleActual = DoubleBuffer.allocate(doubles.capacity()); + tDouble.read(DataBuffers.of(doubleActual)); + assertEquals(doubles, doubleActual); + + TFloat32 tFloat = TFloat32.tensorOf(scope, shape, DataBuffers.of(floats)); + FloatBuffer floatActual = FloatBuffer.allocate(floats.capacity()); + tFloat.read(DataBuffers.of(floatActual)); + assertEquals(floats, floatActual); + + TInt32 tInt = TInt32.tensorOf(scope, shape, DataBuffers.of(ints)); + IntBuffer intActual = IntBuffer.allocate(ints.capacity()); + tInt.read(DataBuffers.of(intActual)); + assertEquals(ints, intActual); + + TInt64 tLong = TInt64.tensorOf(scope, shape, DataBuffers.of(longs)); + LongBuffer longActual = LongBuffer.allocate(longs.capacity()); + tLong.read(DataBuffers.of(longActual)); + assertEquals(longs, longActual); - // validate shape-checking - { - Shape shape = Shape.of(5); - try (TFloat64 t = TFloat64.tensorOf(shape, DataBuffers.of(doubles))) { - fail("should have failed on incompatible buffer"); - } catch (BufferUnderflowException e) { - // expected - } - try (TFloat32 t = TFloat32.tensorOf(shape, DataBuffers.of(floats))) { - fail("should have failed on incompatible buffer"); - } catch (BufferUnderflowException e) { - // expected - } - try (TInt32 t = TInt32.tensorOf(shape, DataBuffers.of(ints))) { - fail("should have failed on incompatible buffer"); - } catch (BufferUnderflowException e) { - // expected } - try (TInt64 t = TInt64.tensorOf(shape, DataBuffers.of(longs))) { - fail("should have failed on incompatible buffer"); - } catch (BufferUnderflowException e) { - // expected + + // validate shape-checking + { + Shape shape = Shape.of(5); + try (TFloat64 t = TFloat64.tensorOf(scope, shape, DataBuffers.of(doubles))) { + fail("should have failed on incompatible buffer"); + } catch (BufferUnderflowException e) { + // expected + } + try (TFloat32 t = TFloat32.tensorOf(scope, shape, DataBuffers.of(floats))) { + fail("should have failed on incompatible buffer"); + } catch (BufferUnderflowException e) { + // expected + } + try (TInt32 t = TInt32.tensorOf(scope, shape, DataBuffers.of(ints))) { + fail("should have failed on incompatible buffer"); + } catch (BufferUnderflowException e) { + // expected + } + try (TInt64 t = TInt64.tensorOf(scope, shape, DataBuffers.of(longs))) { + fail("should have failed on incompatible buffer"); + } catch (BufferUnderflowException e) { + // expected + } } } } @Test public void readFromRawData() { - int[] ints = {1, 2, 3}; - float[] floats = {1f, 2f, 3f}; - double[] doubles = {1d, 2d, 3d}; - long[] longs = {1L, 2L, 3L}; - boolean[] bools = {true, false, true}; - - try (TInt32 tints = TInt32.vectorOf(ints); - TFloat32 tfloats = TFloat32.vectorOf(floats); - TFloat64 tdoubles = TFloat64.vectorOf(doubles); - TInt64 tlongs = TInt64.vectorOf(longs); - TBool tbools = TBool.vectorOf(bools)) { + try (TensorScope scope = new TensorScope()) { + int[] ints = {1, 2, 3}; + float[] floats = {1f, 2f, 3f}; + double[] doubles = {1d, 2d, 3d}; + long[] longs = {1L, 2L, 3L}; + boolean[] bools = {true, false, true}; + + TInt32 tints = TInt32.vectorOf(scope, ints); + TFloat32 tfloats = TFloat32.vectorOf(scope, floats); + TFloat64 tdoubles = TFloat64.vectorOf(scope, doubles); + TInt64 tlongs = TInt64.vectorOf(scope, longs); + TBool tbools = TBool.vectorOf(scope, bools); // validate that any datatype is readable with ByteBuffer (content, position) { @@ -243,7 +250,7 @@ public void readFromRawData() { // validate the use of direct buffers { ByteBuffer bbuf = - ByteBuffer.allocateDirect((int)tdoubles.numBytes()).order(ByteOrder.nativeOrder()); + ByteBuffer.allocateDirect((int) tdoubles.numBytes()).order(ByteOrder.nativeOrder()); tdoubles.asRawTensor().data().copyTo(DataBuffers.of(bbuf), tdoubles.numBytes()); assertEquals(doubles[0], bbuf.asDoubleBuffer().get(0), EPSILON); } @@ -251,7 +258,7 @@ public void readFromRawData() { // validate byte order conversion { DoubleBuffer foreignBuf = - ByteBuffer.allocate((int)tdoubles.numBytes()) + ByteBuffer.allocate((int) tdoubles.numBytes()) .order( ByteOrder.nativeOrder() == ByteOrder.LITTLE_ENDIAN ? ByteOrder.BIG_ENDIAN @@ -267,142 +274,136 @@ public void readFromRawData() { @Test public void scalars() { - try (TFloat32 t = TFloat32.scalarOf(2.718f)) { - assertEquals(TFloat32.class, t.type()); - assertEquals(DataType.DT_FLOAT, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals(2.718f, t.getFloat(), EPSILON_F); - } - - try (TFloat64 t = TFloat64.scalarOf(3.1415)) { - assertEquals(TFloat64.class, t.type()); - assertEquals(DataType.DT_DOUBLE, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals(3.1415, t.getDouble(), EPSILON); - } - - try (TInt32 t = TInt32.scalarOf(-33)) { - assertEquals(TInt32.class, t.type()); - assertEquals(DataType.DT_INT32, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals(-33, t.getInt()); - } - - try (TInt64 t = TInt64.scalarOf(8589934592L)) { - assertEquals(TInt64.class, t.type()); - assertEquals(DataType.DT_INT64, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals(8589934592L, t.getLong()); - } - - try (TBool t = TBool.scalarOf(true)) { - assertEquals(TBool.class, t.type()); - assertEquals(DataType.DT_BOOL, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertTrue(t.getBoolean()); - } - - try (TString t = TString.scalarOf("sombrero")) { - assertEquals(TString.class, t.type()); - assertEquals(DataType.DT_STRING, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertEquals("sombrero", t.getObject()); - } - - final byte[] bytes = {1, 2, 3, 4}; - try (TString t = TString.tensorOfBytes(NdArrays.scalarOfObject(bytes))) { - assertEquals(TString.class, t.type()); - assertEquals(DataType.DT_STRING, t.dataType()); - assertEquals(0, t.shape().numDimensions()); - assertArrayEquals(bytes, t.asBytes().getObject()); + try (TensorScope scope = new TensorScope()) { + TFloat32 tFloat = TFloat32.scalarOf(scope, 2.718f); + assertEquals(TFloat32.class, tFloat.type()); + assertEquals(DataType.DT_FLOAT, tFloat.dataType()); + assertEquals(0, tFloat.shape().numDimensions()); + assertEquals(2.718f, tFloat.getFloat(), EPSILON_F); + + TFloat64 tDouble = TFloat64.scalarOf(scope, 3.1415); + assertEquals(TFloat64.class, tDouble.type()); + assertEquals(DataType.DT_DOUBLE, tDouble.dataType()); + assertEquals(0, tDouble.shape().numDimensions()); + assertEquals(3.1415, tDouble.getDouble(), EPSILON); + + TInt32 tInt = TInt32.scalarOf(scope, -33); + assertEquals(TInt32.class, tInt.type()); + assertEquals(DataType.DT_INT32, tInt.dataType()); + assertEquals(0, tInt.shape().numDimensions()); + assertEquals(-33, tInt.getInt()); + + TInt64 tLong = TInt64.scalarOf(scope, 8589934592L); + assertEquals(TInt64.class, tLong.type()); + assertEquals(DataType.DT_INT64, tLong.dataType()); + assertEquals(0, tLong.shape().numDimensions()); + assertEquals(8589934592L, tLong.getLong()); + + TBool tBool = TBool.scalarOf(scope, true); + assertEquals(TBool.class, tBool.type()); + assertEquals(DataType.DT_BOOL, tBool.dataType()); + assertEquals(0, tBool.shape().numDimensions()); + assertTrue(tBool.getBoolean()); + + TString tString = TString.scalarOf(scope, "sombrero"); + assertEquals(TString.class, tString.type()); + assertEquals(DataType.DT_STRING, tString.dataType()); + assertEquals(0, tString.shape().numDimensions()); + assertEquals("sombrero", tString.getObject()); + + final byte[] bytes = {1, 2, 3, 4}; + TString tByteString = TString.tensorOfBytes(scope, NdArrays.scalarOfObject(bytes)); + assertEquals(TString.class, tByteString.type()); + assertEquals(DataType.DT_STRING, tByteString.dataType()); + assertEquals(0, tByteString.shape().numDimensions()); + assertArrayEquals(bytes, tByteString.asBytes().getObject()); } } @Test public void nDimensional() { - DoubleNdArray vector = StdArrays.ndCopyOf(new double[]{1.414, 2.718, 3.1415}); - try (TFloat64 t = TFloat64.tensorOf(vector)) { - assertEquals(TFloat64.class, t.type()); - assertEquals(DataType.DT_DOUBLE, t.dataType()); - assertEquals(1, t.shape().numDimensions()); - assertEquals(3, t.shape().size(0)); - assertEquals(vector, t); - } - - IntNdArray matrix = StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {4, 5, 6}}); - try (TInt32 t = TInt32.tensorOf(matrix)) { - assertEquals(TInt32.class, t.type()); - assertEquals(DataType.DT_INT32, t.dataType()); - assertEquals(2, t.shape().numDimensions()); - assertEquals(2, t.shape().size(0)); - assertEquals(3, t.shape().size(1)); - assertEquals(matrix, t); - } - - LongNdArray threeD = StdArrays.ndCopyOf(new long[][][]{ - {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, - }); - try (TInt64 t = TInt64.tensorOf(threeD)) { - assertEquals(TInt64.class, t.type()); - assertEquals(DataType.DT_INT64, t.dataType()); - assertEquals(3, t.shape().numDimensions()); - assertEquals(2, t.shape().size(0)); - assertEquals(5, t.shape().size(1)); - assertEquals(1, t.shape().size(2)); - assertEquals(threeD, t); - } - - BooleanNdArray fourD = StdArrays.ndCopyOf(new boolean[][][][]{ - {{{false, false, false, true}, {false, false, true, false}}}, - {{{false, false, true, true}, {false, true, false, false}}}, - {{{false, true, false, true}, {false, true, true, false}}}, - }); - try (TBool t = TBool.tensorOf(fourD)) { - assertEquals(TBool.class, t.type()); - assertEquals(DataType.DT_BOOL, t.dataType()); - assertEquals(4, t.shape().numDimensions()); - assertEquals(3, t.shape().size(0)); - assertEquals(1, t.shape().size(1)); - assertEquals(2, t.shape().size(2)); - assertEquals(4, t.shape().size(3)); - assertEquals(fourD, t); + try (TensorScope scope = new TensorScope()) { + DoubleNdArray vector = StdArrays.ndCopyOf(new double[]{1.414, 2.718, 3.1415}); + TFloat64 tDouble = TFloat64.tensorOf(scope, vector); + assertEquals(TFloat64.class, tDouble.type()); + assertEquals(DataType.DT_DOUBLE, tDouble.dataType()); + assertEquals(1, tDouble.shape().numDimensions()); + assertEquals(3, tDouble.shape().size(0)); + assertEquals(vector, tDouble); + + IntNdArray matrix = StdArrays.ndCopyOf(new int[][]{{1, 2, 3}, {4, 5, 6}}); + TInt32 tInt = TInt32.tensorOf(scope, matrix); + assertEquals(TInt32.class, tInt.type()); + assertEquals(DataType.DT_INT32, tInt.dataType()); + assertEquals(2, tInt.shape().numDimensions()); + assertEquals(2, tInt.shape().size(0)); + assertEquals(3, tInt.shape().size(1)); + assertEquals(matrix, tInt); + + LongNdArray threeD = StdArrays.ndCopyOf(new long[][][]{ + {{1}, {3}, {5}, {7}, {9}}, {{2}, {4}, {6}, {8}, {0}}, + }); + TInt64 tLong = TInt64.tensorOf(scope, threeD); + assertEquals(TInt64.class, tLong.type()); + assertEquals(DataType.DT_INT64, tLong.dataType()); + assertEquals(3, tLong.shape().numDimensions()); + assertEquals(2, tLong.shape().size(0)); + assertEquals(5, tLong.shape().size(1)); + assertEquals(1, tLong.shape().size(2)); + assertEquals(threeD, tLong); + + BooleanNdArray fourD = StdArrays.ndCopyOf(new boolean[][][][]{ + {{{false, false, false, true}, {false, false, true, false}}}, + {{{false, false, true, true}, {false, true, false, false}}}, + {{{false, true, false, true}, {false, true, true, false}}}, + }); + TBool tBool = TBool.tensorOf(scope, fourD); + assertEquals(TBool.class, tBool.type()); + assertEquals(DataType.DT_BOOL, tBool.dataType()); + assertEquals(4, tBool.shape().numDimensions()); + assertEquals(3, tBool.shape().size(0)); + assertEquals(1, tBool.shape().size(1)); + assertEquals(2, tBool.shape().size(2)); + assertEquals(4, tBool.shape().size(3)); + assertEquals(fourD, tBool); } } @Test public void testNDimensionalStringTensor() { - NdArray matrix = NdArrays.ofObjects(String.class, Shape.of(4, 3)); - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < 3; ++j) { - matrix.setObject(String.format("(%d, %d) = %d", i, j, i << j), i, j); + try (TensorScope scope = new TensorScope()) { + NdArray matrix = NdArrays.ofObjects(String.class, Shape.of(4, 3)); + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 3; ++j) { + matrix.setObject(String.format("(%d, %d) = %d", i, j, i << j), i, j); + } } - } - try (TString t = TString.tensorOf(matrix)) { - assertEquals(TString.class, t.type()); - assertEquals(DataType.DT_STRING, t.dataType()); - assertEquals(2, t.shape().numDimensions()); - assertEquals(4, t.shape().size(0)); - assertEquals(3, t.shape().size(1)); - assertEquals(matrix, t); - } - - NdArray byteMatrix = NdArrays.ofObjects(byte[].class, matrix.shape()); - matrix.scalars().forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i)); - try (TString t = TString.tensorOfBytes(byteMatrix)) { - assertEquals(TString.class, t.type()); - assertEquals(DataType.DT_STRING, t.dataType()); - assertEquals(2, t.shape().numDimensions()); - assertEquals(4, t.shape().size(0)); - assertEquals(3, t.shape().size(1)); - assertEquals(byteMatrix, t.asBytes()); - assertEquals(matrix, t); + TString tString = TString.tensorOf(scope, matrix); + assertEquals(TString.class, tString.type()); + assertEquals(DataType.DT_STRING, tString.dataType()); + assertEquals(2, tString.shape().numDimensions()); + assertEquals(4, tString.shape().size(0)); + assertEquals(3, tString.shape().size(1)); + assertEquals(matrix, tString); + + NdArray byteMatrix = NdArrays.ofObjects(byte[].class, matrix.shape()); + matrix.scalars().forEachIndexed((i, s) -> byteMatrix.setObject(s.getObject().getBytes(UTF_8), i)); + TString tByteString = TString.tensorOfBytes(scope, byteMatrix); + assertEquals(TString.class, tByteString.type()); + assertEquals(DataType.DT_STRING, tByteString.dataType()); + assertEquals(2, tByteString.shape().numDimensions()); + assertEquals(4, tByteString.shape().size(0)); + assertEquals(3, tByteString.shape().size(1)); + assertEquals(byteMatrix, tByteString.asBytes()); + assertEquals(matrix, tByteString); } } @Test public void testUint8TensorFromArray() { - byte[] vector = new byte[] {1, 2, 3, 4}; - try (TUint8 t = TUint8.vectorOf(vector)) { + try (TensorScope scope = new TensorScope()) { + byte[] vector = new byte[]{1, 2, 3, 4}; + TUint8 t = TUint8.vectorOf(scope, vector); assertEquals(TUint8.class, t.type()); assertEquals(DataType.DT_UINT8, t.dataType()); assertEquals(1, t.shape().numDimensions()); @@ -416,8 +417,9 @@ public void testUint8TensorFromArray() { @Test public void testCreateFromArrayOfBoxed() { - Integer[] vector = new Integer[] {1, 2, 3, 4}; - try (TInt32 t = TInt32.tensorOf(Shape.of(4), d -> d.write(DataBuffers.ofObjects(vector)))) { + try (TensorScope scope = new TensorScope()) { + Integer[] vector = new Integer[]{1, 2, 3, 4}; + TInt32 t = TInt32.tensorOf(scope, Shape.of(4), d -> d.write(DataBuffers.ofObjects(vector))); assertEquals(TInt32.class, t.type()); assertEquals(DataType.DT_INT32, t.dataType()); assertEquals(1, t.shape().numDimensions()); @@ -431,113 +433,109 @@ public void testCreateFromArrayOfBoxed() { @Test public void failCreateOnMismatchedDimensions() { - int[][][] invalid = new int[3][1][]; - for (int x = 0; x < invalid.length; ++x) { - for (int y = 0; y < invalid[x].length; ++y) { - invalid[x][y] = new int[x + y + 1]; + try (TensorScope scope = new TensorScope()) { + int[][][] invalid = new int[3][1][]; + for (int x = 0; x < invalid.length; ++x) { + for (int y = 0; y < invalid[x].length; ++y) { + invalid[x][y] = new int[x + y + 1]; + } + } + try (TInt32 t = TInt32.tensorOf(scope, StdArrays.ndCopyOf(invalid))) { + fail("Tensor.create() should fail because of differing sizes in the 3rd dimension"); + } catch (IllegalArgumentException e) { + // The expected exception. } - } - try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(invalid))) { - fail("Tensor.create() should fail because of differing sizes in the 3rd dimension"); - } catch (IllegalArgumentException e) { - // The expected exception. } } @Test public void tensorWithZeroDimension() { - // Note: Historically, TF Java failed on purpose when trying to allocate a tensor with a shape - // that has one or more dimensions set to 0 elements. But Python API allows it, so we should do - // the same. - try (TInt32 t = TInt32.tensorOf(Shape.of(3, 0, 1))) { - assertEquals(0, t.numBytes()); - assertEquals(0, t.shape().size()); - } - try (TInt32 t = TInt32.tensorOf(StdArrays.ndCopyOf(new int[3][0][1]))) { + try (TensorScope scope = new TensorScope()) { + // Note: Historically, TF Java failed on purpose when trying to allocate a tensor with a shape + // that has one or more dimensions set to 0 elements. But Python API allows it, so we should do + // the same. + TInt32 t = TInt32.tensorOf(scope, Shape.of(3, 0, 1)); assertEquals(0, t.numBytes()); assertEquals(0, t.shape().size()); + + TInt32 t2 = TInt32.tensorOf(scope, StdArrays.ndCopyOf(new int[3][0][1])); + assertEquals(0, t2.numBytes()); + assertEquals(0, t2.shape().size()); } } @Test public void allocateTensorWithSize() { - try (TInt32 t = Tensor.of(TInt32.class, Shape.of(2, 2, 2), 8 * 4)) { + try (TensorScope scope = new TensorScope()) { // ok - } - try (TInt32 t = Tensor.of(TInt32.class, Shape.of(2, 2, 2), 9 * 4)) { + Tensor.of(scope, TInt32.class, Shape.of(2, 2, 2), 8 * 4); + // ok (size requested is larger that minimum space required) - } - try { - Tensor.of(TInt32.class, Shape.of(2, 2, 2), 8 * 4 - 1); - fail(); - } catch (IllegalArgumentException e) { - // as expected - } - } + Tensor.of(scope, TInt32.class, Shape.of(2, 2, 2), 9 * 4); - @Test - public void useAfterClose() { - int n = 4; - TInt32 t = TInt32.scalarOf(n); - t.close(); - try { - t.numBytes(); - } catch (IllegalStateException e) { - // The expected exception. + try { + Tensor.of(scope, TInt32.class, Shape.of(2, 2, 2), 8 * 4 - 1); + fail(); + } catch (IllegalArgumentException e) { + // as expected + } } } @Test - public void eagerTensorIsReleasedAfterSessionIsClosed() { - TInt32 sum; - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); - sum = tf.math.add(tf.constant(10), tf.constant(20)).asTensor(); - sum.asRawTensor().nativeHandle(); // does not throw - assertEquals(30, sum.getInt()); - } - try { - sum.asRawTensor().nativeHandle(); - fail("Tensor native handle should have been closed by ending eager session"); - } catch (IllegalStateException e) { - // as expected + public void useAfterClose() { + try (TensorScope scope = new TensorScope()) { + int n = 4; + TInt32 t = TInt32.scalarOf(scope, n); + t.close(); + try { + t.numBytes(); + } catch (IllegalStateException e) { + // The expected exception. + } } } @Test public void fromHandle() { - // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been - // created independently of the Java code. In practice, two Tensor instances MUST NOT have the - // same native handle. - // - // An exception is made for this test, where the pitfalls of this is avoided by not calling - // close() on both Tensors. - final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}}); - try (TFloat32 src = TFloat32.tensorOf(matrix)) { - TFloat32 cpy = (TFloat32)RawTensor.fromHandle(src.asRawTensor().nativeHandle()).asTypedTensor(); + try (TensorScope scope = new TensorScope()) { + // fromHandle is a package-visible method intended for use when the C TF_Tensor object has been + // created independently of the Java code. In practice, two Tensor instances MUST NOT have the + // same native handle. + // + // An exception is made for this test, where the pitfalls of this is avoided by not calling + // close() on both Tensors. + final FloatNdArray matrix = StdArrays.ndCopyOf(new float[][]{{1, 2, 3}, {4, 5, 6}}); + TFloat32 src = TFloat32.tensorOf(scope, matrix); + TFloat32 cpy = (TFloat32) RawTensor.fromHandle(scope, src.asRawTensor().nativeHandle()).asTypedTensor(); assertEquals(src.type(), cpy.type()); assertEquals(src.dataType(), cpy.dataType()); assertEquals(src.shape().numDimensions(), cpy.shape().numDimensions()); assertEquals(src.shape(), cpy.shape()); assertEquals(matrix, cpy); + + // don't want to call close + TensorScope.detach(cpy); } } @Test public void gracefullyFailCreationFromNullArrayForStringTensor() { - // Motivated by: https://github.com/tensorflow/tensorflow/issues/17130 - byte[][] array = new byte[1][]; - try { - TUint8.tensorOf(StdArrays.ndCopyOf(array)); - } catch (IllegalStateException e) { - // expected. - } - byte[][][] array2 = new byte[2][2][2]; - array2[1] = null; - try { - TUint8.tensorOf(StdArrays.ndCopyOf(array)); - } catch (IllegalStateException e) { - // expected. + try (TensorScope scope = new TensorScope()) { + // Motivated by: https://github.com/tensorflow/tensorflow/issues/17130 + byte[][] array = new byte[1][]; + try { + TUint8.tensorOf(scope, StdArrays.ndCopyOf(array)); + } catch (IllegalStateException e) { + // expected. + } + byte[][][] array2 = new byte[2][2][2]; + array2[1] = null; + try { + TUint8.tensorOf(scope, StdArrays.ndCopyOf(array)); + } catch (IllegalStateException e) { + // expected. + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java index b2fbc1e794a..9de1cc5ba8c 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/WrongEnvTest.java @@ -35,7 +35,8 @@ public class WrongEnvTest { @Test public void testTwoEagers() { try (EagerSession e1 = EagerSession.create(); - EagerSession e2 = EagerSession.create()) { + EagerSession e2 = EagerSession.create(); + TensorScope scope = new TensorScope()) { Ops tf1 = Ops.create(e1); Ops tf2 = Ops.create(e2); @@ -44,7 +45,7 @@ public void testTwoEagers() { Operand c = tf2.math.add(a, b); - try (TInt32 tensor = c.asTensor()) { + try (TInt32 tensor = c.asTensor(scope)) { assertEquals(11, tensor.getInt()); } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/benchmark/TensorBenchmark.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/benchmark/TensorBenchmark.java index 053738d2dd6..a3b23fa4abf 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/benchmark/TensorBenchmark.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/benchmark/TensorBenchmark.java @@ -12,10 +12,11 @@ import org.openjdk.jmh.annotations.State; import org.openjdk.jmh.annotations.Warmup; import org.openjdk.jmh.runner.RunnerException; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.StdArrays; import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.ndarray.buffer.IntDataBuffer; -import org.tensorflow.ndarray.StdArrays; import org.tensorflow.types.TInt32; @Fork(value = 1, jvmArgs = {"-Xms4G", "-Xmx4G"}) @@ -32,108 +33,114 @@ public static void main(String[] args) throws IOException, RunnerException { @Benchmark @Measurement(batchSize = 1000) public void initTensorByStdArrays() { - int[][][][] data = new int[][][][] { - { - { - {0, 0, 0}, {0, 0, 1}, {0, 0, 2} - }, - { - {0, 1, 0}, {0, 1, 1}, {0, 1, 2} - }, - { - {0, 2, 0}, {0, 2, 1}, {0, 2, 2} - } - }, { - { - {1, 0, 0}, {1, 0, 1}, {1, 0, 2} - }, - { - {1, 1, 0}, {1, 1, 1}, {1, 1, 2} - }, - { - {1, 2, 0}, {1, 2, 1}, {1, 2, 2} - } - }, { - { - {2, 0, 0}, {2, 0, 1}, {2, 0, 2} - }, - { - {2, 1, 0}, {2, 1, 1}, {2, 1, 2} - }, - { - {2, 2, 0}, {2, 2, 1}, {2, 2, 2} - } - } - }; - TInt32.tensorOf(StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d)); + try (TensorScope scope = new TensorScope()) { + int[][][][] data = new int[][][][]{ + { + { + {0, 0, 0}, {0, 0, 1}, {0, 0, 2} + }, + { + {0, 1, 0}, {0, 1, 1}, {0, 1, 2} + }, + { + {0, 2, 0}, {0, 2, 1}, {0, 2, 2} + } + }, { + { + {1, 0, 0}, {1, 0, 1}, {1, 0, 2} + }, + { + {1, 1, 0}, {1, 1, 1}, {1, 1, 2} + }, + { + {1, 2, 0}, {1, 2, 1}, {1, 2, 2} + } + }, { + { + {2, 0, 0}, {2, 0, 1}, {2, 0, 2} + }, + { + {2, 1, 0}, {2, 1, 1}, {2, 1, 2} + }, + { + {2, 2, 0}, {2, 2, 1}, {2, 2, 2} + } + } + }; + TInt32.tensorOf(scope, StdArrays.shapeOf(data), d -> StdArrays.copyTo(data, d)); + } } @Benchmark @Measurement(batchSize = 1000) public void initTensorByVectors() { - TInt32.tensorOf(Shape.of(3, 3, 3, 3), d -> d - .set(vectorOf(0, 0, 0), 0, 0, 0) - .set(vectorOf(0, 0, 1), 0, 0, 1) - .set(vectorOf(0, 0, 2), 0, 0, 2) - .set(vectorOf(0, 1, 0), 0, 1, 0) - .set(vectorOf(0, 1, 1), 0, 1, 1) - .set(vectorOf(0, 1, 2), 0, 1, 2) - .set(vectorOf(0, 2, 0), 0, 2, 0) - .set(vectorOf(0, 2, 1), 0, 2, 1) - .set(vectorOf(0, 2, 2), 0, 2, 2) - .set(vectorOf(1, 0, 0), 1, 0, 0) - .set(vectorOf(1, 0, 1), 1, 0, 1) - .set(vectorOf(1, 0, 2), 1, 0, 2) - .set(vectorOf(1, 1, 0), 1, 1, 0) - .set(vectorOf(1, 1, 1), 1, 1, 1) - .set(vectorOf(1, 1, 2), 1, 1, 2) - .set(vectorOf(1, 2, 0), 1, 2, 0) - .set(vectorOf(1, 2, 1), 1, 2, 1) - .set(vectorOf(1, 2, 2), 1, 2, 2) - .set(vectorOf(2, 0, 0), 2, 0, 0) - .set(vectorOf(2, 0, 1), 2, 0, 1) - .set(vectorOf(2, 0, 2), 2, 0, 2) - .set(vectorOf(2, 1, 0), 2, 1, 0) - .set(vectorOf(2, 1, 1), 2, 1, 1) - .set(vectorOf(2, 1, 2), 2, 1, 2) - .set(vectorOf(2, 2, 0), 2, 2, 0) - .set(vectorOf(2, 2, 1), 2, 2, 1) - .set(vectorOf(2, 2, 2), 2, 2, 2) - ); + try (TensorScope scope = new TensorScope()) { + TInt32.tensorOf(scope, Shape.of(3, 3, 3, 3), d -> d + .set(vectorOf(0, 0, 0), 0, 0, 0) + .set(vectorOf(0, 0, 1), 0, 0, 1) + .set(vectorOf(0, 0, 2), 0, 0, 2) + .set(vectorOf(0, 1, 0), 0, 1, 0) + .set(vectorOf(0, 1, 1), 0, 1, 1) + .set(vectorOf(0, 1, 2), 0, 1, 2) + .set(vectorOf(0, 2, 0), 0, 2, 0) + .set(vectorOf(0, 2, 1), 0, 2, 1) + .set(vectorOf(0, 2, 2), 0, 2, 2) + .set(vectorOf(1, 0, 0), 1, 0, 0) + .set(vectorOf(1, 0, 1), 1, 0, 1) + .set(vectorOf(1, 0, 2), 1, 0, 2) + .set(vectorOf(1, 1, 0), 1, 1, 0) + .set(vectorOf(1, 1, 1), 1, 1, 1) + .set(vectorOf(1, 1, 2), 1, 1, 2) + .set(vectorOf(1, 2, 0), 1, 2, 0) + .set(vectorOf(1, 2, 1), 1, 2, 1) + .set(vectorOf(1, 2, 2), 1, 2, 2) + .set(vectorOf(2, 0, 0), 2, 0, 0) + .set(vectorOf(2, 0, 1), 2, 0, 1) + .set(vectorOf(2, 0, 2), 2, 0, 2) + .set(vectorOf(2, 1, 0), 2, 1, 0) + .set(vectorOf(2, 1, 1), 2, 1, 1) + .set(vectorOf(2, 1, 2), 2, 1, 2) + .set(vectorOf(2, 2, 0), 2, 2, 0) + .set(vectorOf(2, 2, 1), 2, 2, 1) + .set(vectorOf(2, 2, 2), 2, 2, 2) + ); + } } @Benchmark @Measurement(batchSize = 1000) public void initTensorByFlatArray() { - IntDataBuffer data = DataBuffers.of( - 0, 0, 0, - 0, 0, 1, - 0, 0, 2, - 0, 1, 0, - 0, 1, 1, - 0, 1, 2, - 0, 2, 0, - 0, 2, 1, - 0, 2, 2, - 1, 0, 0, - 1, 0, 1, - 1, 0, 2, - 1, 1, 0, - 1, 1, 1, - 1, 1, 2, - 1, 2, 0, - 1, 2, 1, - 1, 2, 2, - 2, 0, 0, - 2, 0, 1, - 2, 0, 2, - 2, 1, 0, - 2, 1, 1, - 2, 1, 2, - 2, 2, 0, - 2, 2, 1, - 2, 2, 2 - ); - TInt32.tensorOf(Shape.of(3, 3, 3, 3), data); + try (TensorScope scope = new TensorScope()) { + IntDataBuffer data = DataBuffers.of( + 0, 0, 0, + 0, 0, 1, + 0, 0, 2, + 0, 1, 0, + 0, 1, 1, + 0, 1, 2, + 0, 2, 0, + 0, 2, 1, + 0, 2, 2, + 1, 0, 0, + 1, 0, 1, + 1, 0, 2, + 1, 1, 0, + 1, 1, 1, + 1, 1, 2, + 1, 2, 0, + 1, 2, 1, + 1, 2, 2, + 2, 0, 0, + 2, 0, 1, + 2, 0, 2, + 2, 1, 0, + 2, 1, 1, + 2, 1, 2, + 2, 2, 0, + 2, 2, 1, + 2, 2, 2 + ); + TInt32.tensorOf(scope, Shape.of(3, 3, 3, 3), data); + } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java index 62881dcee8c..874680a4015 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/ScopeTest.java @@ -23,10 +23,13 @@ import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TType; -/** Unit tests for {@link org.tensorflow.op.Scope}. */ +/** + * Unit tests for {@link org.tensorflow.op.Scope}. + */ public class ScopeTest { @Test @@ -85,9 +88,9 @@ public void validateNames() { Scope root = new Scope(g); final String[] invalid_names = { - "_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.] - null, "", "a$", // Invalid characters - "a/b", // slashes not allowed + "_", "-", "-x", // Names are constrained to start with [A-Za-z0-9.] + null, "", "a$", // Invalid characters + "a/b", // slashes not allowed }; for (String name : invalid_names) { @@ -144,10 +147,11 @@ public void hierarchy() { @Test public void composite() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Scope s = new Scope(g); Output data = - Const.create(s.withName("data"), new int[] {600, 470, 170, 430, 300}).output(); + Const.create(s.withName("data"), new int[]{600, 470, 170, 430, 300}).output(); // Create a composite op with a customized name Variance var1 = Variance.create(s.withName("example"), data); @@ -168,23 +172,28 @@ public void composite() { // assertNotNull(g.operation("variance/zero")); // Verify correct results as well. - TInt32 result = (TInt32)sess.runner().fetch(var1.output()).run().get(0); + TInt32 result = (TInt32) sess.runner().fetch(var1.output()).run(scope).get(0); assertEquals(21704, result.getInt()); - result = (TInt32)sess.runner().fetch(var2.output()).run().get(0); + result = (TInt32) sess.runner().fetch(var2.output()).run(scope).get(0); assertEquals(21704, result.getInt()); } } // "handwritten" sample operator classes private static final class Const { + private final Output output; static Const create(Scope s, int v) { - return create(s, TInt32.scalarOf(v)); + try (TensorScope scope = new TensorScope()) { + return create(s, TInt32.scalarOf(scope, v)); + } } static Const create(Scope s, int[] v) { - return create(s, TInt32.vectorOf(v)); + try (TensorScope scope = new TensorScope()) { + return create(s, TInt32.vectorOf(scope, v)); + } } static Const create(Scope s, T value) { @@ -207,6 +216,7 @@ Output output() { } private static final class Mean { + private final Output output; static Mean create(Scope s, Output input, Output reductionIndices) { @@ -229,6 +239,7 @@ Output output() { } private static final class SquaredDifference { + private final Output output; static SquaredDifference create(Scope s, Output x, Output y) { @@ -251,17 +262,20 @@ Output output() { } private static final class Variance { + private final Output output; static Variance create(Scope base, Output x) { - Scope s = base.withSubScope("variance"); - Output zero = Const.create(base, TInt32.scalarOf(0)).output(); - Output sqdiff = - SquaredDifference.create( - s.withName("squared_deviation"), x, Mean.create(s, x, zero).output()) - .output(); - - return new Variance<>(Mean.create(s.withName("variance"), sqdiff, zero).output()); + try (TensorScope scope = new TensorScope()) { + Scope s = base.withSubScope("variance"); + Output zero = Const.create(base, TInt32.scalarOf(scope, 0)).output(); + Output sqdiff = + SquaredDifference.create( + s.withName("squared_deviation"), x, Mean.create(s, x, zero).output()) + .output(); + + return new Variance<>(Mean.create(s.withName("variance"), sqdiff, zero).output()); + } } Variance(Output o) { diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java index a4d9293ccf8..22249d4d9ba 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskTest.java @@ -22,6 +22,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TBool; @@ -29,10 +30,12 @@ import org.tensorflow.types.TInt32; public class BooleanMaskTest { + @Test - public void testBooleanMask(){ + public void testBooleanMask() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Operand input = Constant.arrayOf(scope, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9); @@ -43,25 +46,23 @@ public void testBooleanMask(){ Operand output1 = BooleanMask.create(scope, input, mask); Operand output2 = BooleanMask.create(scope, input2, mask, BooleanMask.axis(1)); - try (TFloat32 result = (TFloat32) sess.runner().fetch(output1).run().get(0)) { - // expected shape from Python tensorflow - assertEquals(Shape.of(5), result.shape()); - assertEquals(0, result.getFloat(0)); - assertEquals(1, result.getFloat(1)); - assertEquals(4, result.getFloat(2)); - assertEquals(5, result.getFloat(3)); - assertEquals(6, result.getFloat(4)); - } + TFloat32 result = (TFloat32) sess.runner().fetch(output1).run(tensorScope).get(0); + // expected shape from Python tensorflow + assertEquals(Shape.of(5), result.shape()); + assertEquals(0, result.getFloat(0)); + assertEquals(1, result.getFloat(1)); + assertEquals(4, result.getFloat(2)); + assertEquals(5, result.getFloat(3)); + assertEquals(6, result.getFloat(4)); - try (TFloat32 result = (TFloat32) sess.runner().fetch(output2).run().get(0)) { - // expected shape from Python tensorflow - assertEquals(Shape.of(5), result.shape()); - assertEquals(0, result.getFloat(0)); - assertEquals(1, result.getFloat(1)); - assertEquals(4, result.getFloat(2)); - assertEquals(5, result.getFloat(3)); - assertEquals(6, result.getFloat(4)); - } + TFloat32 result2 = (TFloat32) sess.runner().fetch(output2).run(tensorScope).get(0); + // expected shape from Python tensorflow + assertEquals(Shape.of(5), result2.shape()); + assertEquals(0, result2.getFloat(0)); + assertEquals(1, result2.getFloat(1)); + assertEquals(4, result2.getFloat(2)); + assertEquals(5, result2.getFloat(3)); + assertEquals(6, result2.getFloat(4)); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java index ab852bbffb2..ed3fd3614de 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/BooleanMaskUpdateTest.java @@ -24,6 +24,7 @@ import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TBool; @@ -34,7 +35,8 @@ public class BooleanMaskUpdateTest { @Test public void testBooleanMaskUpdateSlice() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); @@ -47,31 +49,31 @@ public void testBooleanMaskUpdateSlice() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); - try (TInt32 result = (TInt32) results.get(0); - TInt32 bcastResult = (TInt32) results.get(1)) { + List results = sess.runner().fetch(output).fetch(bcastOutput).run(tensorScope); + TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1); - assertEquals(Shape.of(3, 3), result.shape()); + assertEquals(Shape.of(3, 3), result.shape()); - assertEquals(-1, result.getInt(0, 0)); - assertEquals(-1, result.getInt(0, 1)); - assertEquals(-1, result.getInt(0, 2)); - assertEquals(1, result.getInt(1, 0)); - assertEquals(1, result.getInt(1, 1)); - assertEquals(1, result.getInt(1, 2)); - assertEquals(2, result.getInt(2, 0)); - assertEquals(2, result.getInt(2, 1)); - assertEquals(2, result.getInt(2, 2)); + assertEquals(-1, result.getInt(0, 0)); + assertEquals(-1, result.getInt(0, 1)); + assertEquals(-1, result.getInt(0, 2)); + assertEquals(1, result.getInt(1, 0)); + assertEquals(1, result.getInt(1, 1)); + assertEquals(1, result.getInt(1, 2)); + assertEquals(2, result.getInt(2, 0)); + assertEquals(2, result.getInt(2, 1)); + assertEquals(2, result.getInt(2, 2)); - assertEquals(result, bcastResult); - } + assertEquals(result, bcastResult); } } @Test public void testBooleanMaskUpdateSliceWithBroadcast() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Operand input = Constant.tensorOf(scope, new int[][]{{0, 0, 0}, {1, 1, 1}, {2, 2, 2}}); @@ -84,31 +86,31 @@ public void testBooleanMaskUpdateSliceWithBroadcast() { Operand bcastOutput = BooleanMaskUpdate.create(scope, input, mask, Constant.scalarOf(scope, -1)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); - try (TInt32 result = (TInt32) results.get(0); - TInt32 bcastResult = (TInt32) results.get(1)) { + List results = sess.runner().fetch(output).fetch(bcastOutput).run(tensorScope); + TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1); - assertEquals(Shape.of(3, 3), result.shape()); + assertEquals(Shape.of(3, 3), result.shape()); - assertEquals(-1, result.getInt(0, 0)); - assertEquals(-1, result.getInt(0, 1)); - assertEquals(-1, result.getInt(0, 2)); - assertEquals(1, result.getInt(1, 0)); - assertEquals(1, result.getInt(1, 1)); - assertEquals(1, result.getInt(1, 2)); - assertEquals(2, result.getInt(2, 0)); - assertEquals(2, result.getInt(2, 1)); - assertEquals(2, result.getInt(2, 2)); + assertEquals(-1, result.getInt(0, 0)); + assertEquals(-1, result.getInt(0, 1)); + assertEquals(-1, result.getInt(0, 2)); + assertEquals(1, result.getInt(1, 0)); + assertEquals(1, result.getInt(1, 1)); + assertEquals(1, result.getInt(1, 2)); + assertEquals(2, result.getInt(2, 0)); + assertEquals(2, result.getInt(2, 1)); + assertEquals(2, result.getInt(2, 2)); - assertEquals(result, bcastResult); - } + assertEquals(result, bcastResult); } } @Test public void testBooleanMaskUpdateAxis() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Operand input = Constant.tensorOf(scope, new int[][][]{{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9}}}); @@ -122,25 +124,24 @@ public void testBooleanMaskUpdateAxis() { Operand bcastOutput = BooleanMaskUpdate .create(scope, input, mask, Constant.scalarOf(scope, -1), BooleanMaskUpdate.axis(2)); - List results = sess.runner().fetch(output).fetch(bcastOutput).run(); - try (TInt32 result = (TInt32) results.get(0); - TInt32 bcastResult = (TInt32) results.get(1)) { - - assertEquals(Shape.of(1, 1, 10), result.shape()); - - assertEquals(-1, result.getInt(0, 0, 0)); - assertEquals(-1, result.getInt(0, 0, 1)); - assertEquals(2, result.getInt(0, 0, 2)); - assertEquals(3, result.getInt(0, 0, 3)); - assertEquals(-1, result.getInt(0, 0, 4)); - assertEquals(-1, result.getInt(0, 0, 5)); - assertEquals(-1, result.getInt(0, 0, 6)); - assertEquals(7, result.getInt(0, 0, 7)); - assertEquals(8, result.getInt(0, 0, 8)); - assertEquals(9, result.getInt(0, 0, 9)); - - assertEquals(result, bcastResult); - } + List results = sess.runner().fetch(output).fetch(bcastOutput).run(tensorScope); + TInt32 result = (TInt32) results.get(0); + TInt32 bcastResult = (TInt32) results.get(1); + + assertEquals(Shape.of(1, 1, 10), result.shape()); + + assertEquals(-1, result.getInt(0, 0, 0)); + assertEquals(-1, result.getInt(0, 0, 1)); + assertEquals(2, result.getInt(0, 0, 2)); + assertEquals(3, result.getInt(0, 0, 3)); + assertEquals(-1, result.getInt(0, 0, 4)); + assertEquals(-1, result.getInt(0, 0, 5)); + assertEquals(-1, result.getInt(0, 0, 6)); + assertEquals(7, result.getInt(0, 0, 7)); + assertEquals(8, result.getInt(0, 0, 8)); + assertEquals(9, result.getInt(0, 0, 9)); + + assertEquals(result, bcastResult); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java index 6df73261867..8bf54b27ec0 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ConstantTest.java @@ -18,13 +18,14 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import java.io.IOException; +import java.util.List; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.EagerSession; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.DoubleNdArray; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.IntNdArray; @@ -61,15 +62,14 @@ public void createInts() { IntNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @@ -80,15 +80,14 @@ public void createFloats() { FloatNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @@ -99,15 +98,14 @@ public void createDoubles() { DoubleNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @@ -118,15 +116,14 @@ public void createLongs() { LongNdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @@ -137,36 +134,36 @@ public void createStrings() throws IOException { NdArray array = NdArrays.wrap(shape, buffer); try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); Constant op1 = Constant.tensorOf(scope, shape, buffer); Constant op2 = Constant.tensorOf(scope, array); - try (AutoCloseableList t = - new AutoCloseableList<>(sess.runner().fetch(op1).fetch(op2).run())) { - assertEquals(array, t.get(0)); - assertEquals(array, t.get(1)); - } + List t = sess.runner().fetch(op1).fetch(op2).run(tensorScope); + assertEquals(array, t.get(0)); + assertEquals(array, t.get(1)); } } @Test public void createFromTensorsInEagerMode() throws IOException { try (EagerSession s = EagerSession.create(); - TInt32 t = TInt32.vectorOf(1, 2, 3, 4)) { + TensorScope tensorScope = new TensorScope(); + TInt32 t = TInt32.vectorOf(tensorScope, 1, 2, 3, 4)) { Ops tf = Ops.create(s); Constant c1 = tf.constant(t); - assertEquals(c1.asTensor(), t); + assertEquals(c1.asTensor(tensorScope), t); // A different endpoint for capturing a tensor as a constant, which supports all data types Constant c2 = tf.constantOf(t); - assertEquals(c2.asTensor(), t); - assertEquals(c1.asTensor(), c2.asTensor()); + assertEquals(c2.asTensor(tensorScope), t); + assertEquals(c1.asTensor(tensorScope), c2.asTensor(tensorScope)); // Permute data in the tensor to make sure that constant copies are independent t.setInt(10); assertEquals(NdArrays.vectorOf(10, 2, 3, 4), t); - assertEquals(NdArrays.vectorOf(1, 2, 3, 4), c1.asTensor()); + assertEquals(NdArrays.vectorOf(1, 2, 3, 4), c1.asTensor(tensorScope)); } } @@ -174,7 +171,8 @@ private static void testCreateFromNumber(Ops tf, Class type) Operand constant = tf.constant(type, 10); assertEquals(type, constant.type()); - try (TFloat64 t = tf.dtypes.cast(constant, TFloat64.class).asTensor()) { + try (TensorScope tensorScope = new TensorScope(); + TFloat64 t = tf.dtypes.cast(constant, TFloat64.class).asTensor(tensorScope)) { assertEquals(10.0, t.getDouble()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java index b1ebd469eb3..d686dcae81b 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GeneratedOperationsTest.java @@ -22,6 +22,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -32,28 +33,28 @@ public final class GeneratedOperationsTest { @Test public void tensorInputTensorOutput() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops ops = Ops.create(g); Operand x = ops.math.add(ops.constant(1), ops.constant(2)); - try (TInt32 result = (TInt32)sess.runner().fetch(x).run().get(0)) { - assertEquals(3, result.getInt()); - } + TInt32 result = (TInt32) sess.runner().fetch(x).run(scope).get(0); + assertEquals(3, result.getInt()); } } @Test public void testListInputTensorOutput() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops ops = Ops.create(g); ArrayList> inputs = new ArrayList<>(); inputs.add(ops.constant(1)); inputs.add(ops.constant(2)); inputs.add(ops.constant(3)); Operand x = ops.math.addN(inputs); - try (TInt32 result = (TInt32)sess.runner().fetch(x).run().get(0)) { - assertEquals(6, result.getInt()); - } + TInt32 result = (TInt32) sess.runner().fetch(x).run(scope).get(0); + assertEquals(6, result.getInt()); } } @@ -61,13 +62,14 @@ public void testListInputTensorOutput() { * Test for Ops.withControlDependencies. * *

    Creates an add node with a control dependency to an assign node. In other words, the assign - * node is a control input to the add node. When the add node is run, the assign node is expected - * to have run beforehand due to the control dependency. + * node is a control input to the add node. When the add node is run, the assign node is expected to have run + * beforehand due to the control dependency. */ @Test public void testControlDependencies() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops ops = Ops.create(g); Operand variable = ops.variable(Shape.scalar(), TInt32.class); Operand initVariable = ops.assign(variable, ops.constant(0)); @@ -75,10 +77,9 @@ public void testControlDependencies() { controls.add(ops.assign(variable, ops.constant(3))); Operand x = ops.withControlDependencies(controls).math.add(variable, ops.constant(0)); - sess.runner().addTarget(initVariable).run(); - try (TInt32 result = (TInt32)sess.runner().fetch(x).run().get(0)) { - assertEquals(3, result.getInt()); - } + sess.runner().addTarget(initVariable).run(scope); + TInt32 result = (TInt32) sess.runner().fetch(x).run(scope).get(0); + assertEquals(3, result.getInt()); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java index 80150b64bb6..3b6f0ada75d 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/GradientsTest.java @@ -20,12 +20,13 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; -import org.tensorflow.AutoCloseableList; import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.Session; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; @@ -34,7 +35,8 @@ public class GradientsTest { @Test public void createGradients() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); @@ -47,21 +49,19 @@ public void createGradients() { assertNotNull(grads.dy()); assertEquals(2, grads.dy().size()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run())) { + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + List outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).fetch(grads.dy(1)).run(scope); - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - assertEquals(18.0f, ((TFloat32)outputs.get(1)).getFloat(), 0.0f); - } + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); + assertEquals(18.0f, ((TFloat32) outputs.get(1)).getFloat(), 0.0f); } } @Test public void createGradientsWithSum() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); @@ -74,19 +74,18 @@ public void createGradientsWithSum() { assertNotNull(grads.dy()); assertEquals(1, grads.dy().size()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>(sess.runner().feed(x, c).fetch(grads.dy(0)).run())) { + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + List outputs = sess.runner().feed(x, c).fetch(grads.dy(0)).run(scope); - assertEquals(114.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - } + assertEquals(114.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); } } @Test public void createGradientsWithInitialValues() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(g); Output x = tf.placeholder(TFloat32.class).output(); @@ -100,13 +99,10 @@ public void createGradientsWithInitialValues() { assertNotNull(grads1.dy()); assertEquals(1, grads1.dy().size()); - try (TFloat32 c = TFloat32.scalarOf(3.0f); - AutoCloseableList outputs = - new AutoCloseableList<>( - sess.runner().feed(x, c).fetch(grads1.dy(0)).run())) { + TFloat32 c = TFloat32.scalarOf(scope, 3.0f); + List outputs = sess.runner().feed(x, c).fetch(grads1.dy(0)).run(scope); - assertEquals(108.0f, ((TFloat32)outputs.get(0)).getFloat(), 0.0f); - } + assertEquals(108.0f, ((TFloat32) outputs.get(0)).getFloat(), 0.0f); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java index 6e86573b7cf..a78422664e3 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/IndexingTest.java @@ -20,9 +20,10 @@ import org.junit.Test; import org.tensorflow.Graph; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; -import org.tensorflow.ndarray.index.Indices; import org.tensorflow.ndarray.index.Index; +import org.tensorflow.ndarray.index.Indices; import org.tensorflow.op.Scope; import org.tensorflow.types.TFloat32; @@ -62,9 +63,11 @@ public void testStridedSliceIndex(){ long[] shape = {10, 10, 10, 10, 10, 10, 10, 10}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class); StridedSlice output = StridedSliceHelper.stridedSlice(scope, op, slice); - try (TFloat32 result = (TFloat32) sess.runner().fetch(output.asOutput()).run().get(0)) { + try (TensorScope tensorScope = new TensorScope(); + TFloat32 result = (TFloat32) sess.runner().fetch(output.asOutput()).run(tensorScope).get(0)) { // expected shape from Python tensorflow - assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.shape(), "Slice index didn't match expected (Python)"); + assertEquals(Shape.of(1, 10, 1, 10, 10, 10, 4, 3), result.shape(), + "Slice index didn't match expected (Python)"); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java index 39c04c942af..46604a53ddb 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ShapesTest.java @@ -22,6 +22,7 @@ import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.op.Scope; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt32; @@ -29,153 +30,165 @@ public class ShapesTest { - /** Test of flatten method, of class Shapes. */ + /** + * Test of flatten method, of class Shapes. + */ @Test public void testFlatten_Operand() { - try (Graph g = new Graph(); + try (TensorScope tensorScope = new TensorScope(); + Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Shape expResult = Shape.create(scope, operand, TInt64.class); Operand reshaped = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Operand actual = Shapes.flatten(scope, reshaped); Shape tfshape = Shape.create(scope, actual, TInt64.class); AtomicInteger index = new AtomicInteger(); - try (TInt64 result1 = (TInt64)session.runner().fetch(tfshape.asOutput()).run().get(0); - TInt64 result2 = (TInt64)session.runner().fetch(expResult.asOutput()).run().get(0)) { - result1 - .scalars() - .forEach( - s -> assertEquals(result2.getLong(index.getAndIncrement()), s.getLong())); - } + TInt64 result1 = (TInt64) session.runner().fetch(tfshape.asOutput()).run(tensorScope).get(0); + TInt64 result2 = (TInt64) session.runner().fetch(expResult.asOutput()).run(tensorScope).get(0); + result1 + .scalars() + .forEach( + s -> assertEquals(result2.getLong(index.getAndIncrement()), s.getLong())); } } - /** Test of flatten method, of class Shapes. */ + /** + * Test of flatten method, of class Shapes. + */ @Test public void testFlatten_Shape() { - try (EagerSession session = EagerSession.create()) { + try (TensorScope tensorScope = new TensorScope(); + EagerSession session = EagerSession.create()) { Scope scope = new Scope(session); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Shape expShape = Shape.create(scope, operand, TInt64.class); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand flattened = Shapes.flatten(scope, tfshape, TInt64.class); AtomicInteger index = new AtomicInteger(); flattened - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> assertEquals( - expShape.asTensor().getLong(index.getAndIncrement()), s.getLong())); + expShape.asTensor(tensorScope).getLong(index.getAndIncrement()), s.getLong())); } } - /** Test of size method, of class Shapes. */ + /** + * Test of size method, of class Shapes. + */ @Test public void testSize_Shape() { - try (Graph g = new Graph(); + try (TensorScope tensorScope = new TensorScope(); + Graph g = new Graph(); Session session = new Session(g)) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand size = Shapes.size(scope, tfshape, TInt64.class); AtomicInteger index = new AtomicInteger(); - try (TInt64 result1 = (TInt64)session.runner().fetch(size.asOutput()).run().get(0)) { - result1.scalars().forEach(s -> assertEquals(8, s.getLong())); - } + TInt64 result1 = (TInt64) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result1.scalars().forEach(s -> assertEquals(8, s.getLong())); } } - /** Test of size method, of class Shapes. */ + /** + * Test of size method, of class Shapes. + */ @Test public void testSize_Shape_Operand() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 0)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(4, s.getInt())); - } + TInt32 result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(4, s.getInt())); size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 1)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(2, s.getInt())); - } + result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(2, s.getInt())); size = Shapes.size(scope, tfshape, Constant.scalarOf(scope, 2)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(1, s.getInt())); - } + result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(1, s.getInt())); } } - /** Test of size method, of class Shapes. */ + /** + * Test of size method, of class Shapes. + */ @Test public void testSize_Operand_Operand() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Operand size = Shapes.size(scope, actual, Constant.scalarOf(scope, 0)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(4, s.getInt())); - } + TInt32 result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(4, s.getInt())); size = Shapes.size(scope, actual, Constant.scalarOf(scope, 1)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(2, s.getInt())); - } + result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(2, s.getInt())); size = Shapes.size(scope, actual, Constant.scalarOf(scope, 2)); - try (TInt32 result = (TInt32)session.runner().fetch(size.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(1, s.getInt())); - } + result = (TInt32) session.runner().fetch(size.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(1, s.getInt())); } } - /** Test of numDimensions method, of class Shapes. */ + /** + * Test of numDimensions method, of class Shapes. + */ @Test public void testNumDimensions() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand nDims = Shapes.numDimensions(scope, tfshape); - try (TInt32 result = (TInt32)session.runner().fetch(nDims.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(3, s.getInt())); - } + TInt32 result = (TInt32) session.runner().fetch(nDims.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(3, s.getInt())); } } - /** Test of reduceDims method, of class Shapes. */ + /** + * Test of reduceDims method, of class Shapes. + */ @Test public void testReduceDims_Operand_Operand() { - try (EagerSession session = EagerSession.create()) { + try (EagerSession session = EagerSession.create(); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(session); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {2, 2, 2})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{2, 2, 2})); Shape tfshape = Shape.create(scope, actual); Operand reduced = Shapes.reduceDims(scope, actual, Constant.scalarOf(scope, 0)); @@ -183,7 +196,7 @@ public void testReduceDims_Operand_Operand() { AtomicInteger index = new AtomicInteger(); int[] expected = {8}; reducedShape - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> { @@ -193,14 +206,17 @@ public void testReduceDims_Operand_Operand() { } } - /** Test of reduceDims method, of class Shapes. */ + /** + * Test of reduceDims method, of class Shapes. + */ @Test public void testReduceDims_Shape_Operand() { - try (EagerSession session = EagerSession.create()) { + try (EagerSession session = EagerSession.create(); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(session); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {2, 2, 2})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{2, 2, 2})); Shape tfshape = Shape.create(scope, actual); Operand reduced = Shapes.reduceDims(scope, actual, Constant.scalarOf(scope, 0)); @@ -208,7 +224,7 @@ public void testReduceDims_Shape_Operand() { AtomicInteger index = new AtomicInteger(); int[] expected1 = {8}; reducedShape - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> { @@ -221,7 +237,7 @@ public void testReduceDims_Shape_Operand() { index.set(0); int[] expected2 = {2, 4}; reducedShape - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> { @@ -234,7 +250,7 @@ public void testReduceDims_Shape_Operand() { index.set(0); int[] expected3 = {2, 2, 2}; reducedShape - .asTensor() + .asTensor(tensorScope) .scalars() .forEach( s -> { @@ -244,28 +260,30 @@ public void testReduceDims_Shape_Operand() { } } - /** Test of squeeze method, of class Shapes. */ + /** + * Test of squeeze method, of class Shapes. + */ @Test public void testSqueeze() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand squeezed = Shapes.squeeze(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2}; - try (TInt32 result = (TInt32)session.runner().fetch(squeezed.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(squeezed.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -273,24 +291,24 @@ public void testSqueeze() { @Test public void testHead() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand head = Shapes.head(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {4}; - try (TInt32 result = (TInt32)session.runner().fetch(head.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(head.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -298,24 +316,24 @@ public void testHead() { @Test public void testTake() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand take = Shapes.take(scope, tfshape, Constant.scalarOf(scope, 2)); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 1}; - try (TInt32 result = (TInt32)session.runner().fetch(take.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(take.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -323,24 +341,24 @@ public void testTake() { @Test public void testTail() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand tail = Shapes.tail(scope, tfshape); AtomicInteger index = new AtomicInteger(); int[] expected = {1}; - try (TInt32 result = (TInt32)session.runner().fetch(tail.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(tail.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -348,24 +366,24 @@ public void testTail() { @Test public void testTakeLast() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual = - Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 1, 2, 1})); + Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 1, 2, 1})); Shape tfshape = Shape.create(scope, actual); Operand takeLast = Shapes.takeLast(scope, tfshape, Constant.scalarOf(scope, 3)); AtomicInteger index = new AtomicInteger(); int[] expected = {1, 2, 1}; - try (TInt32 result = (TInt32)session.runner().fetch(takeLast.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(takeLast.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -373,23 +391,23 @@ public void testTakeLast() { @Test public void testPrependInt() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2})); Shape tfshape = Shape.create(scope, actual); Operand prepend = Shapes.prepend(scope, tfshape, 3); AtomicInteger index = new AtomicInteger(); int[] expected = {3, 4, 2}; - try (TInt32 result = (TInt32)session.runner().fetch(prepend.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(prepend.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -397,23 +415,23 @@ public void testPrependInt() { @Test public void testPrependLong() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2})); Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand prepend = Shapes.prepend(scope, tfshape, 1L); AtomicInteger index = new AtomicInteger(); long[] expected = {1, 4, 2}; - try (TInt64 result = (TInt64)session.runner().fetch(prepend.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getLong()); - }); - } + TInt64 result = (TInt64) session.runner().fetch(prepend.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getLong()); + }); assertEquals(expected.length, index.get()); } } @@ -421,28 +439,28 @@ public void testPrependLong() { @Test public void testPrependShapeTInt32() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand1 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual1 = - Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[]{4, 2})); + Operand operand2 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = - Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); + Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[]{2, 4})); Shape tfshape1 = Shape.create(scope, actual1); Shape tfshape2 = Shape.create(scope, actual2); Operand prepend = Shapes.prepend(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); int[] expected = {2, 4, 4, 2}; - try (TInt32 result = (TInt32)session.runner().fetch(prepend.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(prepend.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -450,28 +468,28 @@ public void testPrependShapeTInt32() { @Test public void testPrependShapeTInt64() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand1 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual1 = - Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[]{4, 2})); + Operand operand2 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = - Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); + Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[]{2, 4})); Shape tfshape1 = Shape.create(scope, actual1, TInt64.class); Shape tfshape2 = Shape.create(scope, actual2, TInt64.class); Operand prepend = Shapes.prepend(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); long[] expected = {2, 4, 4, 2}; - try (TInt64 result = (TInt64)session.runner().fetch(prepend.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getLong()); - }); - } + TInt64 result = (TInt64) session.runner().fetch(prepend.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getLong()); + }); assertEquals(expected.length, index.get()); } } @@ -479,23 +497,23 @@ public void testPrependShapeTInt64() { @Test public void testAppendLong() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2})); Shape tfshape = Shape.create(scope, actual, TInt64.class); Operand append = Shapes.append(scope, tfshape, 2L); AtomicInteger index = new AtomicInteger(); long[] expected = {4L, 2L, 2L}; - try (TInt64 result = (TInt64)session.runner().fetch(append.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getLong()); - }); - } + TInt64 result = (TInt64) session.runner().fetch(append.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getLong()); + }); assertEquals(expected.length, index.get()); } } @@ -503,23 +521,23 @@ public void testAppendLong() { @Test public void testAppendInt() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); - Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[] {4, 2})); + Operand operand = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); + Operand actual = Reshape.create(scope, operand, Constant.vectorOf(scope, new long[]{4, 2})); Shape tfshape = Shape.create(scope, actual); Operand append = Shapes.append(scope, tfshape, 2); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2, 2}; - try (TInt32 result = (TInt32)session.runner().fetch(append.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(append.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -527,28 +545,28 @@ public void testAppendInt() { @Test public void testAppendShapeTInt32() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand1 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual1 = - Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[]{4, 2})); + Operand operand2 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = - Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); + Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[]{2, 4})); Shape tfshape1 = Shape.create(scope, actual1); Shape tfshape2 = Shape.create(scope, actual2); Operand append = Shapes.append(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); int[] expected = {4, 2, 2, 4}; - try (TInt32 result = (TInt32)session.runner().fetch(append.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getInt()); - }); - } + TInt32 result = (TInt32) session.runner().fetch(append.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getInt()); + }); assertEquals(expected.length, index.get()); } } @@ -556,28 +574,28 @@ public void testAppendShapeTInt32() { @Test public void testAppendShapeTInt64() { try (Graph g = new Graph(); - Session session = new Session(g)) { + Session session = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); - Operand operand1 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Operand operand1 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual1 = - Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[] {4, 2})); - Operand operand2 = Constant.arrayOf(scope, new float[] {1, 2, 3, 4, 5, 6, 7, 8}); + Reshape.create(scope, operand1, Constant.vectorOf(scope, new long[]{4, 2})); + Operand operand2 = Constant.arrayOf(scope, new float[]{1, 2, 3, 4, 5, 6, 7, 8}); Operand actual2 = - Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[] {2, 4})); + Reshape.create(scope, operand2, Constant.vectorOf(scope, new long[]{2, 4})); Shape tfshape1 = Shape.create(scope, actual1, TInt64.class); Shape tfshape2 = Shape.create(scope, actual2, TInt64.class); Operand append = Shapes.append(scope, tfshape1, tfshape2); AtomicInteger index = new AtomicInteger(); long[] expected = {4, 2, 2, 4}; - try (TInt64 result = (TInt64)session.runner().fetch(append.asOutput()).run().get(0)) { - result - .scalars() - .forEach( - s -> { - assertEquals(expected[index.getAndIncrement()], s.getLong()); - }); - } + TInt64 result = (TInt64) session.runner().fetch(append.asOutput()).run(tensorScope).get(0); + result + .scalars() + .forEach( + s -> { + assertEquals(expected[index.getAndIncrement()], s.getLong()); + }); assertEquals(expected.length, index.get()); } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java index 4121baf3af1..4e379efecbc 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/op/core/ZerosTest.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.op.Scope; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat32; @@ -37,102 +38,104 @@ public class ZerosTest { @Test public void createIntZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt32.class); - try (TInt32 result = (TInt32)sess.runner().fetch(op).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0, s.getInt())); - } + TInt32 result = (TInt32) sess.runner().fetch(op).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0, s.getInt())); } } @Test public void createFloatZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat32.class); - try (TFloat32 result = (TFloat32)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0.0f, s.getFloat(), 0)); - } + TFloat32 result = (TFloat32) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0.0f, s.getFloat(), 0)); + } } @Test public void createDoubleZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TFloat64.class); - try (TFloat64 result = (TFloat64)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0.0f, s.getDouble(), 0)); - } + TFloat64 result = (TFloat64) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0.0f, s.getDouble(), 0)); } } @Test public void createLongZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TInt64.class); - try (TInt64 result = (TInt64)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0L, s.getLong())); - } + TInt64 result = (TInt64) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0L, s.getLong())); } } @Test public void createBooleanZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TBool.class); - try (TBool result = (TBool)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertFalse(s.getBoolean())); - } - } + TBool result = (TBool) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertFalse(s.getBoolean())); + } } @Test public void createUint8Zeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TUint8.class); - try (TUint8 result = (TUint8)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertEquals(0, s.getByte())); - } + TUint8 result = (TUint8) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertEquals(0, s.getByte())); } } @Test public void createStringZeros() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros op = Zeros.create(scope, Constant.vectorOf(scope, shape), TString.class); - try (TString result = (TString)sess.runner().fetch(op.asOutput()).run().get(0)) { - result.scalars().forEach(s -> assertTrue(s.getObject().isEmpty())); - } + TString result = (TString) sess.runner().fetch(op.asOutput()).run(tensorScope).get(0); + result.scalars().forEach(s -> assertTrue(s.getObject().isEmpty())); } } @Test public void operationsComposingZerosAreCorrectlyNamed() { try (Graph g = new Graph(); - Session sess = new Session(g)) { + Session sess = new Session(g); + TensorScope tensorScope = new TensorScope()) { Scope scope = new Scope(g); long[] shape = {2, 2}; Zeros zeros = Zeros.create(scope.withSubScope("test"), Constant.vectorOf(scope, shape), TFloat32.class); - List results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(); + List results = sess.runner().addTarget("test/Zeros/Zero").addTarget("test/Zeros/Fill").run(tensorScope); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java index faddc7c5826..03254e65196 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/NumericTypesTestBase.java @@ -21,6 +21,7 @@ import org.junit.jupiter.api.Test; import org.tensorflow.EagerSession; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.index.Indices; @@ -34,78 +35,82 @@ abstract class NumericTypesTestBase { @Test public void initializeTensorsWithZeros() { - // Allocate a tensor of 32-bits integer of the shape (2, 3, 2) - T tensor = allocateTensor(Shape.of(2, 3, 2)); + try (TensorScope scope = new TensorScope()) { + // Allocate a tensor of 32-bits integer of the shape (2, 3, 2) + T tensor = allocateTensor(scope, Shape.of(2, 3, 2)); - assertEquals(3, tensor.rank()); - assertEquals(12, tensor.size()); - NdArray data = (NdArray)tensor; + assertEquals(3, tensor.rank()); + assertEquals(12, tensor.size()); + NdArray data = (NdArray) tensor; - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); + try (EagerSession session = EagerSession.create()) { + Ops tf = Ops.create(session); - // Initialize tensor memory with zeros and take a snapshot - data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(0))); - Constant x = tf.constantOf(tensor); + // Initialize tensor memory with zeros and take a snapshot + data.scalars().forEach(scalar -> ((NdArray) scalar).setObject(valueOf(0))); + Constant x = tf.constantOf(tensor); - // Initialize the same tensor memory with ones and take a snapshot - data.scalars().forEach(scalar -> ((NdArray)scalar).setObject(valueOf(1))); - Constant y = tf.constantOf(tensor); + // Initialize the same tensor memory with ones and take a snapshot + data.scalars().forEach(scalar -> ((NdArray) scalar).setObject(valueOf(1))); + Constant y = tf.constantOf(tensor); - // Subtract y from x and validate the result - Sub sub = tf.math.sub(x, y); - ((NdArray)sub.asTensor()).scalars().forEach(scalar -> - assertEquals(valueOf(-1), scalar.getObject()) - ); + // Subtract y from x and validate the result + Sub sub = tf.math.sub(x, y); + ((NdArray) sub.asTensor(scope)).scalars().forEach(scalar -> + assertEquals(valueOf(-1), scalar.getObject()) + ); + } } } @Test public void setAndCompute() { - NdArray heapData = allocateNdArray(Shape.of(4)) - .setObject(valueOf(0), 0) - .setObject(valueOf(1), 1) - .setObject(valueOf(2), 2) - .setObject(valueOf(3), 3); - - // Creates a 2x2 matrix - try (T tensor = allocateTensor(Shape.of(2, 2))) { - NdArray data = (NdArray)tensor; - - // Copy first 2 values of the vector to the first row of the matrix - data.set(heapData.slice(Indices.range(0, 2)), 0); - - // Copy values at an odd position in the vector as the second row of the matrix - data.set(heapData.slice(Indices.odd()), 1); - - assertEquals(valueOf(0), data.getObject(0, 0)); - assertEquals(valueOf(1), data.getObject(0, 1)); - assertEquals(valueOf(1), data.getObject(1, 0)); - assertEquals(valueOf(3), data.getObject(1, 1)); - - // Read rows of the tensor in reverse order - NdArray flippedData = data.slice(Indices.flip(), Indices.flip()); - - assertEquals(valueOf(3), flippedData.getObject(0, 0)); - assertEquals(valueOf(1), flippedData.getObject(0, 1)); - assertEquals(valueOf(1), flippedData.getObject(1, 0)); - assertEquals(valueOf(0), flippedData.getObject(1, 1)); - - try (EagerSession session = EagerSession.create()) { - Ops tf = Ops.create(session); - - Add add = tf.math.add(tf.constantOf(tensor), tf.constantOf(tensor)); - NdArray result = (NdArray)add.asTensor(); - - assertEquals(valueOf(0), result.getObject(0, 0)); - assertEquals(valueOf(2), result.getObject(0, 1)); - assertEquals(valueOf(2), result.getObject(1, 0)); - assertEquals(valueOf(6), result.getObject(1, 1)); + try (TensorScope scope = new TensorScope()) { + NdArray heapData = allocateNdArray(Shape.of(4)) + .setObject(valueOf(0), 0) + .setObject(valueOf(1), 1) + .setObject(valueOf(2), 2) + .setObject(valueOf(3), 3); + + // Creates a 2x2 matrix + try (T tensor = allocateTensor(scope, Shape.of(2, 2))) { + NdArray data = (NdArray) tensor; + + // Copy first 2 values of the vector to the first row of the matrix + data.set(heapData.slice(Indices.range(0, 2)), 0); + + // Copy values at an odd position in the vector as the second row of the matrix + data.set(heapData.slice(Indices.odd()), 1); + + assertEquals(valueOf(0), data.getObject(0, 0)); + assertEquals(valueOf(1), data.getObject(0, 1)); + assertEquals(valueOf(1), data.getObject(1, 0)); + assertEquals(valueOf(3), data.getObject(1, 1)); + + // Read rows of the tensor in reverse order + NdArray flippedData = data.slice(Indices.flip(), Indices.flip()); + + assertEquals(valueOf(3), flippedData.getObject(0, 0)); + assertEquals(valueOf(1), flippedData.getObject(0, 1)); + assertEquals(valueOf(1), flippedData.getObject(1, 0)); + assertEquals(valueOf(0), flippedData.getObject(1, 1)); + + try (EagerSession session = EagerSession.create()) { + Ops tf = Ops.create(session); + + Add add = tf.math.add(tf.constantOf(tensor), tf.constantOf(tensor)); + NdArray result = (NdArray) add.asTensor(scope); + + assertEquals(valueOf(0), result.getObject(0, 0)); + assertEquals(valueOf(2), result.getObject(0, 1)); + assertEquals(valueOf(2), result.getObject(1, 0)); + assertEquals(valueOf(6), result.getObject(1, 1)); + } } } } - abstract T allocateTensor(Shape shape); + abstract T allocateTensor(TensorScope scope, Shape shape); abstract NdArray allocateNdArray(Shape shape); diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java index 17a6e0dd2b5..9d2890ebdce 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TBfloat16Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TBfloat16Test extends NumericTypesTestBase { @Override - TBfloat16 allocateTensor(Shape shape) { - return TBfloat16.tensorOf(shape); + TBfloat16 allocateTensor(TensorScope scope, Shape shape) { + return TBfloat16.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java index c1ae8ad3b6d..4339be8a2c9 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat16Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TFloat16Test extends NumericTypesTestBase { @Override - TFloat16 allocateTensor(Shape shape) { - return TFloat16.tensorOf(shape); + TFloat16 allocateTensor(TensorScope scope, Shape shape) { + return TFloat16.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java index 8df96f2871a..45bf8278487 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat32Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TFloat32Test extends NumericTypesTestBase { @Override - TFloat32 allocateTensor(Shape shape) { - return TFloat32.tensorOf(shape); + TFloat32 allocateTensor(TensorScope scope, Shape shape) { + return TFloat32.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java index 47b4b6d936a..fd12b6f2666 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TFloat64Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TFloat64Test extends NumericTypesTestBase { @Override - TFloat64 allocateTensor(Shape shape) { - return TFloat64.tensorOf(shape); + TFloat64 allocateTensor(TensorScope scope, Shape shape) { + return TFloat64.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java index a2ab28b6219..364ac21e8bf 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt32Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TInt32Test extends NumericTypesTestBase { @Override - TInt32 allocateTensor(Shape shape) { - return TInt32.tensorOf(shape); + TInt32 allocateTensor(TensorScope scope, Shape shape) { + return TInt32.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java index a88f3fb4d6d..2e04cb21688 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TInt64Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TInt64Test extends NumericTypesTestBase { @Override - TInt64 allocateTensor(Shape shape) { - return TInt64.tensorOf(shape); + TInt64 allocateTensor(TensorScope scope, Shape shape) { + return TInt64.tensorOf(scope, shape); } @Override diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java index 015f93b70e7..6b4893baf2a 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TStringTest.java @@ -23,6 +23,7 @@ import java.nio.charset.StandardCharsets; import org.junit.jupiter.api.Test; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -31,77 +32,94 @@ public class TStringTest { @Test public void createScalar() { - TString tensor = TString.scalarOf("Pretty vacant"); - assertNotNull(tensor); - assertEquals(Shape.scalar(), tensor.shape()); - assertEquals("Pretty vacant", tensor.getObject()); + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.scalarOf(scope, "Pretty vacant"); + assertNotNull(tensor); + assertEquals(Shape.scalar(), tensor.shape()); + assertEquals("Pretty vacant", tensor.getObject()); + } } - @Test - public void createrScalarLongerThan127() { - TString tensor = TString.scalarOf("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); - assertNotNull(tensor); - assertEquals(Shape.scalar(), tensor.shape()); - assertEquals("Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", tensor.getObject()); + @Test + public void createrScalarLongerThan127() { + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.scalarOf(scope, + "Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !"); + assertNotNull(tensor); + assertEquals(Shape.scalar(), tensor.shape()); + assertEquals( + "Long String 1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890 !", + tensor.getObject()); } + } - @Test + @Test public void createVector() { - TString tensor = TString.vectorOf("Pretty", "vacant"); - assertNotNull(tensor); - assertEquals(Shape.of(2), tensor.shape()); - assertEquals("Pretty", tensor.getObject(0)); - assertEquals("vacant", tensor.getObject(1)); + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.vectorOf(scope, "Pretty", "vacant"); + assertNotNull(tensor); + assertEquals(Shape.of(2), tensor.shape()); + assertEquals("Pretty", tensor.getObject(0)); + assertEquals("vacant", tensor.getObject(1)); + } } @Test public void createCopy() { - NdArray strings = NdArrays.ofObjects(String.class, Shape.of(2, 2)) - .setObject("Pretty", 0, 0) - .setObject("vacant", 0, 1) - .setObject("New", 1, 0) - .setObject("York", 1, 1); + try (TensorScope scope = new TensorScope()) { + NdArray strings = NdArrays.ofObjects(String.class, Shape.of(2, 2)) + .setObject("Pretty", 0, 0) + .setObject("vacant", 0, 1) + .setObject("New", 1, 0) + .setObject("York", 1, 1); - TString tensor = TString.tensorOf(strings); - assertNotNull(tensor); - strings.scalars().forEachIndexed((idx, s) -> - assertEquals(s.getObject(), tensor.getObject(idx)) - ); + TString tensor = TString.tensorOf(scope, strings); + assertNotNull(tensor); + strings.scalars().forEachIndexed((idx, s) -> + assertEquals(s.getObject(), tensor.getObject(idx)) + ); + } } @Test public void defaultCharsetIsUtf8() { - TString tensor = TString.tensorOf(NdArrays.scalarOfObject(BABY_CHICK)); - byte[] bytes = tensor.asBytes().getObject(); - assertArrayEquals(new byte[] { (byte)0xF0, (byte)0x9F, (byte)0x90, (byte)0xA5 }, bytes); - assertEquals(BABY_CHICK, tensor.getObject()); + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.tensorOf(scope, NdArrays.scalarOfObject(BABY_CHICK)); + byte[] bytes = tensor.asBytes().getObject(); + assertArrayEquals(new byte[]{(byte) 0xF0, (byte) 0x9F, (byte) 0x90, (byte) 0xA5}, bytes); + assertEquals(BABY_CHICK, tensor.getObject()); + } } @Test public void usingDifferentCharset() { - TString tensor = TString.tensorOf(StandardCharsets.UTF_16LE, NdArrays.scalarOfObject(BABY_CHICK)); - byte[] bytes = tensor.asBytes().getObject(); - assertArrayEquals(new byte[] { (byte)0x3D, (byte)0xD8, (byte)0x25, (byte)0xDC }, bytes); - assertEquals(BABY_CHICK, tensor.using(StandardCharsets.UTF_16LE).getObject()); + try (TensorScope scope = new TensorScope()) { + TString tensor = TString.tensorOf(scope, StandardCharsets.UTF_16LE, NdArrays.scalarOfObject(BABY_CHICK)); + byte[] bytes = tensor.asBytes().getObject(); + assertArrayEquals(new byte[]{(byte) 0x3D, (byte) 0xD8, (byte) 0x25, (byte) 0xDC}, bytes); + assertEquals(BABY_CHICK, tensor.using(StandardCharsets.UTF_16LE).getObject()); + } } @Test public void initializingTensorWithRawBytes() { - String[] strings = new String[] { "TensorFlow", "For", "Java", "Rocks", "!" }; - NdArray bytes = NdArrays.ofObjects(byte[].class, Shape.of(strings.length)); - for (int i = 0; i < strings.length; ++i) { - bytes.setObject(strings[i].getBytes(), i); - } - TString tensor = TString.tensorOfBytes(bytes); - assertNotNull(tensor); - assertEquals(bytes.shape(), tensor.shape()); + try (TensorScope scope = new TensorScope()) { + String[] strings = new String[]{"TensorFlow", "For", "Java", "Rocks", "!"}; + NdArray bytes = NdArrays.ofObjects(byte[].class, Shape.of(strings.length)); + for (int i = 0; i < strings.length; ++i) { + bytes.setObject(strings[i].getBytes(), i); + } + TString tensor = TString.tensorOfBytes(scope, bytes); + assertNotNull(tensor); + assertEquals(bytes.shape(), tensor.shape()); - NdArray tensorBytes = tensor.asBytes(); - for (int i = 0; i < strings.length; ++i) { - assertArrayEquals(bytes.getObject(i), tensorBytes.getObject(i)); + NdArray tensorBytes = tensor.asBytes(); + for (int i = 0; i < strings.length; ++i) { + assertArrayEquals(bytes.getObject(i), tensorBytes.getObject(i)); + } } } - private static final String BABY_CHICK = "\uD83D\uDC25"; + private static final String BABY_CHICK = "\uD83D\uDC25"; } diff --git a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java index ce7397d5878..b7fc379b6fc 100644 --- a/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java +++ b/tensorflow-core/tensorflow-core-api/src/test/java/org/tensorflow/types/TUint8Test.java @@ -17,6 +17,7 @@ package org.tensorflow.types; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.NdArray; import org.tensorflow.ndarray.NdArrays; import org.tensorflow.ndarray.Shape; @@ -24,8 +25,8 @@ public class TUint8Test extends NumericTypesTestBase { @Override - TUint8 allocateTensor(Shape shape) { - return TUint8.tensorOf(shape); + TUint8 allocateTensor(TensorScope scope, Shape shape) { + return TUint8.tensorOf(scope, shape); } @Override diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java index a3aa290a8c8..bc3c4eeaabf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/data/DatasetIterator.java @@ -15,15 +15,15 @@ */ package org.tensorflow.framework.data; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.ndarray.Shape; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; import org.tensorflow.types.family.TType; /** @@ -272,7 +272,9 @@ public Iterator>> iterator() { @Override public boolean hasNext() { - return nextOptional.hasValue().asTensor().getBoolean(); + try (TensorScope scope = new TensorScope()) { + return nextOptional.hasValue().asTensor(scope).getBoolean(); + } } @Override diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java index f6b0de71b0d..6a8521235fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/impl/LossesHelper.java @@ -14,7 +14,12 @@ =======================================================================*/ package org.tensorflow.framework.losses.impl; +import static org.tensorflow.framework.utils.CastHelper.cast; + +import java.util.Arrays; +import java.util.Collections; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.framework.losses.Reduction; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -26,11 +31,6 @@ import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TNumber; -import java.util.Arrays; -import java.util.Collections; - -import static org.tensorflow.framework.utils.CastHelper.cast; - /** * These are helper methods for Losses and Metrics and will be module private when Java modularity * is applied to TensorFlow Java. These methods should not be used outside of the losses and metrics @@ -52,12 +52,11 @@ public class LossesHelper { * @param tf the TensorFlow Ops * @param predictions Predicted values, a Operand of arbitrary dimensions. * @param labels Optional label Operand whose dimensions match prediction - * . + * . * @param the data type for the labels, predictions and result - * @return LossTuple of prediction, label,sampleWeight will - * be null. Each of them possibly has the last dimension squeezed, sampleWeight - * could be extended by one dimension. If sampleWeight is null, (prediction, - * label) is returned. + * @return LossTuple of prediction, label,sampleWeight will be null. Each of h + * of them possibly has the last dimension squeezed, sampleWeight could be extended by one dimension. If + * sampleWeight is null, (prediction, label) is returned. */ public static LossTuple squeezeOrExpandDimensions( Ops tf, Operand labels, Operand predictions) { @@ -141,8 +140,7 @@ public static LossTuple squeezeOrExpandDimensions( * Squeeze or expand the sampleWeight based on the rank difference * *

    If the rank difference is +1, squeeze the last dimension of sampleWeight, If the rank - * difference is -1, expand the last dimension of sampleWeight. Otherwise, leave the shape of - * sampleWeight as is. + * difference is -1, expand the last dimension of sampleWeight. Otherwise, leave the shape of sampleWeight as is. * * @param tf the TensorFlow Ops * @param sampleWeight the sample weights @@ -180,7 +178,7 @@ private static Operand maybeExpandWeights( * * @param tf the TensorFlowOps * @param labels Label values, a Tensor whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param the data type for the labels, predictions and result * @return labels and predictions, possibly with last dim squeezed. @@ -195,7 +193,7 @@ public static LossTuple removeSqueezableDimensions( * * @param tf the TensorFlowOps * @param labels Label values, a Operand whose dimensions match predictions - * . + * . * @param predictions Predicted values, a Tensor of arbitrary dimensions. * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). * @param the data type for the labels, predictions and result @@ -299,8 +297,8 @@ private static Operand reduceWeightedLoss( * @param losses Operand whose elements contain individual loss measurements. * @param numElements The number of measurable elements in losses. * @param the data type of the losses - * @return A scalar representing the mean of losses. If numElements is - * zero, then zero is returned. + * @return A scalar representing the mean of losses. If numElements is zero, then zero is + * returned. */ public static Operand safeMean( Ops tf, Operand losses, long numElements) { @@ -338,8 +336,7 @@ public static Operand allAxes(Ops tf, Operand op) * @param minValue the minimum value * @param maxValue the maximum value * @param the datatype for the values - * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph - * Session + * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session * @throws IllegalArgumentException if the TensorFlow Ops represents an Eager Session */ public static Operand rangeCheck( @@ -351,41 +348,44 @@ public static Operand rangeCheck( tf.reduceAll(tf.math.lessEqual(values, maxValue), allDims)); // Graph and Eager mode need to be handled differently, control dependencies are not allowed in // Eager mode - if (tf.scope().env().isGraph()) { - AssertThat assertThat = - tf.assertThat( - cond, - Arrays.asList( - tf.constant(prefix), - tf.constant(": values out of range, "), - tf.constant("minimum = "), - minValue, - tf.constant(", maximum = "), - maxValue)); - Ops ltf = - tf.withSubScope("rangeCheck") - .withControlDependencies(Collections.singletonList(assertThat)); - return ltf.identity(values); - } else if (!cond.asTensor().getBoolean()) - throw new IllegalArgumentException(String.format("%s : values out of range", prefix)); - else return values; + try (TensorScope scope = new TensorScope()) { + if (tf.scope().env().isGraph()) { + AssertThat assertThat = + tf.assertThat( + cond, + Arrays.asList( + tf.constant(prefix), + tf.constant(": values out of range, "), + tf.constant("minimum = "), + minValue, + tf.constant(", maximum = "), + maxValue)); + Ops ltf = + tf.withSubScope("rangeCheck") + .withControlDependencies(Collections.singletonList(assertThat)); + return ltf.identity(values); + } else if (!cond.asTensor(scope).getBoolean()) { + throw new IllegalArgumentException(String.format("%s : values out of range", prefix)); + } else { + return values; + } + } } /** - * Checks to see if all the values are in the allowed values set. Running the operand in Graph - * mode will throw {@link org.tensorflow.exceptions.TFInvalidArgumentException}, if at least one - * value is not in the allowed values set. In Eager mode, this method will throw an {@link - * IllegalArgumentException} if at least one value is not in the allowed values set. + * Checks to see if all the values are in the allowed values set. Running the operand in Graph mode will throw {@link + * org.tensorflow.exceptions.TFInvalidArgumentException}, if at least one value is not in the allowed values set. In + * Eager mode, this method will throw an {@link IllegalArgumentException} if at least one value is not in the allowed + * values set. * * @param tf The TensorFlow Ops * @param prefix A String prefix to include in the error message * @param values the values to check * @param allowedValues the allowed values * @param the data type for values and allowed values - * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph - * Session - * @throws IllegalArgumentException if the Session is in Eager mode and at least one value is not - * in the allowed values set + * @return the values possibly with control dependencies if the TensorFlow Ops represents a Graph Session + * @throws IllegalArgumentException if the Session is in Eager mode and at least one value is not in the allowed + * values set */ public static Operand valueCheck( Ops tf, String prefix, Operand values, Operand allowedValues) { @@ -396,26 +396,32 @@ public static Operand valueCheck( if (diffSize != Shape.UNKNOWN_SIZE) { if (diffSize != 0) { // at least 1 value in the diff did not match the allowed values. throw new IllegalArgumentException(String.format("%s : values not in value set,", prefix)); - } else return values; + } else { + return values; + } } else { // use dynamic shape - Operand cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0)); - // Graph and Eager mode need to be handled differently, control dependencies are not allowed - // in Eager mode - if (tf.scope().env().isGraph()) { - AssertThat assertThat = - tf.assertThat( - cond, - Arrays.asList( - tf.constant(prefix), - tf.constant(": values not in value set, values = "), - values)); - Ops ltf = - tf.withSubScope("valueCheck") - .withControlDependencies(Collections.singletonList(assertThat)); - return ltf.identity(values); - } else if (!cond.asTensor().getBoolean()) - throw new IllegalArgumentException(String.format("%s : values not in value set", prefix)); - else return values; + try (TensorScope scope = new TensorScope()) { + Operand cond = tf.math.equal(tf.shape.size(tf.shape(diff.out())), tf.constant(0)); + // Graph and Eager mode need to be handled differently, control dependencies are not allowed + // in Eager mode + if (tf.scope().env().isGraph()) { + AssertThat assertThat = + tf.assertThat( + cond, + Arrays.asList( + tf.constant(prefix), + tf.constant(": values not in value set, values = "), + values)); + Ops ltf = + tf.withSubScope("valueCheck") + .withControlDependencies(Collections.singletonList(assertThat)); + return ltf.identity(values); + } else if (!cond.asTensor(scope).getBoolean()) { + throw new IllegalArgumentException(String.format("%s : values not in value set", prefix)); + } else { + return values; + } + } } } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java index e730c79cfbf..4c915054077 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/utils/ShapeUtils.java @@ -14,9 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.utils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.types.TInt32; @@ -24,11 +28,9 @@ import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TIntegral; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -/** Various methods for processing with Shapes and Operands */ +/** + * Various methods for processing with Shapes and Operands + */ public class ShapeUtils { /** @@ -66,12 +68,14 @@ public static int[] getIntArray(Scope scope, Operand dims) { * @throws java.lang.IllegalArgumentException if the dims type is not an integer */ public static long[] getLongArray(Scope scope, Operand dims) { - if (scope.env().isEager()) { - return getLongArray(dims.asTensor()); - } - try (Session session = new Session((Graph) scope.env()); - TIntegral tensor = (TIntegral) session.runner().fetch(dims).run().get(0)) { - return getLongArray(tensor); + try (TensorScope tensorScope = new TensorScope()) { + if (scope.env().isEager()) { + return getLongArray(dims.asTensor(tensorScope)); + } + try (Session session = new Session((Graph) scope.env())) { + TIntegral tensor = (TIntegral) session.runner().fetch(dims).run(tensorScope).get(0); + return getLongArray(tensor); + } } } @@ -112,12 +116,16 @@ public static Shape reduce(Shape shape, int axis) { axis = shape.numDimensions() + axis; } long[] array = shape.asArray(); - if (array == null) return Shape.unknown(); + if (array == null) { + return Shape.unknown(); + } long[] newArray = new long[axis]; System.arraycopy(array, 0, newArray, 0, axis - 1); long prod = array[axis - 1]; for (int i = axis; i < array.length; i++) { - if (array[i] != Shape.UNKNOWN_SIZE) prod *= array[i]; + if (array[i] != Shape.UNKNOWN_SIZE) { + prod *= array[i]; + } } newArray[axis - 1] = prod; return Shape.of(newArray); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java index 8c2c3a54ff9..370512609c2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/constraints/MinMaxNormTest.java @@ -1,7 +1,10 @@ package org.tensorflow.framework.constraints; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -10,9 +13,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; -import java.util.Random; -import java.util.concurrent.atomic.AtomicInteger; - class MinMaxNormTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -27,12 +27,15 @@ private float[] getSampleArray() { return result; } - /** Test of call method, of class MinMaxNorm. */ + /** + * Test of call method, of class MinMaxNorm. + */ @Test public void testCall() { float[] testValues = {0.1f, 0.5f, 3f, 8f, 1e-7f}; - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { + for (TestSession.Mode tfMode : tfModes) { + try (TestSession session = TestSession.createTestSession(tfMode); + TensorScope scope = new TensorScope()) { Ops tf = session.getTF(); final float[] array = getSampleArray(); Operand weights = tf.reshape(tf.constant(array), tf.constant(Shape.of(100, 100))); @@ -41,15 +44,17 @@ public void testCall() { i.getAndIncrement()) { MinMaxNorm instance = new MinMaxNorm(tf, testValues[i.get()], testValues[i.get()] * 2); Operand result = instance.call(weights); - if (tfMode == TestSession.Mode.EAGER) - evaluate(session, result.asTensor(), testValues[i.get()]); - else + if (tfMode == TestSession.Mode.EAGER) { + evaluate(session, result.asTensor(scope), testValues[i.get()]); + } else { try (TFloat32 tensor = - (TFloat32) session.getGraphSession().runner().fetch(result).run().get(0)) { + (TFloat32) session.getGraphSession().runner().fetch(result).run(scope).get(0)) { evaluate(session, tensor, testValues[i.get()]); } + } } } + } } private void evaluate(TestSession session, TFloat32 tensor, float m) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java index 2d282e5dcf7..0545f25b407 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/BatchDatasetTest.java @@ -15,18 +15,18 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.tensorflow.ndarray.index.Indices.range; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertTrue; -import static org.tensorflow.ndarray.index.Indices.range; - public class BatchDatasetTest extends DatasetTestBase { @Test @@ -44,9 +44,9 @@ public void testEagerBatchDataset() { int count = 0; for (List> components : dataset) { - try (TInt32 batch1 = - (TInt32)components.get(0).asTensor(); - TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + try (TensorScope scope = new TensorScope()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); assertEquals(testMatrix1.slice(range(count, count + 2)), batch1); assertEquals(testMatrix2.slice(range(count, count + 2)), batch2); @@ -68,15 +68,16 @@ public void testDropLastBatch() { int count = 0; for (List> components : dataset) { + try (TensorScope scope = new TensorScope()) { - try (TInt32 batch1 = - (TInt32)components.get(0).asTensor(); - TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); assertEquals(testMatrix1.slice(range(count, count + 3)), batch1); assertEquals(testMatrix2.slice(range(count, count + 3)), batch2); count += 3; } + } } @@ -95,10 +96,9 @@ public void testKeepLastBatch() { boolean foundLastBatch = false; for (List> components : dataset) { - try (TInt32 batch1 = - (TInt32)components.get(0).asTensor(); - TInt32 batch2 = - (TInt32)components.get(1).asTensor();) { + try (TensorScope scope = new TensorScope()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); if (count == 0) { assertEquals(testMatrix1.slice(range(count, count + 3)), batch1); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java index 882a64ba54d..70da093bb7c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/DatasetIteratorTest.java @@ -15,25 +15,26 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.types.family.TType; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TFOutOfRangeException; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; - -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.tensorflow.types.family.TType; public class DatasetIteratorTest extends DatasetTestBase { @Test public void testGraphIteration() { - try (Graph graph = new Graph()) { + try (Graph graph = new Graph(); + TensorScope scope = new TensorScope()) { Ops tf = Ops.create(graph); List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); @@ -53,10 +54,10 @@ public void testGraphIteration() { int batches = 0; while (true) { try { - List outputs = session.runner().fetch(x).fetch(y).run(); + List outputs = session.runner().fetch(x).fetch(y).run(scope); - try (TInt32 xBatch = (TInt32)outputs.get(0); - TInt32 yBatch = (TInt32)outputs.get(1)) { + try (TInt32 xBatch = (TInt32) outputs.get(0); + TInt32 yBatch = (TInt32) outputs.get(1)) { assertEquals(testMatrix1.get(batches), xBatch); assertEquals(testMatrix2.get(batches), yBatch); batches++; @@ -71,22 +72,24 @@ public void testGraphIteration() { @Test public void testEagerIteration() { + try (TensorScope scope = new TensorScope()) { - Ops tf = Ops.create(); - - List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); + Ops tf = Ops.create(); - List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); + List> tensors = Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)); - Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); - int count = 0; - for (List> outputs : dataset) { - try (TInt32 batch1 = (TInt32)outputs.get(0).asTensor(); - TInt32 batch2 = (TInt32)outputs.get(1).asTensor()) { - assertEquals(testMatrix1.get(count), batch1); - assertEquals(testMatrix2.get(count), batch2); + List> dataTypes = Arrays.asList(TInt32.class, TInt32.class); - count++; + Dataset dataset = Dataset.fromTensorSlices(tf, tensors, dataTypes); + int count = 0; + for (List> outputs : dataset) { + try (TInt32 batch1 = (TInt32) outputs.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) outputs.get(1).asTensor(scope)) { + assertEquals(testMatrix1.get(count), batch1); + assertEquals(testMatrix2.get(count), batch2); + + count++; + } } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java index 5f203427563..16f7edff143 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/MapDatasetTest.java @@ -15,24 +15,25 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Operand; import org.tensorflow.Session; -import org.tensorflow.types.family.TType; +import org.tensorflow.TensorScope; import org.tensorflow.exceptions.TFOutOfRangeException; -import org.tensorflow.op.Ops; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.StdArrays; +import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; - -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; +import org.tensorflow.types.family.TType; public class MapDatasetTest extends DatasetTestBase { + IntNdArray mapped1; IntNdArray mapped2; @@ -41,14 +42,14 @@ public void setUp() { super.setUp(); mapped1 = StdArrays.ndCopyOf( - new int[][] { - {2, 4, 6, 8, 10}, - {4, 8, 12, 16, 20}, - {6, 12, 18, 24, 30}, - {8, 16, 24, 32, 40} + new int[][]{ + {2, 4, 6, 8, 10}, + {4, 8, 12, 16, 20}, + {6, 12, 18, 24, 30}, + {8, 16, 24, 32, 40} }); - mapped2 = StdArrays.ndCopyOf(new int[][] {{2}, {0}, {2}, {2}}); + mapped2 = StdArrays.ndCopyOf(new int[][]{{2}, {0}, {2}, {2}}); } @Test @@ -77,17 +78,16 @@ public void testGraphIteration() { int batches = 0; while (true) { - try { - List outputs = session.runner().fetch(X).fetch(y).run(); + try (TensorScope scope = new TensorScope()) { + List outputs = session.runner().fetch(X).fetch(y).run(scope); - try (TInt32 XBatch = (TInt32)outputs.get(0); - TInt32 yBatch = (TInt32)outputs.get(1)) { + TInt32 XBatch = (TInt32) outputs.get(0); + TInt32 yBatch = (TInt32) outputs.get(1); - assertEquals(mapped1.get(batches), XBatch); - assertEquals(mapped2.get(batches), yBatch); + assertEquals(mapped1.get(batches), XBatch); + assertEquals(mapped2.get(batches), yBatch); - batches++; - } + batches++; } catch (TFOutOfRangeException e) { break; } @@ -113,8 +113,9 @@ public void testEagerIteration() { int count = 0; for (List> outputs : dataset) { - try (TInt32 XBatch = (TInt32)outputs.get(0).asTensor(); - TInt32 yBatch = (TInt32)outputs.get(1).asTensor()) { + try (TensorScope scope = new TensorScope()) { + TInt32 XBatch = (TInt32) outputs.get(0).asTensor(scope); + TInt32 yBatch = (TInt32) outputs.get(1).asTensor(scope); assertEquals(mapped1.get(count), XBatch); assertEquals(mapped2.get(count), yBatch); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java index d0cdb4527a5..70e297fa1e9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/SkipDatasetTest.java @@ -15,32 +15,34 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - public class SkipDatasetTest extends DatasetTestBase { + @Test public void testEagerSkipDataset() { Ops tf = Ops.create(); Dataset dataset = Dataset.fromTensorSlices( - tf, - Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.class, TInt32.class)) + tf, + Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), + Arrays.asList(TInt32.class, TInt32.class)) .skip(2); int count = 2; for (List> components : dataset) { - try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); - TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + try (TensorScope scope = new TensorScope()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); assertEquals(testMatrix1.get(count), batch1); assertEquals(testMatrix2.get(count), batch2); count++; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java index 79a2e79c72e..19fbaf9da61 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/data/TakeDatasetTest.java @@ -15,16 +15,16 @@ */ package org.tensorflow.framework.data; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.Arrays; +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.TensorScope; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt32; -import java.util.Arrays; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - public class TakeDatasetTest extends DatasetTestBase { @Test @@ -33,15 +33,16 @@ public void testEagerTakeDataset() { Dataset dataset = Dataset.fromTensorSlices( - tf, - Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), - Arrays.asList(TInt32.class, TInt32.class)) + tf, + Arrays.asList(tf.constant(testMatrix1), tf.constant(testMatrix2)), + Arrays.asList(TInt32.class, TInt32.class)) .take(4); int count = 0; for (List> components : dataset) { - try (TInt32 batch1 = (TInt32)components.get(0).asTensor(); - TInt32 batch2 = (TInt32)components.get(1).asTensor()) { + try (TensorScope scope = new TensorScope()) { + TInt32 batch1 = (TInt32) components.get(0).asTensor(scope); + TInt32 batch2 = (TInt32) components.get(1).asTensor(scope); assertEquals(testMatrix1.get(count), batch1); assertEquals(testMatrix2.get(count), batch2); count++; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java index 63d666f8640..d83c50c725c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -14,9 +14,13 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -26,37 +30,33 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertThrows; - public class AssertBroadcastableTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; int[][][] valueArrayI = - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new int[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; long[][][] valueArrayL = - new long[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new long[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; float[][][] valueArrayF = - new float[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new float[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; double[][][] valueArrayD = - new double[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new double[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; private void testValid( @@ -68,10 +68,11 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - List tensors = - testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); - try (Tensor weightsTensor = tensors.get(0); - Tensor valuesTensor = tensors.get(1)) { + try (TensorScope scope = new TensorScope()) { + List tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run(scope); + Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1); Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); testSession @@ -80,7 +81,7 @@ private void testValid( .feed(weightsPlaceholder, weightsTensor) .feed(valuesPlaceholder, valuesTensor) .addTarget(dynamicOp) - .run(); + .run(scope); } } @@ -103,7 +104,7 @@ public void test1x1x1() { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayD); - Operand weights = tf.constant(new double[][][] {{{5}}}); + Operand weights = tf.constant(new double[][][]{{{5}}}); testValid(testSession, tf, weights, values, TFloat64.class); } } @@ -114,7 +115,7 @@ public void test1x1xN() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayL); - Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); + Operand weights = tf.constant(new long[][][]{{{5, 7, 11, 3}}}); testValid(testSession, tf, weights, values, TInt64.class); } } @@ -125,7 +126,7 @@ public void test1xNx1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); + Operand weights = tf.constant(new int[][][]{{{5}, {11}}}); testValid(testSession, tf, weights, values, TInt32.class); } } @@ -137,7 +138,7 @@ public void test1xNxN() { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); + Operand weights = tf.constant(new int[][][]{{{5, 7, 11, 3}, {2, 13, 7, 5}}}); testValid(testSession, tf, weights, values, TInt32.class); } } @@ -149,7 +150,7 @@ public void testNx1x1() { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); + Operand weights = tf.constant(new int[][][]{{{5}}, {{7}}, {{11}}}); testValid(testSession, tf, weights, values, TInt32.class); } } @@ -162,7 +163,7 @@ public void testNx1xN() { Operand values = tf.constant(valueArrayI); Operand weights = - tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); + tf.constant(new int[][][]{{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); testValid(testSession, tf, weights, values, TInt32.class); } } @@ -176,10 +177,10 @@ public void testNxNxN() { Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( - new int[][][] { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} + new int[][][]{ + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} }); testValid(testSession, tf, weights, values, TInt32.class); } @@ -199,7 +200,7 @@ public void testInvalid1x1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5}}); + Operand weights = tf.constant(new int[][]{{5}}); testValid(testSession, tf, weights, values, TInt32.class); } }); @@ -213,7 +214,7 @@ public void testInvalidPrefixMatch() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); + Operand weights = tf.constant(new int[][]{{5, 7}, {11, 3}, {2, 12}}); testValid(testSession, tf, weights, values, TInt32.class); } }); @@ -227,7 +228,7 @@ public void testInvalidSuffixMatch() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); + Operand weights = tf.constant(new int[][]{{5, 7, 11, 3}, {2, 12, 7, 5}}); testValid(testSession, tf, weights, values, TInt32.class); } }); @@ -241,7 +242,7 @@ public void testInvalidOnesExtraDim() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][][] {{{{5}}}}); + Operand weights = tf.constant(new int[][][][]{{{{5}}}}); testValid(testSession, tf, weights, values, TInt32.class); } }); @@ -258,10 +259,10 @@ public void testInvalidPrefixMatchExtraDim() { Operand weights = tf.constant( - new int[][][][] { - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, - {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} + new int[][][][]{ + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, + {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} }); testValid(testSession, tf, weights, values, TInt32.class); } @@ -278,12 +279,12 @@ public void testInvalidSuffixMatchExtraDim() { Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( - new int[][][][] { - { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} - } + new int[][][][]{ + { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + } }); testValid(testSession, tf, weights, values, TInt32.class); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java index 3322a81fe5b..3fde6d1b204 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java @@ -14,9 +14,15 @@ =======================================================================*/ package org.tensorflow.framework.metrics.impl; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.Tensor; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; @@ -25,38 +31,33 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import java.util.List; -import java.util.concurrent.atomic.AtomicInteger; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - public class BroadcastWeightsTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; int[][][] valueArrayI = - new int[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new int[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; long[][][] valueArrayL = - new long[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new long[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; float[][][] valueArrayF = - new float[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new float[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; double[][][] valueArrayD = - new double[][][] { - {{1, 2, 3, 4}, {5, 6, 7, 8}}, - {{9, 10, 11, 12}, {13, 14, 15, 16}}, - {{17, 18, 19, 20}, {21, 22, 23, 24}} + new double[][][]{ + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} }; private void testValid( @@ -78,10 +79,10 @@ private void testValid( Operand weightsPlaceholder = tf.placeholder(type); Operand valuesPlaceholder = tf.placeholder(type); - List tensors = - testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); - try (Tensor weightsTensor = tensors.get(0); - Tensor valuesTensor = tensors.get(1)) { + try (TensorScope scope = new TensorScope()) { + List tensors = testSession.getGraphSession().runner().fetch(weights).fetch(values).run(scope); + Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1); Operand dynamicOp = MetricsHelper.broadcastWeights(tf, weightsPlaceholder, valuesPlaceholder); @@ -93,7 +94,7 @@ private void testValid( .feed(weightsPlaceholder, weightsTensor) .feed(valuesPlaceholder, valuesTensor) .fetch(dynamicOp) - .run(); + .run(scope); if (expected != null) { if (type.equals(TInt32.class)) { @@ -140,8 +141,8 @@ public void testValidScalar() { Operand values = tf.constant(valueArrayF); Operand weights = tf.constant(5f); Float[] expected = { - 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, - 5f + 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, + 5f }; testValid(testSession, tf, weights, values, expected, TFloat32.class); } @@ -153,10 +154,10 @@ public void test1x1x1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayD); - Operand weights = tf.constant(new double[][][] {{{5}}}); + Operand weights = tf.constant(new double[][][]{{{5}}}); Double[] expected = { - 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., - 5. + 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., + 5. }; testValid(testSession, tf, weights, values, expected, TFloat64.class); @@ -169,10 +170,10 @@ public void test1x1xN() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayL); - Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); + Operand weights = tf.constant(new long[][][]{{{5, 7, 11, 3}}}); Long[] expected = { - 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, - 11L, 3L, + 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, + 11L, 3L, }; testValid(testSession, tf, weights, values, expected, TInt64.class); } @@ -184,9 +185,9 @@ public void test1xNx1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); + Operand weights = tf.constant(new int[][][]{{{5}, {11}}}); Integer[] expected = { - 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11 + 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11 }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -198,9 +199,9 @@ public void test1xNxN() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); + Operand weights = tf.constant(new int[][][]{{{5, 7, 11, 3}, {2, 13, 7, 5}}}); Integer[] expected = { - 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, + 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -212,9 +213,9 @@ public void testNx1x1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); + Operand weights = tf.constant(new int[][][]{{{5}}, {{7}}, {{11}}}); Integer[] expected = { - 5, 5, 5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 11, 11, 11, 11, 11, 11, 11, 11 + 5, 5, 5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 11, 11, 11, 11, 11, 11, 11, 11 }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -227,9 +228,9 @@ public void testNx1xN() { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); Operand weights = - tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); + tf.constant(new int[][][]{{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); Integer[] expected = { - 5, 7, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3 + 5, 7, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3 }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -244,13 +245,13 @@ public void testNxNxN() { Operand weights = tf.constant( - new int[][][] { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} + new int[][][]{ + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} }); Integer[] expected = { - 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5 + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5 }; testValid(testSession, tf, weights, values, expected, TInt32.class); } @@ -270,7 +271,7 @@ public void testInvalid1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[] {5}); + Operand weights = tf.constant(new int[]{5}); testValid(testSession, tf, weights, values, null, TInt32.class); } @@ -286,7 +287,7 @@ public void testInvalid1x1() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5}}); + Operand weights = tf.constant(new int[][]{{5}}); testValid(testSession, tf, weights, values, null, TInt32.class); } @@ -301,7 +302,7 @@ public void testInvalidPrefixMatch() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); + Operand weights = tf.constant(new int[][]{{5, 7}, {11, 3}, {2, 12}}); testValid(testSession, tf, weights, values, null, TInt32.class); } }); @@ -315,7 +316,7 @@ public void testInvalidSuffixMatch() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); + Operand weights = tf.constant(new int[][]{{5, 7, 11, 3}, {2, 12, 7, 5}}); testValid(testSession, tf, weights, values, null, TInt32.class); } }); @@ -329,7 +330,7 @@ public void testInvalidOnesExtraDim() { try (TestSession testSession = TestSession.createTestSession(tfMode)) { Ops tf = testSession.getTF(); Operand values = tf.constant(valueArrayI); - Operand weights = tf.constant(new int[][][][] {{{{5}}}}); + Operand weights = tf.constant(new int[][][][]{{{{5}}}}); testValid(testSession, tf, weights, values, null, TInt32.class); } }); @@ -346,10 +347,10 @@ public void testInvalidPrefixMatchExtraDim() { Operand weights = tf.constant( - new int[][][][] { - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, - {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, - {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} + new int[][][][]{ + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, + {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} }); testValid(testSession, tf, weights, values, null, TInt32.class); } @@ -366,12 +367,12 @@ public void testInvalidSuffixMatchExtraDim() { Operand values = tf.constant(valueArrayI); Operand weights = tf.constant( - new int[][][][] { - { - {{5, 7, 11, 3}, {2, 12, 7, 5}}, - {{2, 17, 11, 3}, {2, 17, 11, 3}}, - {{5, 7, 11, 3}, {2, 12, 7, 5}} - } + new int[][][][]{ + { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + } }); testValid(testSession, tf, weights, values, null, TInt32.class); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java index 49154882a0f..c3121325a48 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamTest.java @@ -14,8 +14,19 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; +import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -29,13 +40,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.Adam.FIRST_MOMENT; -import static org.tensorflow.framework.optimizers.Adam.SECOND_MOMENT; - /** Test cases for Adam Optimizer */ public class AdamTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -138,24 +142,26 @@ public void testBasic() { final float[] powers = { (float) Math.pow(beta1, step + 1), (float) Math.pow(beta2, step + 1) }; - - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("beta1_power") - .run() - .get(0)) { - result.scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); - } - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("beta2_power") - .run() - .get(0)) { - result.scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + try (TensorScope scope = new TensorScope()) { + + try (TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run(scope) + .get(0)) { + result.scalars().forEach(f -> assertEquals(powers[0], f.getFloat(), epsilon1)); + } + try (TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("beta2_power") + .run(scope) + .get(0)) { + result.scalars().forEach(f -> assertEquals(powers[1], f.getFloat(), epsilon1)); + } } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java index 60c17674dfe..920b3a38212 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/AdamaxTest.java @@ -14,8 +14,22 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.optimizers.Adamax.BETA_ONE_DEFAULT; +import static org.tensorflow.framework.optimizers.Adamax.BETA_TWO_DEFAULT; +import static org.tensorflow.framework.optimizers.Adamax.FIRST_MOMENT; +import static org.tensorflow.framework.optimizers.Adamax.GradAndVar; +import static org.tensorflow.framework.optimizers.Adamax.SECOND_MOMENT; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -29,35 +43,39 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.tensorflow.framework.optimizers.Adamax.*; - -/** Test cases for Adamax Optimizer */ +/** + * Test cases for Adamax Optimizer + */ public class AdamaxTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private static final int VAR = 0; private static final int M = 1; private static final int V = 2; - public AdamaxTest() {} + public AdamaxTest() { + } @BeforeAll - public static void setUpClass() {} + public static void setUpClass() { + } @AfterAll - public static void tearDownClass() {} + public static void tearDownClass() { + } @BeforeEach - public void setUp() {} + public void setUp() { + } @AfterEach - public void tearDown() {} + public void tearDown() { + } - /** Test of getOptimizerName method, of class Adamax. */ + /** + * Test of getOptimizerName method, of class Adamax. + */ @Test public void testGetOptimizerName() { try (TestSession session = TestSession.createTestSession(tfMode)) { @@ -69,7 +87,9 @@ public void testGetOptimizerName() { } } - /** Test of applyDense method, of class Adamax. */ + /** + * Test of applyDense method, of class Adamax. + */ @Test public void testBasic() { @@ -148,13 +168,14 @@ public void testBasic() { // Test powers final float beta1Power = (float) Math.pow(BETA_ONE_DEFAULT, step + 1); - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("beta1_power") - .run() - .get(0)) { + try (TensorScope scope = new TensorScope()) { + TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("beta1_power") + .run(scope) + .get(0); result.scalars().forEach(f -> assertEquals(beta1Power, f.getFloat(), epsilon1)); } session.run(update); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java index 849f2fbfec1..7d96bc37446 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/NadamTest.java @@ -14,8 +14,17 @@ =======================================================================*/ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; +import org.tensorflow.TensorScope; import org.tensorflow.framework.utils.ND; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.FloatNdArray; @@ -29,13 +38,11 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; -import java.util.ArrayList; -import java.util.List; - -import static org.junit.jupiter.api.Assertions.assertEquals; - -/** Test cases for Nadam Optimizer */ +/** + * Test cases for Nadam Optimizer + */ public class NadamTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; private static final int VAR = 0; @@ -44,21 +51,28 @@ public class NadamTest { float momentum = 1; - public NadamTest() {} + public NadamTest() { + } @BeforeAll - public static void setUpClass() {} + public static void setUpClass() { + } @AfterAll - public static void tearDownClass() {} + public static void tearDownClass() { + } @BeforeEach - public void setUp() {} + public void setUp() { + } @AfterEach - public void tearDown() {} + public void tearDown() { + } - /** Test of getOptimizerName method, of class Nadam. */ + /** + * Test of getOptimizerName method, of class Nadam. + */ @Test public void testGetOptimizerName() { try (TestSession session = TestSession.createTestSession(tfMode)) { @@ -70,7 +84,9 @@ public void testGetOptimizerName() { } } - /** Test of applyDense method, of class Nadam. */ + /** + * Test of applyDense method, of class Nadam. + */ @Test public void testBasic() { @@ -146,13 +162,15 @@ public void testBasic() { session.evaluate(var0Init, var0); session.evaluate(var1Init, var1); - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0)) { + try (TensorScope scope = new TensorScope()) { + + TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("momentum") + .run(scope) + .get(0); result.scalars().forEach(f -> assertEquals(1F, f.getFloat(), epsilon1)); } momentum = 1F; @@ -165,13 +183,14 @@ public void testBasic() { Nadam.BETA_ONE_DEFAULT * (1F - 0.5F * (float) Math.pow(0.96F, (0.004F * (step + 1)))); momentum = momentum * mut; - try (TFloat32 result = - (TFloat32)session - .getGraphSession() - .runner() - .fetch("momentum") - .run() - .get(0)) { + try (TensorScope scope = new TensorScope()) { + TFloat32 result = + (TFloat32) session + .getGraphSession() + .runner() + .fetch("momentum") + .run(scope) + .get(0); result.scalars().forEach(f -> assertEquals(momentum, f.getFloat(), epsilon1)); } mcache = ND.mul(mcache, momentum); @@ -198,6 +217,7 @@ public void testBasic() { session.evaluate(var1Np, var1); } } + } private FloatNdArray[] nadamUpdateNdArray( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java index 7884308c9fb..0db5784082b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java @@ -14,35 +14,53 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Predicate; +import org.tensorflow.EagerSession; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.IntNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Predicate; - -import static org.junit.jupiter.api.Assertions.*; - -/** Eager Mode Test Session */ +/** + * Eager Mode Test Session + */ public class EagerTestSession extends TestSession { private final EagerSession session; private final Ops tf; - /** Create an Eager mode test session. */ + /** + * Create an Eager mode test session. + */ public EagerTestSession() { this.session = EagerSession.create(); this.tf = Ops.create(session).withName("test"); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public Ops getTF() { return tf; @@ -57,702 +75,746 @@ public EagerSession getSession() { return session; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void close() { session.close(); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public boolean isEager() { return true; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public Session getGraphSession() { return null; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public EagerSession getEagerSession() { return this.session; } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(double expected, Operand input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } else if (inputType == TInt64.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } else if (inputType == TUint8.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + try (TensorScope scope = new TensorScope()) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); + } else if (inputType == TFloat64.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); + } else if (inputType == TInt32.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals((int) expected, f.getInt())); + } else if (inputType == TInt64.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals((long) expected, f.getLong())); + } else if (inputType == TUint8.class) { + Operand o = (Operand) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } + index.set(0); + o.asTensor(scope).scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Number[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + Class inputType = input.type(); + if (inputType == TFloat32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> + assertEquals( + expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); + } else if (inputType == TFloat64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> + assertEquals( + expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); + } else if (inputType == TInt32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); + } else if (inputType == TInt64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); + } else if (inputType == TUint8.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%x). %d\n", index.getAndIncrement(), f.getByte())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%x). %d\n", index.getAndIncrement(), f.getByte())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()].byteValue(), f.getByte())); } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].byteValue(), f.getByte())); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(FloatNdArray expected, Output input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicLong index = new AtomicLong(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + try (TensorScope scope = new TensorScope()) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + Output o = (Output) input; + AtomicLong index = new AtomicLong(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); + } else if (inputType == TFloat64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - for (IntNdArray f : o.asTensor().scalars()) { - assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); - } - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> + assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); + } else if (inputType == TInt32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + index.set(0); + for (IntNdArray f : o.asTensor(scope).scalars()) { + assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); + } + } else if (inputType == TInt64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); + } else if (inputType == TUint8.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } + index.set(0); + o.asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluateString(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %s\n", predicate.test(input.asTensor().getObject()), input.asTensor().getObject()); - } else { - input - .asTensor() - .scalars() - .forEachIndexed( - (idx, s) -> - System.out.printf( - "%d). %b <==> %s\n", - index.getAndIncrement(), predicate.test(s.getObject()), s.getObject())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(input.asTensor().getObject())); - } else { - input.asTensor().scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); - } - } - - /** {@inheritDoc} */ - @Override - public void evaluate(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - Output o = (Output) input; + try (TensorScope scope = new TensorScope()) { + AtomicInteger index = new AtomicInteger(); + boolean isScalar = input.shape().equals(Shape.scalar()); if (debug) { if (isScalar) { System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); + "0). %b <==> %s\n", predicate.test(input.asTensor(scope).getObject()), input.asTensor(scope).getObject()); } else { - o.asTensor() + input + .asTensor(scope) .scalars() .forEachIndexed( - (idx, f) -> + (idx, s) -> System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + "%d). %b <==> %s\n", + index.getAndIncrement(), predicate.test(s.getObject()), s.getObject())); } } index.set(0); if (isScalar) { - assertTrue(predicate.test(o.asTensor().getFloat())); + assertTrue(predicate.test(input.asTensor(scope).getObject())); } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); + input.asTensor(scope).scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - if (debug) { + } + } + + /** + * {@inheritDoc} + */ + @Override + public void evaluate(Output input, Predicate predicate) { + try (TensorScope scope = new TensorScope()) { + AtomicInteger index = new AtomicInteger(); + Class inputType = input.type(); + boolean isScalar = input.shape().equals(Shape.scalar()); + if (inputType == TFloat32.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", predicate.test(o.asTensor(scope).getFloat()), o.asTensor(scope).getFloat()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getDouble()), o.asTensor().getDouble()); + assertTrue(predicate.test(o.asTensor(scope).getFloat())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getFloat()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getDouble())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getDouble()))); - } - } else if (inputType == TFloat16.class) { - Output o = (Output) input; - if (debug) { + } else if (inputType == TFloat64.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", predicate.test(o.asTensor(scope).getDouble()), o.asTensor(scope).getDouble()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); + assertTrue(predicate.test(o.asTensor(scope).getDouble())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getDouble()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getFloat())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); - } - } else if (inputType == TInt32.class) { - Output o = (Output) input; - if (debug) { + } else if (inputType == TFloat16.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", predicate.test(o.asTensor(scope).getFloat()), o.asTensor(scope).getFloat()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(o.asTensor().getInt()), o.asTensor().getInt()); + assertTrue(predicate.test(o.asTensor(scope).getFloat())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getFloat()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getInt())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getInt()))); - } - } else if (inputType == TInt64.class) { - Output o = (Output) input; - if (debug) { + } else if (inputType == TInt32.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", predicate.test(o.asTensor(scope).getInt()), o.asTensor(scope).getInt()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(o.asTensor().getLong()), o.asTensor().getLong()); + assertTrue(predicate.test(o.asTensor(scope).getInt())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getInt()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getLong())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getLong()))); - } - } else if (inputType == TUint8.class) { - Output o = (Output) input; - if (debug) { + } else if (inputType == TInt64.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", predicate.test(o.asTensor(scope).getLong()), o.asTensor(scope).getLong()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); + } + } + index.set(0); if (isScalar) { - System.out.printf( - "0). %b <==> %x\n", predicate.test(o.asTensor().getByte()), o.asTensor().getByte()); + assertTrue(predicate.test(o.asTensor(scope).getLong())); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %x\n", - index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getLong()))); + } + } else if (inputType == TUint8.class) { + Output o = (Output) input; + if (debug) { + if (isScalar) { + System.out.printf( + "0). %b <==> %x\n", predicate.test(o.asTensor(scope).getByte()), o.asTensor(scope).getByte()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %x\n", + index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + } + } + index.set(0); + if (isScalar) { + assertTrue(predicate.test(o.asTensor(scope).getByte())); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor(scope).getByte()))); } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getByte())); } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getByte()))); + fail("Unexpected Class: " + inputType); } - } else { - fail("Unexpected Class: " + inputType); } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(String[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + if (debug) { + input + .asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + } + index.set(0); input - .asTensor() + .asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } - index.set(0); - input - .asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Boolean[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected size (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected size (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + if (debug) { + input + .asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + } + index.set(0); input - .asTensor() + .asTensor(scope) .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); } - index.set(0); - input - .asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void evaluate(Output expected, Output input) { - assert input.shape().equals(expected.shape()) - : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); - Class inputType = input.asOutput().type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + assert input.shape().equals(expected.shape()) + : String.format( + "expected shape (%s) != to input shape (%s)", + expected.shape().toString(), input.shape().toString()); + Class inputType = input.asOutput().type(); + boolean isScalar = input.shape().equals(Shape.scalar()); + if (inputType == TFloat32.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %f <==> %f\n", x.asTensor(scope).getFloat(), o.asTensor(scope).getFloat()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), x.asTensor(scope).getFloat(idx), f.getFloat())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.asTensor().getFloat(), o.asTensor().getFloat()); + assertEquals(x.asTensor(scope).getFloat(), o.asTensor(scope).getFloat(), epsilon); } else { - o.asTensor() + o.asTensor(scope) .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), x.asTensor().getFloat(idx), f.getFloat())); + (idx, f) -> assertEquals(x.asTensor(scope).getFloat(idx), f.getFloat(), epsilon)); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getFloat(), o.asTensor().getFloat(), epsilon); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(x.asTensor().getFloat(idx), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TFloat64.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %f <==> %f\n", x.asTensor(scope).getDouble(), o.asTensor(scope).getDouble()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), x.asTensor(scope).getDouble(idx), f.getDouble())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.asTensor().getDouble(), o.asTensor().getDouble()); + assertEquals(x.asTensor(scope).getDouble(), o.asTensor(scope).getDouble(), epsilon); } else { - o.asTensor() + o.asTensor(scope) .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), x.asTensor().getDouble(idx), f.getDouble())); + (idx, f) -> assertEquals(x.asTensor(scope).getDouble(idx), f.getDouble(), epsilon)); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getDouble(), o.asTensor().getDouble(), epsilon); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(x.asTensor().getDouble(idx), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt32.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %d <==> %d\n", x.asTensor(scope).getInt(), o.asTensor(scope).getInt()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), x.asTensor(scope).getInt(idx), f.getInt())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.asTensor().getInt(), o.asTensor().getInt()); + assertEquals(x.asTensor(scope).getInt(), o.asTensor(scope).getInt()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), x.asTensor().getInt(idx), f.getInt())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getInt(idx), f.getInt())); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getInt(), o.asTensor().getInt()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getInt(idx), f.getInt())); - } - } else if (inputType == TInt64.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt64.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %d <==> %d\n", x.asTensor(scope).getLong(), o.asTensor(scope).getLong()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), x.asTensor(scope).getLong(idx), f.getLong())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.asTensor().getLong(), o.asTensor().getLong()); + assertEquals(x.asTensor(scope).getLong(), o.asTensor(scope).getLong()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), x.asTensor().getLong(idx), f.getLong())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getLong(idx), f.getLong())); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getLong(), o.asTensor().getLong()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getLong(idx), f.getLong())); - } - } else if (inputType == TUint8.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TUint8.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %x <==> %x\n", x.asTensor(scope).getByte(), o.asTensor(scope).getByte()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %x <==> %x\n", + index.getAndIncrement(), x.asTensor(scope).getByte(idx), f.getByte())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %x <==> %x\n", x.asTensor().getByte(), o.asTensor().getByte()); + assertEquals(x.asTensor(scope).getByte(), o.asTensor(scope).getByte()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %x <==> %x\n", - index.getAndIncrement(), x.asTensor().getByte(idx), f.getByte())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getByte(idx), f.getByte())); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getByte(), o.asTensor().getByte()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getByte(idx), f.getByte())); - } - } else if (inputType == TString.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TString.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %s <==> %s\n", x.asTensor(scope).getObject(), o.asTensor(scope).getObject()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %s <==> %s\n", + index.getAndIncrement(), x.asTensor(scope).getObject(idx), f.getObject())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %s <==> %s\n", x.asTensor().getObject(), o.asTensor().getObject()); + assertEquals(x.asTensor(scope).getObject(), o.asTensor(scope).getObject()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %s <==> %s\n", - index.getAndIncrement(), x.asTensor().getObject(idx), f.getObject())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getObject(idx), f.getObject())); } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getObject(), o.asTensor().getObject()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getObject(idx), f.getObject())); - } - } else if (inputType == TBool.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TBool.class) { + Output x = (Output) expected; + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + if (debug) { + if (isScalar) { + System.out.printf("0). %b <==> %b\n", x.asTensor(scope).getBoolean(), o.asTensor(scope).getBoolean()); + } else { + o.asTensor(scope) + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %b\n", + index.getAndIncrement(), x.asTensor(scope).getBoolean(idx), f.getBoolean())); + } + } + index.set(0); if (isScalar) { - System.out.printf("0). %b <==> %b\n", x.asTensor().getBoolean(), o.asTensor().getBoolean()); + assertEquals(x.asTensor(scope).getBoolean(), o.asTensor(scope).getBoolean()); } else { - o.asTensor() + o.asTensor(scope) .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %b\n", - index.getAndIncrement(), x.asTensor().getBoolean(idx), f.getBoolean())); + .forEachIndexed((idx, f) -> assertEquals(x.asTensor(scope).getBoolean(idx), f.getBoolean())); } } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getBoolean(), o.asTensor().getBoolean()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getBoolean(idx), f.getBoolean())); - } } } - /** {@inheritDoc} */ + /** + * {@inheritDoc} + */ @Override public void print(PrintWriter writer, Output input) { - Class inputType = input.asOutput().type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } else if (inputType == TString.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } else if (inputType == TBool.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); - } else { - writer.println("Unexpected Class: " + inputType); + try (TensorScope scope = new TensorScope()) { + Class inputType = input.asOutput().type(); + if (inputType == TFloat32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } else if (inputType == TFloat64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } else if (inputType == TInt32.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } else if (inputType == TInt64.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } else if (inputType == TUint8.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } else if (inputType == TString.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + } else if (inputType == TBool.class) { + Output o = (Output) input; + AtomicInteger index = new AtomicInteger(); + o.asTensor(scope) + .scalars() + .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + } else { + writer.println("Unexpected Class: " + inputType); + } + writer.flush(); } - writer.flush(); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 43c0642939e..70e2b92ba67 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -14,22 +14,35 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import java.io.PrintWriter; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Predicate; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Output; +import org.tensorflow.Session; +import org.tensorflow.TensorScope; import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; -import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; -import java.util.function.Predicate; - -import static org.junit.jupiter.api.Assertions.*; - /** * Graph Mode Test Session */ @@ -110,7 +123,9 @@ public EagerSession getEagerSession() { */ @Override public void initialize() { - graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run()); + try (TensorScope scope = new TensorScope()) { + graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run(scope)); + } } /** @@ -126,84 +141,86 @@ public void run(Op op) { */ @Override public void evaluate(double expected, Operand input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TFloat64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + } + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + } + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + } + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TUint8.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + } + } + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -212,108 +229,108 @@ public void evaluate(double expected, Operand input) { */ @Override public void evaluate(Number[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - if (size != Shape.UNKNOWN_SIZE) { - assertEquals( + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + if (size != Shape.UNKNOWN_SIZE) {assertEquals( expected.length, - size, - () -> - String.format("expected length (%d) != to input length (%d)", expected.length, size)); + size,() -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); } - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + .forEach( + f -> + assertEquals( + expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TFloat64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + } + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + .forEach( + f -> + assertEquals( + expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + } + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + } + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TUint8.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + } + } + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -322,103 +339,105 @@ public void evaluate(Number[] expected, Output input) { */ @Override public void evaluate(FloatNdArray expected, Output input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicLong index = new AtomicLong(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + Class inputType = input.type(); + if (inputType == TFloat32.class) { + AtomicLong index = new AtomicLong(); + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + .forEach( + f -> + assertEquals( + expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TFloat64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } + } + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + .forEach( + f -> + assertEquals( + expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt32.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } + } + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + .forEach( + f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TInt64.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } + } + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { + } else if (inputType == TUint8.class) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + } + } + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); + .forEach( + f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -427,30 +446,30 @@ public void evaluate(FloatNdArray expected, Output input) { */ @Override public void evaluate(String[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - if (size != Shape.UNKNOWN_SIZE) { - assertEquals( + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + if (size != Shape.UNKNOWN_SIZE) {assertEquals( expected.length, - size, - () -> - String.format("expected length (%d) != to input length (%d)", expected.length, size)); + size,() -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); } - AtomicInteger index = new AtomicInteger(); - if (debug) { + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TString result = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + } + } + index.set(0); try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); - } } /** @@ -458,27 +477,29 @@ public void evaluate(String[] expected, Output input) { */ @Override public void evaluate(Boolean[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TBool result = + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + result + .scalars() + .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getObject())); + } + } + index.set(0); try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { result .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getObject())); + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); } } - index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); - } } /** @@ -486,320 +507,322 @@ public void evaluate(Boolean[] expected, Output input) { */ @Override public void evaluate(Output expected, Output input) { - assert input.shape().equals(expected.shape()) - : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - if (!inputType.equals(expected.type())) { - throw new IllegalArgumentException( - String.format( - "Both data type must be equal, inout = %s, expected = %s", - inputType, expected.dataType())); - } - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - final Output finalExpected = (Output) expected; - if (debug) { + try (TensorScope scope = new TensorScope()) { + assert input.shape().equals(expected.shape()) + : String.format( + "expected shape (%s) != to input shape (%s)", + expected.shape().toString(), input.shape().toString()); + AtomicInteger index = new AtomicInteger(); + Class inputType = input.type(); + if (!inputType.equals(expected.type())) { + throw new IllegalArgumentException( + String.format( + "Both data type must be equal, inout = %s, expected = %s", + inputType, expected.dataType())); + } + boolean isScalar = input.shape().equals(Shape.scalar()); + if (inputType == TFloat32.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TFloat32 expectedResult = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getFloat(idx), + f.getFloat())); + } + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0); TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getFloat(idx), - f.getFloat())); + assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); + } else if (inputType == TFloat64.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TFloat64 expectedResult = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getDouble(idx), + f.getDouble())); + } + } } - } - } else if (inputType == TFloat64.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0); TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); + assertEquals(expectedResult.getDouble(), result.getDouble(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getDouble(idx), - f.getDouble())); + assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); } } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getDouble(), result.getDouble(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); + } else if (inputType == TFloat16.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TFloat16 result = + (TFloat16) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TFloat16 expectedResult = + (TFloat16) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %f <==> %f\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getFloat(idx), + f.getFloat())); + } + } } - } - } else if (inputType == TFloat16.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); + (TFloat16) this.getGraphSession().runner().fetch(input).run(scope).get(0); TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat16) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); + assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getFloat(idx), - f.getFloat())); + assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); } } - } - index.set(0); - try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); + } else if (inputType == TInt32.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TInt32 expectedResult = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), finalExpected.asTensor(scope).getInt(idx), f.getInt())); + } + } } - } - } else if (inputType == TInt32.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0); TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); + assertEquals(expectedResult.getInt(), result.getInt(), epsilon); } else { result .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), finalExpected.asTensor().getInt(idx), f.getInt())); + (idx, f) -> assertEquals(expectedResult.getInt(idx), f.getInt(), epsilon)); } } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); - TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getInt(), result.getInt(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getInt(idx), f.getInt(), epsilon)); + } else if (inputType == TInt64.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TInt64 expectedResult = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getLong(idx), + f.getLong())); + } + } } - } - } else if (inputType == TInt64.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0); TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); + assertEquals(expectedResult.getLong(), result.getLong(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), - finalExpected.asTensor().getLong(idx), - f.getLong())); + assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); } } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); - TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getLong(), result.getLong(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); + } else if (inputType == TUint8.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TUint8 expectedResult = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %d <==> %d\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getByte(idx), + f.getByte())); + } + } } - } - } else if (inputType == TUint8.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0); TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); + assertEquals(expectedResult.getByte(), result.getByte(), epsilon); } else { result .scalars() .forEachIndexed( (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), - finalExpected.asTensor().getByte(idx), - f.getByte())); + assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); } } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); - TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getByte(), result.getByte(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); + } else if (inputType == TBool.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TBool result = + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TBool expectedResult = + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %b\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getBoolean(idx), + f.getBoolean())); + } + } } - } - } else if (inputType == TBool.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0); TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); + assertEquals(expectedResult.getBoolean(), result.getBoolean()); } else { result .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %b\n", - index.getAndIncrement(), - finalExpected.asTensor().getBoolean(idx), - f.getBoolean())); + (idx, f) -> assertEquals(expectedResult.getBoolean(idx), f.getBoolean())); } } - } - index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); - TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getBoolean(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getBoolean(idx), f.getBoolean())); + } else if (inputType == TString.class) { + final Output finalExpected = (Output) expected; + if (debug) { + try (TString result = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0); + TString expectedResult = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %s <==> %s\n", + index.getAndIncrement(), + finalExpected.asTensor(scope).getObject(idx), + f.getObject())); + } + } } - } - } else if (inputType == TString.class) { - final Output finalExpected = (Output) expected; - if (debug) { + index.set(0); try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0); TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); + assertEquals(expectedResult.getObject(), result.getObject()); } else { result .scalars() .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %s <==> %s\n", - index.getAndIncrement(), - finalExpected.asTensor().getObject(idx), - f.getObject())); + (idx, f) -> assertEquals(expectedResult.getObject(idx), f.getObject())); } } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); - TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getObject(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getObject(idx), f.getObject())); - } - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -808,37 +831,39 @@ public void evaluate(Output expected, Output input) { */ @Override public void evaluateString(Output input, Predicate predicate) { - boolean isScalar = input.shape().equals(Shape.scalar()); - AtomicInteger index = new AtomicInteger(); - if (debug) { + try (TensorScope scope = new TensorScope()) { + boolean isScalar = input.shape().equals(Shape.scalar()); + AtomicInteger index = new AtomicInteger(); + if (debug) { + try (TString result = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %s\n", + predicate.test(result.getObject()), result.getObject()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %s\n", + index.getAndIncrement(), predicate.test(f.getObject()), f.getObject())); + } + } + } + index.set(0); try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %s\n", - predicate.test(result.getObject()), result.getObject()); + assertTrue(predicate.test(result.getObject())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %s\n", - index.getAndIncrement(), predicate.test(f.getObject()), f.getObject())); + .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); } } } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getObject())); - } else { - result - .scalars() - .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); - } - } } /** @@ -846,160 +871,162 @@ public void evaluateString(Output input, Predicate predicate) { */ @Override public void evaluate(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - if (debug) { + try (TensorScope scope = new TensorScope()) { + AtomicInteger index = new AtomicInteger(); + Class inputType = input.type(); + boolean isScalar = input.shape().equals(Shape.scalar()); + if (inputType == TFloat32.class) { + if (debug) { + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", + predicate.test(result.getFloat()), result.getFloat()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + } + } + } + index.set(0); try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getFloat()), result.getFloat()); + assertTrue(predicate.test(result.getFloat())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getFloat()))); } } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getFloat())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getFloat()))); + } else if (inputType == TFloat64.class) { + if (debug) { + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %f\n", + predicate.test(result.getDouble()), result.getDouble()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %f\n", + index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); + } + } } - } - } else if (inputType == TFloat64.class) { - if (debug) { + index.set(0); try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getDouble()), result.getDouble()); + assertTrue(predicate.test(result.getDouble())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getDouble()))); } } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getDouble())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getDouble()))); + } else if (inputType == TInt32.class) { + if (debug) { + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); + } + } } - } - } else if (inputType == TInt32.class) { - if (debug) { + index.set(0); try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); + assertTrue(predicate.test(result.getInt())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); } } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getInt())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); + } else if (inputType == TInt64.class) { + if (debug) { + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", + predicate.test(result.getLong()), result.getLong()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); + } + } } - } - } else if (inputType == TInt64.class) { - if (debug) { + index.set(0); try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getLong()), result.getLong()); + assertTrue(predicate.test(result.getLong())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); } } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getLong())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); + } else if (inputType == TUint8.class) { + if (debug) { + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + System.out.printf( + "0). %b <==> %d\n", + predicate.test(result.getByte()), result.getByte()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> + System.out.printf( + "%d). %b <==> %d\n", + index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + } + } } - } - } else if (inputType == TUint8.class) { - if (debug) { + index.set(0); try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getByte()), result.getByte()); + assertTrue(predicate.test(result.getByte())); } else { result .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); + .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); } } + } else { + fail("Unexpected type class: " + inputType); } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getByte())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); - } - } - } else { - fail("Unexpected type class: " + inputType); } } @@ -1008,115 +1035,117 @@ public void evaluate(Output input, Predicate predic */ @Override public void print(PrintWriter writer, Output input) { - boolean isScalar = input.shape().size() == 1; + try (TensorScope scope = new TensorScope()) { + boolean isScalar = input.shape().size() == 1; - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf("%d). %f\n", index.getAndIncrement(), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + Class inputType = input.type(); + if (inputType == TFloat32.class) { + AtomicInteger index = new AtomicInteger(); + try (TFloat32 result = + (TFloat32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf("%d). %f\n", index.getAndIncrement(), result.getFloat()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); + } } - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TFloat64.class) { + AtomicInteger index = new AtomicInteger(); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %f\n", index.getAndIncrement(), result.getDouble()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + try (TFloat64 result = + (TFloat64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %f\n", index.getAndIncrement(), result.getDouble()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); + } } - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TInt32.class) { + AtomicInteger index = new AtomicInteger(); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(),result.getInt()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + try (TInt32 result = + (TInt32) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %d\n", index.getAndIncrement(), result.getInt()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); + } } - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TInt64.class) { + AtomicInteger index = new AtomicInteger(); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(), result.getLong()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + try (TInt64 result = + (TInt64) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %d\n", index.getAndIncrement(), result.getLong()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); + } } - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TUint8.class) { + AtomicInteger index = new AtomicInteger(); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %x\n", index.getAndIncrement(), result.getByte()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + try (TUint8 result = + (TUint8) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %x\n", index.getAndIncrement(), result.getByte()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); + } } - } - } else if (inputType == TBool.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TBool.class) { + AtomicInteger index = new AtomicInteger(); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %b\n", index.getAndIncrement(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + try (TBool result = + (TBool) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %b\n", index.getAndIncrement(), result.getBoolean()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); + } } - } - } else if (inputType == TString.class) { - AtomicInteger index = new AtomicInteger(); + } else if (inputType == TString.class) { + AtomicInteger index = new AtomicInteger(); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %s\n", index.getAndIncrement(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + try (TString result = + (TString) this.getGraphSession().runner().fetch(input).run(scope).get(0)) { + if (isScalar) { + writer.printf( + "%d). %s\n", index.getAndIncrement(), result.getObject()); + } else { + result + .scalars() + .forEachIndexed( + (idx, f) -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); + } } + } else { + writer.println("Unexpected type class: " + inputType); } - } else { - writer.println("Unexpected type class: " + inputType); + writer.flush(); } - writer.flush(); } }