From 28a34dd19fd88c892e7b3ea62489bb062fb9b11b Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Wed, 3 Mar 2021 11:21:50 -0500 Subject: [PATCH 01/13] Clean up generics, remove generics from class and fix call method to have generic. --- .../annotations/org/tensorflow/op/Ops.java | 12 ++-- .../framework/activations/Activation.java | 26 ++++----- .../tensorflow/framework/activations/ELU.java | 12 ++-- .../framework/activations/Exponential.java | 9 ++- .../framework/activations/HardSigmoid.java | 9 ++- .../framework/activations/Linear.java | 8 +-- .../framework/activations/ReLU.java | 6 +- .../framework/activations/SELU.java | 8 +-- .../framework/activations/Sigmoid.java | 9 ++- .../framework/activations/Softmax.java | 9 ++- .../framework/activations/Softplus.java | 6 +- .../framework/activations/Softsign.java | 9 ++- .../framework/activations/Swish.java | 7 +-- .../framework/activations/Tanh.java | 8 +-- .../initializers/BaseInitializer.java | 3 +- .../framework/initializers/Constant.java | 9 ++- .../framework/initializers/Glorot.java | 8 +-- .../tensorflow/framework/initializers/He.java | 12 ++-- .../framework/initializers/Identity.java | 8 +-- .../framework/initializers/Initializer.java | 11 ++-- .../framework/initializers/LeCun.java | 11 ++-- .../framework/initializers/Ones.java | 9 ++- .../framework/initializers/Orthogonal.java | 38 +++++++----- .../framework/initializers/RandomNormal.java | 29 ++++++---- .../framework/initializers/RandomUniform.java | 39 ++++++++----- .../initializers/TruncatedNormal.java | 33 +++++++---- .../initializers/VarianceScaling.java | 38 +++++++----- .../framework/initializers/Zeros.java | 8 +-- .../framework/activations/ELUTest.java | 10 +--- .../activations/ExponentialTest.java | 8 +-- .../activations/HardSigmoidTest.java | 8 +-- .../framework/activations/LinearTest.java | 6 +- .../framework/activations/ReLUTest.java | 20 +++---- .../framework/activations/SELUTest.java | 8 +-- .../framework/activations/SigmoidTest.java | 7 +-- .../framework/activations/SoftmaxTest.java | 21 +++---- .../framework/activations/SoftplusTest.java | 4 +- .../framework/activations/SoftsignTest.java | 4 +- .../framework/activations/SwishTest.java | 8 +-- .../framework/activations/TanhTest.java | 4 +- .../framework/initializers/ConstantTest.java | 16 ++--- .../framework/initializers/GlorotTest.java | 15 +++-- .../framework/initializers/HeTest.java | 15 +++-- .../framework/initializers/IdentityTest.java | 10 +--- .../framework/initializers/LeCunTest.java | 14 ++--- .../framework/initializers/OnesTest.java | 16 ++--- .../initializers/OrthogonalTest.java | 10 +--- .../initializers/RandomNormalTest.java | 9 +-- .../initializers/RandomUniformTest.java | 12 ++-- .../initializers/TruncatedNormalTest.java | 9 +-- .../initializers/VarianceScalingTest.java | 58 ++++++++----------- .../framework/initializers/ZerosTest.java | 16 ++--- 52 files changed, 327 insertions(+), 375 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index ea3ef31313e..acbae4dac6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -354,20 +354,20 @@ public final class Ops { public final SparseOps sparse; - public final TpuOps tpu; - public final BitwiseOps bitwise; + public final TpuOps tpu; + public final MathOps math; public final AudioOps audio; public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -385,13 +385,13 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - tpu = new TpuOps(this); bitwise = new BitwiseOps(this); + tpu = new TpuOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java index e1482a51a8a..104e3726ee4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java @@ -20,13 +20,8 @@ /** * Abstract base class for Activations - * - *

Note: The {@link #tf} attribute must be set prior to invoking the call method. See - * {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}. - * - * @param the data type of the activation */ -public abstract class Activation { +public abstract class Activation { /** The TensorFlow Ops */ protected Ops tf; @@ -41,21 +36,21 @@ protected Activation(Ops tf) { } /** - * Sets the TensorFlow Ops + * Gets the TensorFlow Ops * - * @param tf the TensorFlow Ops + * @return the TensorFlow Ops */ - protected void setTF(Ops tf) { - this.tf = tf; + protected Ops getTF() { + return this.tf; } /** - * Gets the TensorFlow Ops + * Sets the TensorFlow Ops * - * @return the TensorFlow Ops + * @param tf the TensorFlow Ops */ - protected Ops getTF() { - return this.tf; + protected void setTF(Ops tf) { + this.tf = tf; } /** @@ -63,6 +58,7 @@ protected Ops getTF() { * * @param input the input tensor * @return The operand for the activation + * @param the data type of the input and result */ - public abstract Operand call(Operand input); + public abstract Operand call(Operand input); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java index 2f2f16f2752..e7ad39fde62 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java @@ -17,7 +17,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Exponential linear unit. @@ -44,11 +44,10 @@ * Operand<TFloat32> result = elu.call(input); * * - * @param the data type of the activation * @see Clevert et al, 2016, Fast and Accurate Deep * Network Learning by Exponential Linear Units (ELUs) */ -public class ELU extends Activation { +public class ELU extends Activation { private static final double ALPHA_DEFAULT = 1.0; @@ -83,11 +82,12 @@ public ELU(Ops tf, double alpha) { * @return The operand for the activation */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { Operand result = tf.nn.elu(input); - if (alpha == 1.0) return result; - else { + if (alpha == 1.0) { + return result; + } else { Class inputType = input.type(); Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType)); Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), inputType)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java index d5fdff36c61..2debca9ddac 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Exponential activation function. @@ -30,10 +30,8 @@ * Operand<TFloat32> result = exp.call(input); * // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f] * - * - * @param the data type of the activation */ -public class Exponential extends Activation { +public class Exponential extends Activation { /** * Creates an Exponential activation. @@ -49,9 +47,10 @@ public Exponential(Ops tf) { * * @param input the input tensor * @return an Operand for the exponential activation: exp(x). + * @param the data type of the activation */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.math.exp(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java index 0b7cf573b8e..0e2fd6f9342 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Hard sigmoid activation. @@ -40,10 +40,8 @@ * Operand<TFloat32> result = hardSigmoid.call(input); * // result is [0.f , 0.3f, 0.5f, 0.7f, 1.f] * - * - * @param the data type of the result */ -public class HardSigmoid extends Activation { +public class HardSigmoid extends Activation { /** * Creates Hard sigmoid activation. @@ -59,9 +57,10 @@ public HardSigmoid(Ops tf) { * * @param input the input tensor * @return The operand for the activation + * @param the data type of the result */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { Class inputType = input.type(); Operand point2 = tf.dtypes.cast(tf.constant(0.2), inputType); Operand point5 = tf.dtypes.cast(tf.constant(0.5), inputType); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java index d907397995d..06aba774423 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java @@ -19,9 +19,9 @@ import org.tensorflow.types.family.TNumber; /** - * Linear activation function (pass-through). + * Linear activation function (pass-through). * - *

The linear activation returns its input. It is also known as the Identity activation function.

+ *

The linear activation returns its input. It is also known as the Identity activation function. * *

For example: * @@ -33,7 +33,7 @@ * // result is [-3.0f,-1.0f, 0.0f,1.0f,3.0f] * */ -public class Linear extends Activation { +public class Linear extends Activation { /** * Creates a linear activation. @@ -46,7 +46,7 @@ public Linear(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return input; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java index aef6ebf2992..974d266143f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java @@ -55,10 +55,8 @@ * result = relu.call(input); * // result is [-0.f, -0.f, 0.f, 0.f, 10.f] * - * - * @param the data type of the result */ -public class ReLU extends Activation { +public class ReLU extends Activation { public static final float ALPHA_DEFAULT = 0.0f; public static final float MAX_VALUE_DEFAULT = Float.NaN; @@ -96,7 +94,7 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { Class inputType = input.type(); boolean clipMax = !Float.isNaN(maxValue); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java index f24731049fb..d3328aeb110 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Scaled Exponential Linear Unit (SELU). @@ -42,10 +42,9 @@ *

Notes: To be used together with the {@link * org.tensorflow.framework.initializers.LeCun} initializer with Normal Distribution. * - * @param the data type of the activation * @see Klambauer et al., 2017 */ -public class SELU extends Activation { +public class SELU extends Activation { /** * Creates a Scaled Exponential Linear Unit (SELU) activation. @@ -61,9 +60,10 @@ public SELU(Ops tf) { * * @param input the input tensor * @return The operand for the activation + * @param the data type of the activation */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.nn.selu(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java index 5d507b38483..c4781ba50d9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Sigmoid activation. sigmoid(x) = 1 / (1 + exp(-x)). @@ -38,10 +38,8 @@ * // result is [2.0611537e-09f, 2.6894143e-01f, * // 5.0000000e-01f,7.3105860e-01f, 1.f] * - * - * @param the data type of the activation */ -public class Sigmoid extends Activation { +public class Sigmoid extends Activation { /** * Creates a Sigmoid activation. @@ -57,9 +55,10 @@ public Sigmoid(Ops tf) { * * @param input the input tensor * @return The operand for the activation + * @param the data type of the activation */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.math.sigmoid(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java index 154e1ecc84a..8051cee9c04 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java @@ -19,7 +19,7 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.core.ReduceMax; import org.tensorflow.op.core.ReduceSum; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Softmax converts a real vector to a vector of categorical probabilities. @@ -35,10 +35,8 @@ *

The softmax of each vector x is computed as: exp(x) / tf.sum(exp(x)). * *

The input values in are the log-odds of the resulting probability. - * - * @param the data type of the activation */ -public class Softmax extends Activation { +public class Softmax extends Activation { private static final int AXIS_DEFAULT = -1; @@ -70,9 +68,10 @@ public Softmax(Ops tf, int axis) { * * @param input the input tensor * @return The operand for the activation + * @param the data type of the activation */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { Shape shape = input.shape(); int numDimensions = shape.numDimensions(); if (numDimensions == 2) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java index 65a183ea047..762c0ae6bcd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Softplus activation function, softplus(x) = log(exp(x) + 1). @@ -32,7 +32,7 @@ * // 1.3132616e+00f, 2.0000000e+01f] * */ -public class Softplus extends Activation { +public class Softplus extends Activation { /** * Creates a Softplus activation function. @@ -50,7 +50,7 @@ public Softplus(Ops tf) { * @return The operand for the activation */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.math.softplus(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java index 1f691e71862..b6e9f874914 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Softsign activation function, softsign(x) = x / (abs(x) + 1). @@ -30,10 +30,8 @@ * Operand<TFloat32> result = softsign.call(input); * // result is [-0.5f, 0.f, 0.5f] * - * - * @param the data type of the activation */ -public class Softsign extends Activation { +public class Softsign extends Activation { /** * Creates a Softsign activation. @@ -49,9 +47,10 @@ public Softsign(Ops tf) { * * @param input the input tensor * @return The operand for the activation + * @param the data type of the activation */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.nn.softsign(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java index d9f73a422d5..5a6fc4c7765 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Swish activation function. swish(x) = x * sigmoid(x). @@ -37,10 +37,9 @@ * * * - * @param the data type of the activation * @see Ramachandran et al., 2017 */ -public class Swish extends Activation { +public class Swish extends Activation { /** * Creates a Swish activation, swish(x) = x * sigmoid(x). @@ -57,7 +56,7 @@ public Swish(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { // TODO Python Keras returns a "grad", which is an optimization not implemented in Java. return tf.math.mul(input, tf.math.sigmoid(input)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java index 4fe02eed048..a485638aa4a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Hyperbolic tangent activation function. @@ -30,10 +30,8 @@ * Operand<TFloat32> result = tanh.call(input); * // result = [-0.9950547f, -0.7615942f, 0.f, 0.7615942f, 0.9950547f] * - * - * @param the data type of the activation */ -public class Tanh extends Activation { +public class Tanh extends Activation { /** * Creates a Hyperbolic tangent activation. @@ -46,7 +44,7 @@ public Tanh(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.math.tanh(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java index 9c1fa9ac287..7efd02a6db1 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/BaseInitializer.java @@ -15,10 +15,9 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TType; /** Abstract base class for all Initializers */ -public abstract class BaseInitializer implements Initializer { +public abstract class BaseInitializer implements Initializer { protected final Ops tf; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java index 4a2df86d74b..5aae5b90e5b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java @@ -32,10 +32,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The Type for the call operation */ -public class Constant extends BaseInitializer { +public class Constant extends BaseInitializer { private final double doubleValue; private final long longValue; @@ -86,9 +84,10 @@ public Constant(Ops tf, boolean value) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); } switch (valueType) { case LONG: diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 894bd073758..5a3c291785f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -16,7 +16,6 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; /** * The Glorot initializer, also called Xavier initializer. @@ -58,16 +57,17 @@ * * *

NOTE: + * *

For a GlorotNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. + * *

For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. * - * @param The TType for the call operation * @see VarianceScaling.Distribution * @see Glorot et al., 2010 */ -public class Glorot extends VarianceScaling { +public class Glorot extends VarianceScaling { public static final double SCALE = 1.0; @@ -77,7 +77,7 @@ public class Glorot extends VarianceScaling { * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. * @see VarianceScaling.Distribution */ public Glorot(Ops tf, Distribution distribution, long seed) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index 3a91b72b0d0..ac64e449265 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; /** * He initializer. @@ -53,17 +52,18 @@ * * *

NOTE: + * *

For an HeNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. - *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} - * for the distribution parameter. * - * @param The TType for the call operation + *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} for + * the distribution parameter. + * * @see He * et al., 2015 */ -public class He extends VarianceScaling { +public class He extends VarianceScaling { public static final double SCALE = 2.0; @@ -73,7 +73,7 @@ public class He extends VarianceScaling { * @param tf the TensorFlow Ops * @param distribution The distribution type for the He initializer. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. * @see VarianceScaling.Distribution */ public He(Ops tf, Distribution distribution, long seed) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java index f672c9f1e85..e1f50e145a2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java @@ -19,7 +19,7 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; /** * Initializer that generates the identity matrix. @@ -34,10 +34,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class Identity extends BaseInitializer { +public class Identity extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; private final double gain; @@ -65,7 +63,7 @@ public Identity(Ops tf, double gain) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { Shape shape = ShapeUtils.toShape(tf.scope(), dims); if (shape.numDimensions() != 2) { throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions()); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java index 4beb218783b..032f9c92792 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java @@ -18,12 +18,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; -/** - * An interface for Initializers - * - * @param The data Type for initializer operation - */ -public interface Initializer { +/** An interface for Initializers */ +public interface Initializer { /** * Generates the operation used to perform the initialization. @@ -31,6 +27,7 @@ public interface Initializer { * @param dims the shape dimensions * @param type the type of tensor * @return An operand for the initialization. + * @param The data Type for initializer operation */ - Operand call(Operand dims, Class type); + Operand call(Operand dims, Class type); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java index 38e68ef688b..b82f40918c0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; /** * LeCun normal initializer. @@ -27,7 +26,7 @@ * stddev = sqrt(1 / fanIn) where fanIn is the number of input units in the * weight tensor. * - *

If the distribution is UNIFORM, itraws samples from a uniform distribution within + *

If the distribution is UNIFORM, it draws samples from a uniform distribution within * [-limit, limit], where limit = Math.sqrt(3 / fanIn) (fanIn is * the number of input units in the weight tensor) * @@ -59,14 +58,14 @@ * *

NOTE: * * - *

For a LeCunNormal equivalent initializer, use {@link VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * + *

For a LeCunNormal equivalent initializer, use {@link + * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * * *

For a LeCunUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * * for the distribution parameter. * * *

* - * @param The TType for the call operation * @see Self-Normalizing * Neural Networks, Klambauer et al., 2017 @@ -74,7 +73,7 @@ * al., 1998 * @see VarianceScaling.Distribution */ -public class LeCun extends VarianceScaling { +public class LeCun extends VarianceScaling { /** * Creates a LeCunNormal Initializer @@ -82,7 +81,7 @@ public class LeCun extends VarianceScaling { * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public LeCun(Ops tf, Distribution distribution, long seed) { super(tf, 1.0, Mode.FAN_IN, distribution, seed); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java index b8eb0c418e9..9094c2add97 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java @@ -32,10 +32,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class Ones extends BaseInitializer { +public class Ones extends BaseInitializer { /** * Creates an Initializer that sets all values to one. @@ -57,9 +55,10 @@ public Ones(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); } return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), type)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java index a5b466e118e..b6dffd2d768 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java @@ -21,7 +21,10 @@ import org.tensorflow.op.Ops; import org.tensorflow.op.linalg.Qr; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import static org.tensorflow.framework.utils.CastHelper.cast; /** * Initializer that generates an orthogonal matrix. @@ -44,10 +47,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class Orthogonal extends BaseInitializer { +public class Orthogonal extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; @@ -59,7 +60,7 @@ public class Orthogonal extends BaseInitializer { * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public Orthogonal(Ops tf, long seed) { this(tf, GAIN_DEFAULT, seed); @@ -71,7 +72,7 @@ public Orthogonal(Ops tf, long seed) { * @param tf the TensorFlow Ops * @param gain the gain to be applied to the Matrix. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public Orthogonal(Ops tf, double gain, long seed) { super(tf); @@ -81,7 +82,11 @@ public Orthogonal(Ops tf, double gain, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type)) { + throw new IllegalArgumentException("Tensor type must be numeric: " + type.getSimpleName()); + } + Class nType = (Class) type; Shape dimsShape = ShapeUtils.toShape(tf.scope(), dims); if (dimsShape.numDimensions() < 2) { throw new IllegalArgumentException( @@ -94,17 +99,18 @@ public Operand call(Operand dims, Class type) { long numCols = dimsShape.size(i); Shape flatShape = Shape.of(Math.max(numRows, numCols), Math.min(numRows, numCols)); long[] seeds = {seed, 0}; - Operand op = - tf.random.statelessRandomNormal(tf.constant(flatShape), tf.constant(seeds), type); + Operand op = + tf.random.statelessRandomNormal(tf.constant(flatShape), tf.constant(seeds), nType); + Qr.Options qrOptions = Qr.fullMatrices(false); - Qr qrOp = tf.linalg.qr(op, qrOptions); - Output qo = qrOp.q(); - Output ro = qrOp.r(); - Operand diagOp = - tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), type)); - Operand qop = tf.math.mul(qo, tf.math.sign(diagOp)); + Qr qrOp = tf.linalg.qr(op, qrOptions); + Output qo = qrOp.q(); + Output ro = qrOp.r(); + Operand diagOp = + tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), op.type())); + Operand qop = tf.math.mul(qo, tf.math.sign(diagOp)); if (numRows < numCols) qop = tf.linalg.transpose(qop, null); - return tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), type)); + return cast(tf, tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), op.type())), type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java index 38ab194a56b..8f8edcc56ca 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java @@ -17,7 +17,10 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import static org.tensorflow.framework.utils.CastHelper.cast; /** * Initializer that generates tensors with a normal distribution. @@ -31,10 +34,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class RandomNormal extends BaseInitializer { +public class RandomNormal extends BaseInitializer { public static final double MEAN_DEFAULT = 0.0; public static final double STDDEV_DEFAULT = 1.0; @@ -49,7 +50,7 @@ public class RandomNormal extends BaseInitializer { * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public RandomNormal(Ops tf, long seed) { this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); @@ -61,7 +62,7 @@ public RandomNormal(Ops tf, long seed) { * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public RandomNormal(Ops tf, double mean, long seed) { this(tf, mean, STDDEV_DEFAULT, seed); @@ -74,7 +75,7 @@ public RandomNormal(Ops tf, double mean, long seed) { * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public RandomNormal(Ops tf, double mean, double stddev, long seed) { super(tf); @@ -85,10 +86,16 @@ public RandomNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type)) { + throw new IllegalArgumentException("Tensor type must be numeric: " + type.getSimpleName()); + } long[] seeds = {seed, 0}; - Operand distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); - Operand op = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), type)); - return tf.math.add(op, tf.dtypes.cast(tf.constant(mean), type)); + @SuppressWarnings("unchecked") + Class nType = (Class) type; + Operand distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), nType); + Operand op = + tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), distOp.type())); + return cast(tf, tf.math.add(op, tf.dtypes.cast(tf.constant(mean), distOp.type())), type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java index 787af15f709..3c7d394d538 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java @@ -20,6 +20,9 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import static org.tensorflow.framework.utils.CastHelper.cast; /** * Initializer that generates tensors with a uniform distribution. @@ -33,10 +36,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class RandomUniform extends BaseInitializer { +public class RandomUniform extends BaseInitializer { public static final double MINVAL_DEFAULT = -0.05; public static final double MAXVAL_DEFAULT = 0.05; @@ -46,12 +47,12 @@ public class RandomUniform extends BaseInitializer { private final long seed; /** - * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and - * {@link #MAXVAL_DEFAULT} for the maxval + * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and {@link + * #MAXVAL_DEFAULT} for the maxval * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public RandomUniform(Ops tf, long seed) { this(tf, MINVAL_DEFAULT, MAXVAL_DEFAULT, seed); @@ -64,7 +65,7 @@ public RandomUniform(Ops tf, long seed) { * @param minval Lower bound of the range of random values to generate (inclusive). * @param maxval Upper bound of the range of random values to generate (exclusive). * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public RandomUniform(Ops tf, double minval, double maxval, long seed) { super(tf); @@ -75,28 +76,34 @@ public RandomUniform(Ops tf, double minval, double maxval, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - Operand distOp; + public Operand call(Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type)) { + throw new IllegalArgumentException("Tensor type must be numeric: " + type.getSimpleName()); + } + @SuppressWarnings("unchecked") + Class nType = (Class) type; + Operand distOp; if (TIntegral.class.isAssignableFrom(type)) { RandomUniformInt.Options options = RandomUniformInt.seed(this.seed); distOp = tf.random.randomUniformInt( dims, - tf.dtypes.cast(tf.constant(this.minval), type), - tf.dtypes.cast(tf.constant(this.maxval), type), + cast(tf, tf.constant(this.minval), nType), + cast(tf, tf.constant(this.maxval), nType), options); } else { long[] seeds = {seed, 0}; - distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); + distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), nType); if (this.minval == 0) { if (this.maxval != 1.0) { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval), type)); + distOp = tf.math.mul(distOp, cast(tf, tf.constant(this.maxval), distOp.type())); } } else { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval - this.minval), type)); - distOp = tf.math.add(distOp, tf.dtypes.cast(tf.constant(this.minval), type)); + distOp = + tf.math.mul(distOp, cast(tf, tf.constant(this.maxval - this.minval), distOp.type())); + distOp = tf.math.add(distOp, cast(tf, tf.constant(this.minval), distOp.type())); } } - return distOp; + return cast(tf, distOp, type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java index d3cfec26338..7969e6988fa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java @@ -17,7 +17,10 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import static org.tensorflow.framework.utils.CastHelper.cast; /** * Initializer that generates a truncated normal distribution. @@ -31,10 +34,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class TruncatedNormal extends BaseInitializer { +public class TruncatedNormal extends BaseInitializer { public static final double MEAN_DEFAULT = 0.0; public static final double STDDEV_DEFAULT = 0.05; @@ -49,7 +50,7 @@ public class TruncatedNormal extends BaseInitializer { * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public TruncatedNormal(Ops tf, long seed) { this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); @@ -62,7 +63,7 @@ public TruncatedNormal(Ops tf, long seed) { * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { super(tf); @@ -73,11 +74,19 @@ public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - long[] seeds = {seed,0}; - Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); - return tf.math.add( - tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)), - tf.dtypes.cast(tf.constant(mean), type)); + public Operand call(Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type)) { + throw new IllegalArgumentException("Tensor type must be numeric: " + type.getSimpleName()); + } + @SuppressWarnings("unchecked") + Class nType = (Class) type; + long[] seeds = {seed, 0}; + Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), nType); + return cast( + tf, + tf.math.add( + tf.math.mul(distOp, cast(tf, tf.constant(stddev), distOp.type())), + cast(tf, tf.constant(mean), distOp.type())), + type); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java index 5d951450505..583d5680c57 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java @@ -19,13 +19,16 @@ import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import static org.tensorflow.framework.utils.CastHelper.cast; /** * Initializer capable of adapting its scale to the shape of weights tensors. * - *

With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from - * a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after + *

With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from a + * truncated/untruncated normal distribution with a mean of zero and a standard deviation (after * truncation, if used) stddev = Math.sqrt(scale / n), where n is: * *