From dc94953db2b884fdc8d0208cf01a5e4231b3c332 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 26 Mar 2021 17:54:15 -0400 Subject: [PATCH 01/31] Moved high level tf.nn ops to framework. Moved tf.raw.nn Ops to tf.nn. Changed generation to generate SoftmaxCrossEntropyWithLogits and SparseSoftmaxCrossEntropyWithLogits to core NNOps (tf.nn). --- ...pi_def_SoftmaxCrossEntropyWithLogits.pbtxt | 2 +- ..._SparseSoftmaxCrossEntropyWithLogits.pbtxt | 2 +- .../annotations/org/tensorflow/op/NnOps.java | 175 +++--------------- .../org/tensorflow/op/NnRawOps.java | 84 --------- .../SoftmaxCrossEntropyWithLogits.java | 8 +- .../SparseSoftmaxCrossEntropyWithLogits.java | 8 +- .../op/nn/SigmoidCrossEntropyWithLogits.java | 14 +- .../op/nn/SoftmaxCrossEntropyWithLogits.java | 44 +++-- .../SparseSoftmaxCrossEntropyWithLogits.java | 47 +++-- 9 files changed, 107 insertions(+), 277 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/{raw => }/SoftmaxCrossEntropyWithLogits.java (94%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/{raw => }/SparseSoftmaxCrossEntropyWithLogits.java (94%) rename {tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow => tensorflow-framework/src/main/java/org/tensorflow/framework}/op/nn/SigmoidCrossEntropyWithLogits.java (91%) rename {tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow => tensorflow-framework/src/main/java/org/tensorflow/framework}/op/nn/SoftmaxCrossEntropyWithLogits.java (87%) rename {tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow => tensorflow-framework/src/main/java/org/tensorflow/framework}/op/nn/SparseSoftmaxCrossEntropyWithLogits.java (83%) diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt index 5dba2164cd6..e064562c0f2 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "SoftmaxCrossEntropyWithLogits" endpoint { - name: "nn.raw.SoftmaxCrossEntropyWithLogits" + name: "nn.SoftmaxCrossEntropyWithLogits" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt index cf80ff77565..7627d5f6074 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "SparseSoftmaxCrossEntropyWithLogits" endpoint { - name: "nn.raw.SparseSoftmaxCrossEntropyWithLogits" + name: "nn.SparseSoftmaxCrossEntropyWithLogits" } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java index 4f724578d14..0269d387859 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java @@ -83,7 +83,6 @@ import org.tensorflow.op.nn.Relu; import org.tensorflow.op.nn.Relu6; import org.tensorflow.op.nn.Selu; -import org.tensorflow.op.nn.SigmoidCrossEntropyWithLogits; import org.tensorflow.op.nn.Softmax; import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.op.nn.Softsign; @@ -103,8 +102,6 @@ * @see {@link Ops} */ public final class NnOps { - public final NnRawOps raw; - private final Scope scope; private final Ops ops; @@ -112,7 +109,6 @@ public final class NnOps { NnOps(Ops ops) { this.scope = ops.scope(); this.ops = ops; - raw = new NnRawOps(ops); } /** @@ -1795,56 +1791,6 @@ public Selu selu(Operand features) { return Selu.create(scope, features); } - /** - * Computes sigmoid cross entropy given logits. - * - *

Measures the probability error in discrete classification tasks in which each class is - * independent and not mutually exclusive. For instance, one could perform multilabel - * classification where a picture can contain both an elephant and a dog at the same time. - * - *

For brevity, let x = logits, z = labels. The logistic loss in - * pseudo-code is - * - *

-   *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
-   *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
-   *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
-   *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
-   *   = (1 - z) * x + log(1 + exp(-x))
-   *   = x - x * z + log(1 + exp(-x))
-   *  
- * - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above - * - *

-   *  x - x * z + log(1 + exp(-x))
-   *   = log(exp(x)) - x * z + log(1 + exp(-x))
-   *   = - x * z + log(1 + exp(x))
-   *  
- * - *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent - * formulation - * - *

-   *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
-   *  
- * - *

logits and labels must have the same type and shape. - * - *

- * - * @param scope The TensorFlow scope - * @param labels the labels - * @param logits the logits of type float32 or float64 - * @param the type of labels and logits - * @return the component-wise logistic losses. - * @throws IllegalArgumentException if logits' and labels' do not have the same shape - */ - public Operand sigmoidCrossEntropyWithLogits(Operand labels, - Operand logits) { - return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); - } - /** * Computes softmax activations. *

@@ -1861,54 +1807,20 @@ public Softmax softmax(Operand logits) { } /** - * Computes softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

While the classes are mutually exclusive, their probabilities need not be. All that is - * required is that each row of labels is a valid probability distribution. If they - * are not, the computation of the gradient will be incorrect. - * - *

If using exclusive labels (wherein one and only one class is true at a time), - * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} - * - *

Usage: - * - *

-   *    Operand<TFloat32> logits =
-   *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
-   *    Operand<TFloat32> labels =
-   *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
-   *    Operand<TFloat32> output =
-   *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
-   *    // output Shape = [2]
-   *    // dataType = FLOAT (1)
-   *    // values { 0.169846, 0.824745 }
-   *  
- * - *

Backpropagation will happen into both logits and labels. To - * disallow backpropagation into labels, pass label tensors through - * tf.stopGradient before feeding it to this function. + * Computes softmax cross entropy cost and gradients to backpropagate. + *

+ * Inputs are the logits, not probabilities. * - * @param scope current scope - * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] - * , each row of labels[i] must be a valid probability distribution. - * @param logits Per-label activations, typically a linear output. These activation energies are - * interpreted as unnormalized log probabilities. - * @param axis The class dimension. -1 is the last dimension. - * @param the number type of the operands - * @return the softmax cross entropy loss. Its type is the same as logits and its - * shape is the same as labels except that it does not have the last dimension of - * labels. + * @param data type for {@code loss()} output + * @param features batch_size x num_classes matrix + * @param labels batch_size x num_classes matrix + * The caller must ensure that each batch of labels represents a valid + * probability distribution. + * @return a new instance of SoftmaxCrossEntropyWithLogits */ - public Operand softmaxCrossEntropyWithLogits( - Operand labels, Operand logits, int axis) { - return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + public SoftmaxCrossEntropyWithLogits softmaxCrossEntropyWithLogits( + Operand features, Operand labels) { + return SoftmaxCrossEntropyWithLogits.create(scope, features, labels); } /** @@ -2100,51 +2012,24 @@ public SpaceToDepth spaceToDepth(Operand input, Long blo } /** - * Computes sparse softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

For this operation, the probability of a given label is considered exclusive. That is, soft - * classes are not allowed, and the labels vector must provide a single specific - * index for the true class for each row of logits (each minibatch entry). For soft - * softmax classification with a probability distribution for each entry, {@link - * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. - * - *

WARNING: - * - *

This op expects unscaled logits, since it performs a softmax on logits - * internally for efficiency. Do not call this op with the output of softmax, - * as it will produce incorrect results. - * - *

A common use case is to have logits of shape [batchSize, numClasses] and have - * labels of shape [batchSize], but higher dimensions are supported, in which case - * the dim-th dimension is assumed to be of size numClasses. - * logits must have the dataType of TFloat16, TFloat32 - * , or TFloat64, and labels must have the dtype of TInt32 - * or TInt64. - * - * @param scope current scope - * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r - * is rank of labels and result) and the dataType is TInt32 - * or TInt64. Each entry in labels must be an index in [0, - * numClasses). Other values will raise an exception when this op is run on CPU, and - * return NaN for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., - * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, - * or TFloat64. These activation energies are interpreted as unnormalized log - * probabilities. - * @return A Tensor of the same shape as labels and of the same type as - * logits with the softmax cross entropy loss. - * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank - * of the labels is not equal to the rank of the logits minus one. - */ - public Operand sparseSoftmaxCrossEntropyWithLogits( - Operand labels, Operand logits) { - return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); + * Computes softmax cross entropy cost and gradients to backpropagate. + *

+ * Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept + * a matrix of label probabilities, but rather a single label per row + * of features. This label is considered to have probability 1.0 for the + * given row. + *

+ * Inputs are the logits, not probabilities. + * + * @param data type for {@code loss()} output + * @param features batch_size x num_classes matrix + * @param labels batch_size vector with values in [0, num_classes). + * This is the label for the given minibatch entry. + * @return a new instance of SparseSoftmaxCrossEntropyWithLogits + */ + public SparseSoftmaxCrossEntropyWithLogits sparseSoftmaxCrossEntropyWithLogits( + Operand features, Operand labels) { + return SparseSoftmaxCrossEntropyWithLogits.create(scope, features, labels); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java deleted file mode 100644 index 13c6baa651a..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== -// -// This class has been generated, DO NOT EDIT! -// -package org.tensorflow.op; - -import org.tensorflow.Operand; -import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits; -import org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits; -import org.tensorflow.types.family.TNumber; - -/** - * An API for building {@code nn.raw} operations as {@link Op Op}s - * - * @see {@link Ops} - */ -public final class NnRawOps { - private final Scope scope; - - private final Ops ops; - - NnRawOps(Ops ops) { - this.scope = ops.scope(); - this.ops = ops; - } - - /** - * Computes softmax cross entropy cost and gradients to backpropagate. - *

- * Inputs are the logits, not probabilities. - * - * @param data type for {@code loss()} output - * @param features batch_size x num_classes matrix - * @param labels batch_size x num_classes matrix - * The caller must ensure that each batch of labels represents a valid - * probability distribution. - * @return a new instance of SoftmaxCrossEntropyWithLogits - */ - public SoftmaxCrossEntropyWithLogits softmaxCrossEntropyWithLogits( - Operand features, Operand labels) { - return SoftmaxCrossEntropyWithLogits.create(scope, features, labels); - } - - /** - * Computes softmax cross entropy cost and gradients to backpropagate. - *

- * Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept - * a matrix of label probabilities, but rather a single label per row - * of features. This label is considered to have probability 1.0 for the - * given row. - *

- * Inputs are the logits, not probabilities. - * - * @param data type for {@code loss()} output - * @param features batch_size x num_classes matrix - * @param labels batch_size vector with values in [0, num_classes). - * This is the label for the given minibatch entry. - * @return a new instance of SparseSoftmaxCrossEntropyWithLogits - */ - public SparseSoftmaxCrossEntropyWithLogits sparseSoftmaxCrossEntropyWithLogits( - Operand features, Operand labels) { - return SparseSoftmaxCrossEntropyWithLogits.create(scope, features, labels); - } - - /** - * Get the parent {@link Ops} object. - */ - public final Ops ops() { - return ops; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java similarity index 94% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java index 8032a4c2512..5d3ab3c1100 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.nn.raw; +package org.tensorflow.op.nn; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -34,7 +34,7 @@ * * @param data type for {@code loss()} output */ -@Operator(group = "nn.raw") +@Operator(group = "nn") public final class SoftmaxCrossEntropyWithLogits extends RawOp { /** @@ -53,7 +53,7 @@ public static SoftmaxCrossEntropyWithLogits create(Scope opBuilder.addInput(features.asOutput()); opBuilder.addInput(labels.asOutput()); opBuilder = scope.apply(opBuilder); - return new SoftmaxCrossEntropyWithLogits(opBuilder.build()); + return new SoftmaxCrossEntropyWithLogits<>(opBuilder.build()); } /** @@ -80,6 +80,6 @@ private SoftmaxCrossEntropyWithLogits(Operation operation) { super(operation); int outputIdx = 0; loss = operation.output(outputIdx++); - backprop = operation.output(outputIdx++); + backprop = operation.output(outputIdx); } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java similarity index 94% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 67650760b1c..794beab4ded 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.nn.raw; +package org.tensorflow.op.nn; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -39,7 +39,7 @@ * * @param data type for {@code loss()} output */ -@Operator(group = "nn.raw") +@Operator(group = "nn") public final class SparseSoftmaxCrossEntropyWithLogits extends RawOp { /** @@ -57,7 +57,7 @@ public static SparseSoftmaxCrossEntropyWithLogits create( opBuilder.addInput(features.asOutput()); opBuilder.addInput(labels.asOutput()); opBuilder = scope.apply(opBuilder); - return new SparseSoftmaxCrossEntropyWithLogits(opBuilder.build()); + return new SparseSoftmaxCrossEntropyWithLogits<>(opBuilder.build()); } /** @@ -84,6 +84,6 @@ private SparseSoftmaxCrossEntropyWithLogits(Operation operation) { super(operation); int outputIdx = 0; loss = operation.output(outputIdx++); - backprop = operation.output(outputIdx++); + backprop = operation.output(outputIdx); } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java similarity index 91% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java index 92c413f7e52..b55385839d3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java @@ -1,4 +1,4 @@ -package org.tensorflow.op.nn; +package org.tensorflow.framework.op.nn; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; @@ -8,11 +8,17 @@ import org.tensorflow.op.core.Select; import org.tensorflow.op.core.ZerosLike; import org.tensorflow.op.dtypes.Cast; -import org.tensorflow.op.math.*; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Exp; +import org.tensorflow.op.math.GreaterEqual; +import org.tensorflow.op.math.Log1p; +import org.tensorflow.op.math.Mul; +import org.tensorflow.op.math.Neg; +import org.tensorflow.op.math.Sub; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -@Operator(group = "nn") +//@Operator(group = "nn") public class SigmoidCrossEntropyWithLogits { /** @@ -60,7 +66,7 @@ public class SigmoidCrossEntropyWithLogits { * @return the component-wise logistic losses. * @throws IllegalArgumentException if logits' and labels' do not have the same shape */ - @Endpoint(name = "sigmoidCrossEntropyWithLogits") + //@Endpoint(name = "sigmoidCrossEntropyWithLogits") public static Operand sigmoidCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits) { if (!isCompatible(labels.shape(), logits.shape())) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java similarity index 87% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java index ddeacbea4d4..0f5b8197f1e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -1,11 +1,15 @@ -package org.tensorflow.op.nn; +package org.tensorflow.framework.op.nn; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.core.*; +import org.tensorflow.op.core.Concat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Slice; import org.tensorflow.op.dtypes.Cast; import org.tensorflow.op.linalg.Transpose; import org.tensorflow.op.math.Sub; @@ -14,12 +18,11 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.List; -@Operator(group = "nn") +// @Operator(group = "nn") public class SoftmaxCrossEntropyWithLogits { /** @@ -68,6 +71,7 @@ public class SoftmaxCrossEntropyWithLogits { * shape is the same as labels except that it does not have the last dimension of * labels. */ + @SuppressWarnings("unchecked") @Endpoint(name = "softmaxCrossEntropyWithLogits") public static Operand softmaxCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits, int axis) { @@ -78,7 +82,9 @@ public static Operand softmaxCrossEntr } if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { - Operand result = softmaxCrossEntropyWithLogits(scope, + Operand result = + softmaxCrossEntropyWithLogits( + scope, Cast.create(scope, labels, TFloat32.class), Cast.create(scope, logits, TFloat32.class), axis); @@ -86,10 +92,8 @@ public static Operand softmaxCrossEntr } if (logits.asOutput().type() != labels.asOutput().type()) { - return softmaxCrossEntropyWithLogits(scope, - Cast.create(scope, labels, logits.asOutput().type()), - logits, - axis); + return softmaxCrossEntropyWithLogits( + scope, Cast.create(scope, labels, logits.asOutput().type()), logits, axis); } Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); @@ -101,13 +105,20 @@ public static Operand softmaxCrossEntr labels = moveDimToEnd(scope, labels, axis, inputRank); } + Operand tLabels; + if (labels.type() != logits.type()) { + tLabels = Cast.create(scope, labels, logits.type()); + } else { + // Unchecked warning checked in if statement. + tLabels = (Operand) labels; + } + Shape inputShape = logits.shape(); logits = flattenOuterDims(scope, logits); - labels = flattenOuterDims(scope, labels); + tLabels = flattenOuterDims(scope, tLabels); - org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits.create( - scope, logits, (Operand)labels); + org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits.create(scope, logits, tLabels); /* cannot use generic on cost, because cost may be recast later. */ Operand cost = smax.loss(); Operand outputShape = @@ -119,6 +130,9 @@ public static Operand softmaxCrossEntr cost = Reshape.create(scope, cost, outputShape); if (scope.env().isGraph() && !shape.hasUnknownDimension()) { long[] array = shape.asArray(); + if (array == null) { + array = new long[0]; + } long[] newArray = new long[array.length - 1]; if (axis < 0) { axis = shape.numDimensions() + axis; @@ -153,7 +167,7 @@ private static Operand flattenOuterDims(Scope scope, Oper boolean productValid = true; for (int i = ndims - 2; i >= 0; i--) { long d = shape.size(i); - if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) { + if (d == Shape.UNKNOWN_SIZE) { productValid = false; break; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java similarity index 83% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 54b32bb5c63..64faa7c5d70 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -1,11 +1,10 @@ -package org.tensorflow.op.nn; +package org.tensorflow.framework.op.nn; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.AssertThat; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Reshape; @@ -22,7 +21,7 @@ import java.util.Collections; import java.util.List; -@Operator(group = "nn") +// @Operator(group = "nn") public class SparseSoftmaxCrossEntropyWithLogits { /** @@ -63,19 +62,24 @@ public class SparseSoftmaxCrossEntropyWithLogits { * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, * or TFloat64. These activation energies are interpreted as unnormalized log * probabilities. + * @param the data type for the labels + * @param the data tyoe for the loss and logits. * @return A Tensor of the same shape as labels and of the same type as * logits with the softmax cross entropy loss. * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank * of the labels is not equal to the rank of the logits minus one. */ + @SuppressWarnings("unchecked") @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") - public static Operand sparseSoftmaxCrossEntropyWithLogits( - Scope scope, Operand labels, Operand logits) { + public static + Operand sparseSoftmaxCrossEntropyWithLogits( + Scope scope, Operand labels, Operand logits) { scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); - /** cannot use generics on preciseLogits as it may be recast later */ - Operand preciseLogits = logits; + Operand preciseLogits; if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { preciseLogits = Cast.create(scope, logits, TFloat32.class); + } else { + preciseLogits = logits; } Shape labelsStaticShape = labels.shape(); org.tensorflow.op.core.Shape labelsShape = @@ -108,14 +112,16 @@ public static Operand sparseSoftmaxCrossE } // Check if no reshapes are required. if (logitsShape.numDimensions() == 2) { - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create( scope, preciseLogits, labels); - Operand loss = smax.loss(); - if (logits.asOutput().type() == TFloat16.class) { - loss = Cast.create(scope, loss, TFloat16.class); + Operand cost = smax.loss(); + if (cost.type() != logits.type()) { + return Cast.create(scope, cost, logits.type()); + } else { + // Unchecked cast already checked with previous if + return (Operand) cost; } - return loss; } List shapeChecks = new ArrayList<>(); @@ -145,14 +151,17 @@ public static Operand sparseSoftmaxCrossE preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses)); labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1)); scope.withControlDependencies(shapeChecks); - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( + // call raw op + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create( scope, preciseLogits, labels); - Operand cost = smax.loss(); + Operand cost = smax.loss(); cost = Reshape.create(scope, cost, labelsShape); - if (logits.asOutput().type() == TFloat16.class) { - cost = Cast.create(scope, cost, TFloat16.class); + if (cost.type() != logits.type()) { + return Cast.create(scope, cost, logits.type()); + } else { + // Unchecked cast already checked with previous if + return (Operand) cost; } - return cost; } } From 1878b609d82996c3376b28c5d1e7338dfc6e80f1 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 26 Mar 2021 18:02:55 -0400 Subject: [PATCH 02/31] Added FrameworkOps analogous to Ops. Added NnOps and SetOps as groups. Fixed MetricsHelper and Losses to use the bew FrameworkOps. Moved SetsOps to framework.op. --- .../tensorflow/framework/losses/Losses.java | 17 +- .../framework/metrics/impl/MetricsHelper.java | 4 +- .../tensorflow/framework/op/FrameworkOps.java | 136 ++++++++++++ .../org/tensorflow/framework/op/NnOps.java | 197 ++++++++++++++++++ .../{metrics/impl => op}/SetsOps.java | 64 +++--- .../SparseSoftmaxCrossEntropyWithLogits.java | 3 +- .../{SetsOpsTest.java => SetOpsTest.java} | 18 +- 7 files changed, 398 insertions(+), 41 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java rename tensorflow-framework/src/main/java/org/tensorflow/framework/{metrics/impl => op}/SetsOps.java (75%) rename tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/{SetsOpsTest.java => SetOpsTest.java} (86%) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 9aa94cf7fcf..aa5fa4ada6d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.op.core.ReduceAll; import org.tensorflow.op.core.ReduceMax; import org.tensorflow.op.core.ReduceSum; @@ -181,7 +182,8 @@ public static Operand binaryCrossentropy( */ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { - if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); + FrameworkOps fop = FrameworkOps.create(tf); + if (fromLogits) { return fop.nn.sigmoidCrossEntropyWithLogits(target, output);} /* TODO - skip this logic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { @@ -191,7 +193,7 @@ private static Operand binaryCrossentropyHelper( // TODO if (output.op().numInputess() != 1) // TODO throw new IllegalArgumentException("output can only have 1 output"); // TODO output = output.op().inout(0); - // TODO return tf.nn.sigmoidCrossEntropyWithLogits(target, output); + // TODO return fop.nn.sigmoidCrossEntropyWithLogits(target, output); // TODO} } */ @@ -235,6 +237,7 @@ public static Operand categoricalCrossentropy( boolean fromLogits, float labelSmoothing, int axis) { + FrameworkOps fop = FrameworkOps.create(tf); Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -245,7 +248,7 @@ public static Operand categoricalCrossentropy( tLabels = smoothCategoricalLabels(tf, tLabels, labelSmoothing); } if (fromLogits) { - return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis); + return fop.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis); } /* TODO if (!(predictions instanceof Variable) && (!tf.scope().env().isEager())) { @@ -255,7 +258,7 @@ public static Operand categoricalCrossentropy( if (predictions.op().numOutputs() != 1) throw new IllegalArgumentException("output can only have 1 output"); predictions = predictions.op().output(0); - return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); + return fop.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); } } */ @@ -516,6 +519,7 @@ public static Operand sparseCategoricalCrossentropy( boolean fromLogits, int axis) { Class predictionType = predictions.type(); + FrameworkOps fop = FrameworkOps.create(tf); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); @@ -568,9 +572,8 @@ public static Operand sparseCategoricalCrossentropy( tf.constant( new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)})); } - - @SuppressWarnings("unchecked") - Operand loss = tf.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); + + Operand loss = fop.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); if (updateShape && predictionsRank >= 3) { Shape newShape = predictionsShape.take(predictionsShape.numDimensions() - 1); loss = tf.reshape(loss, tf.constant(newShape)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 8a352322f52..a82e1760d1f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -16,6 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -174,12 +175,13 @@ private static Operand canBroadcastNonscalarShapes( private static Operand canBroadcastDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("canBroadcastDims"); + FrameworkOps fops = FrameworkOps.create(tf); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2D = tf.expandDims(weightsShape, tf.constant(-1)); - Operand diffResult = SetsOps.difference(tf, weightsShape2D, validDims); + Operand diffResult = fops.sets.difference(weightsShape2D, validDims); Operand numInvalidDims = tf.size(diffResult); return tf.math.equal(tf.constant(0), numInvalidDims); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java new file mode 100644 index 00000000000..cecbecfed15 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.DeviceSpec; +import org.tensorflow.EagerSession; +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; + +/** + * An API for building framework operations as {@link Op Op}s + * + *

These are higher level ops that may invoke core ops. Higher level Ops may perform the + * operation solely in the TensorFlow framework or do preprocessing of the Operands before invoking + * a core level Op. + */ +public class FrameworkOps { + public final Ops coreOps; + private final Scope scope; + + public final NnOps nn; + public final SetsOps sets; + + /** + * Creates a FrameworkOps instance with the provided scope + * + * @param scope the scope + */ + private FrameworkOps(Scope scope) { + this.coreOps = Ops.create(scope.env()); + this.scope = scope; + nn = new NnOps(this); + sets = new SetsOps(this); + } + + /** + * Creates a FrameworkOps instance based on the provided Core Ops + * + * @param coreOps The TensorFlow Core Ops + */ + private FrameworkOps(Ops coreOps) { + this.coreOps = coreOps; + this.scope = coreOps.scope(); + nn = new NnOps(this); + sets = new SetsOps(this); + } + + + /** Returns the current {@link Scope scope} of this API */ + public final Scope scope() { + return scope; + } + + /** + * Gets the core Ops + * + * @return coreOps + */ + public final Ops coreOps() { + return coreOps; + } + + /** + * Returns an API that builds operations with the provided name prefix. + * + *

@link Scope#withSubScope(String)} + */ + public FrameworkOps withSubScope(String childScopeName) { + return new FrameworkOps(scope.withSubScope(childScopeName)); + } + + /** + * Returns an API that uses the provided name for an op. + * + *

{@link Scope#withName(String)} + */ + public FrameworkOps withName(String opName) { + return new FrameworkOps(scope.withName(opName)); + } + + /** + * Returns an API that places the created operations on the device(s) matching the provided spec. + * + *

{@link Scope#withDevice(DeviceSpec)} + */ + public FrameworkOps withDevice(DeviceSpec deviceSpec) { + return new FrameworkOps(scope.withDevice(deviceSpec)); + } + + /** + * Returns an API that adds operations to the graph with the provided control dependencies. + * + *

{@link Scope#withControlDependencies(Iterable)} + */ + public FrameworkOps withControlDependencies(Iterable controls) { + return new FrameworkOps(scope.withControlDependencies(controls)); + } + + /** Creates an API for building operations in the provided execution environment */ + public static FrameworkOps create(ExecutionEnvironment env) { + return new FrameworkOps(new Scope(env)); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + *

Invoking this method is equivalent to {@code + * FrameworkOps.create(EagerSession.getDefault())}. + */ + public static FrameworkOps create() { + return new FrameworkOps(new Scope(EagerSession.getDefault())); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + * @param coreOps the TensorFlow core Ops + */ + public static FrameworkOps create(Ops coreOps) { + return new FrameworkOps(coreOps); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java new file mode 100644 index 00000000000..4054f3ddbb5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java @@ -0,0 +1,197 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.framework.op.nn.SigmoidCrossEntropyWithLogits; +import org.tensorflow.framework.op.nn.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.framework.op.nn.SparseSoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.types.family.TNumber; + +/** + * An API for building {@code nn} operations as {@link Op Op}s + * + *

These are higher level ops that may invoke core ops. Higher level Ops may perform the + * operation solely in the TensorFlow framework or do preprocessing of the Operands before invoking + * a core level Op. + * + *

{@link FrameworkOps} + */ +public class NnOps { + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * @param frameworkOps the TensorFLow framework Ops + */ + NnOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Computes sigmoid cross entropy given logits. + * + *

Measures the probability error in discrete classification tasks in which each class is + * independent and not mutually exclusive. For instance, one could perform multilabel + * classification where a picture can contain both an elephant and a dog at the same time. + * + *

For brevity, let x = logits, z = labels. The logistic loss in + * pseudo-code is + * + *

+     *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+     *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
+     *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
+     *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
+     *   = (1 - z) * x + log(1 + exp(-x))
+     *   = x - x * z + log(1 + exp(-x))
+     *  
+ * + *

For x < 0, to avoid overflow in exp(-x), we reformulate the above + * + *

+     *  x - x * z + log(1 + exp(-x))
+     *   = log(exp(x)) - x * z + log(1 + exp(-x))
+     *   = - x * z + log(1 + exp(x))
+     *  
+ * + *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent + * formulation + * + *

+     *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
+     *  
+ * + *

logits and labels must have the same type and shape. + * + *

+ * + * @param labels the labels + * @param logits the logits of type float32 or float64 + * @param the type of labels and logits + * @return the component-wise logistic losses. + * @throws IllegalArgumentException if logits' and labels' do not have the same shape + */ + public Operand sigmoidCrossEntropyWithLogits(Operand labels, + Operand logits) { + return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); + } + + /** + * Computes softmax cross entropy between logits and labels. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of labels is a valid probability distribution. If they + * are not, the computation of the gradient will be incorrect. + * + *

If using exclusive labels (wherein one and only one class is true at a time), + * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} + * + *

Usage: + * + *

+     *    Operand<TFloat32> logits =
+     *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
+     *    Operand<TFloat32> labels =
+     *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
+     *    Operand<TFloat32> output =
+     *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+     *    // output Shape = [2]
+     *    // dataType = FLOAT (1)
+     *    // values { 0.169846, 0.824745 }
+     *  
+ * + *

Backpropagation will happen into both logits and labels. To + * disallow backpropagation into labels, pass label tensors through + * tf.stopGradient before feeding it to this function. + * + * @param labels Each vector along the class dimension should hold a valid probability + * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] + * , each row of labels[i] must be a valid probability distribution. + * @param logits Per-label activations, typically a linear output. These activation energies are + * interpreted as unnormalized log probabilities. + * @param axis The class dimension. -1 is the last dimension. + * @param the number type of the operands + * @return the softmax cross entropy loss. Its type is the same as logits and its + * shape is the same as labels except that it does not have the last dimension of + * labels. + */ + public Operand softmaxCrossEntropyWithLogits( + Operand labels, Operand logits, int axis) { + return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + } + + /** + * Computes sparse softmax cross entropy between logits and labels. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

For this operation, the probability of a given label is considered exclusive. That is, soft + * classes are not allowed, and the labels vector must provide a single specific + * index for the true class for each row of logits (each minibatch entry). For soft + * softmax classification with a probability distribution for each entry, {@link + * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. + * + *

WARNING: + * + *

This op expects unscaled logits, since it performs a softmax on logits + * internally for efficiency. Do not call this op with the output of softmax, + * as it will produce incorrect results. + * + *

A common use case is to have logits of shape [batchSize, numClasses] and have + * labels of shape [batchSize], but higher dimensions are supported, in which case + * the dim-th dimension is assumed to be of size numClasses. + * logits must have the dataType of TFloat16, TFloat32 + * , or TFloat64, and labels must have the dtype of TInt32 + * or TInt64. + * + * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r + * is rank of labels and result) and the dataType is TInt32 + * or TInt64. Each entry in labels must be an index in [0, + * numClasses). Other values will raise an exception when this op is run on CPU, and + * return NaN for corresponding loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., + * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, + * or TFloat64. These activation energies are interpreted as unnormalized log + * probabilities. + * @param The data type for the labels + * @param The data type for the logits and loss + * @return the loss + * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank + * of the labels is not equal to the rank of the logits minus one. + */ + + public Operand sparseSoftmaxCrossEntropyWithLogits( + Operand labels, Operand logits) { + return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); + } + + +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java similarity index 75% rename from tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java index 467dea19b57..d7833cdbb06 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java @@ -12,26 +12,40 @@ See the License for the specific language governing permissions and limitations under the License. =======================================================================*/ -package org.tensorflow.framework.metrics.impl; +package org.tensorflow.framework.op; import org.tensorflow.Operand; -import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; import org.tensorflow.op.SparseOps; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.dtypes.Cast; import org.tensorflow.op.sparse.DenseToDenseSetOperation; +import org.tensorflow.op.sparse.SparseToDense; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - /** Implementation of set operations */ public class SetsOps { + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + SetsOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + /** * Computes set difference of elements in last dimension of a and b with * aMinusB set to true. * *

All but the last dimension of a and b must match * - * @param tf the TensorFlow Ops * @param a The first operand representing set a * @param b The other operand representing set b * @param the data type for the sets @@ -39,8 +53,8 @@ public class SetsOps { * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ - public static Operand difference(Ops tf, Operand a, Operand b) { - return difference(tf, a, b, true); + public Operand difference(Operand a, Operand b) { + return difference(a, b, true); } /** @@ -48,7 +62,6 @@ public static Operand difference(Ops tf, Operand a, Op * *

All but the last dimension of a and b must match * - * @param tf the TensorFlow Ops * @param a The first operand representing set a * @param b The other operand representing set b * @param aMinusB whether to subtract b from a, vs vice versa. @@ -57,15 +70,13 @@ public static Operand difference(Ops tf, Operand a, Op * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ - public static Operand difference( - Ops tf, Operand a, Operand b, boolean aMinusB) { - return setOperation(tf, a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); + public Operand difference(Operand a, Operand b, boolean aMinusB) { + return setOperation(a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); } /** * Computes set union of elements in last dimension of a and b. * - * @param tf the TensorFlow Ops * @param a The first operand representing set a * @param b The other operand representing set b * @param the data type for the sets @@ -73,14 +84,13 @@ public static Operand difference( * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ - public static Operand union(Ops tf, Operand a, Operand b) { - return setOperation(tf, a, b, Operation.UNION); + public Operand union(Operand a, Operand b) { + return setOperation(a, b, Operation.UNION); } /** * Computes set intersection of elements in last dimension of a and b. * - * @param tf the TensorFlow Ops * @param a The first operand representing set a * @param b The other operand representing set b * @param the data type for the sets @@ -88,14 +98,13 @@ public static Operand union(Ops tf, Operand a, Operand * last dimension the * same. Elements along the last dimension contain the results of the set * operation. */ - public static Operand intersection(Ops tf, Operand a, Operand b) { - return setOperation(tf, a, b, Operation.INTERSECTION); + public Operand intersection(Operand a, Operand b) { + return setOperation(a, b, Operation.INTERSECTION); } /** * Compute set operation of elements in last dimension of a and b. * - * @param tf the TensorFlow Ops * @param a The first set operation operand * @param b The other et operation operand * @param setOperation The set operation to perform, {@link Operation}. @@ -104,18 +113,23 @@ public static Operand intersection(Ops tf, Operand a, * last dimension the same. Elements along the last dimension contain the results of the set * operation. */ - public static Operand setOperation( - Ops tf, Operand a, Operand b, Operation setOperation) { + public Operand setOperation( + Operand a, Operand b, Operation setOperation) { DenseToDenseSetOperation setOperationResult = - tf.sparse.denseToDenseSetOperation( - a, b, setOperation.getSetOperation(), DenseToDenseSetOperation.validateIndices(true)); - - return tf.sparse.sparseToDense( + DenseToDenseSetOperation.create( + scope, + a, + b, + setOperation.getSetOperation(), + DenseToDenseSetOperation.validateIndices(true)); + + return SparseToDense.create( + scope, setOperationResult.resultIndices(), setOperationResult.resultShape(), setOperationResult.resultValues(), - cast(tf, tf.constant(0), a.type())); + Cast.create(scope, Constant.scalarOf(scope, 0), a.type())); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 64faa7c5d70..75766cf9bfb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -64,8 +64,7 @@ public class SparseSoftmaxCrossEntropyWithLogits { * probabilities. * @param the data type for the labels * @param the data tyoe for the loss and logits. - * @return A Tensor of the same shape as labels and of the same type as - * logits with the softmax cross entropy loss. + * @return the loss * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank * of the labels is not equal to the rank of the logits minus one. */ diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetOpsTest.java similarity index 86% rename from tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetOpsTest.java index eceff2797f8..e10f016bd94 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetOpsTest.java @@ -2,6 +2,8 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.framework.op.FrameworkOps; +import org.tensorflow.framework.op.SetsOps; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -15,7 +17,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; -class SetsOpsTest { +class SetOpsTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -28,6 +30,7 @@ public void testSetIntersectionMultirow2() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 5}}); int[][] expected = new int[][] {{1, 9}, {0, 0}}; @@ -35,7 +38,7 @@ public void testSetIntersectionMultirow2() { for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); + Operand intersection = fops.sets.intersection(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } @@ -49,6 +52,7 @@ public void testSetIntersectionDuplicates2d() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{1, 1, 3}}); Operand b = tf.constant(new int[][] {{1, 1}}); int[][] expected = {{1}}; @@ -56,7 +60,7 @@ public void testSetIntersectionDuplicates2d() { for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); + Operand intersection = fops.sets.intersection(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); @@ -72,6 +76,7 @@ public void testDenseSetDifferenceMultirow2d() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{1, 5, 9}, {4, 5, 3}}); Operand b = tf.constant(new int[][] {{1, 2, 6}, {1, 2, 2}}); @@ -81,14 +86,14 @@ public void testDenseSetDifferenceMultirow2d() { int[][] expected = {{5, 9, 0}, {3, 4, 5}}; // a- b Shape expectedShape = Shape.of(2, 3); - Operand intersection = SetsOps.difference(tf, aa, bb); + Operand intersection = fops.sets.difference(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); // b - a expected = new int[][] {{2, 6}, {1, 2}}; expectedShape = Shape.of(2, 2); - intersection = SetsOps.difference(tf, aa, bb, false); + intersection = fops.sets.difference(aa, bb, false); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); @@ -103,6 +108,7 @@ public void testDenseUnionMultirow2d() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 2}}); int[][] expected = new int[][] {{5, 0}, {3, 4}}; @@ -111,7 +117,7 @@ public void testDenseUnionMultirow2d() { Operand bb = cast(tf, b, type); Shape expectedShape = Shape.of(2, 2); // a- b - Operand intersection = SetsOps.difference(tf, aa, bb); + Operand intersection = fops.sets.difference(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } From 9225a48b7119f0fdc163deee9fe15607708a18ca Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 27 Mar 2021 15:21:11 -0400 Subject: [PATCH 03/31] Added FrameworkOps analogous to Ops. Added NnOps and SetOps as groups. Fixed MetricsHelper and Losses to use the bew FrameworkOps. Moved SetsOps to framework.op. --- .../src/main/java/org/tensorflow/framework/losses/Losses.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index aa5fa4ada6d..33c8d50409d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -572,7 +572,7 @@ public static Operand sparseCategoricalCrossentropy( tf.constant( new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)})); } - + Operand loss = fop.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); if (updateShape && predictionsRank >= 3) { Shape newShape = predictionsShape.take(predictionsShape.numDimensions() - 1); From caab79bf3c58344bdf675087a25dac399e837462 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 27 Mar 2021 15:36:41 -0400 Subject: [PATCH 04/31] Move l2Normalize to MathOps --- .../tensorflow/framework/losses/Losses.java | 23 ++----- .../tensorflow/framework/op/FrameworkOps.java | 3 + .../org/tensorflow/framework/op/MathOps.java | 67 +++++++++++++++++++ 3 files changed, 74 insertions(+), 19 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 33c8d50409d..398588cee67 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -337,13 +337,14 @@ public static Operand categoricalHinge( */ public static Operand cosineSimilarity( Ops tf, Operand labels, Operand predictions, int[] axis) { + FrameworkOps fops = FrameworkOps.create(tf); Operand tLabels = cast(tf, labels, predictions.type()); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - tLabels = l2Normalize(tf, tLabels, axis); - predictions = l2Normalize(tf, predictions, axis); + tLabels = fops.math.l2Normalize(tLabels, axis); + predictions = fops.math.l2Normalize(predictions, axis); Operand mathMul = tf.math.mul(tLabels, predictions); return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); } @@ -651,23 +652,7 @@ private static Operand smoothCategoricalLabels( return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses)); } - // TODO this was tf.math.l2_normalize in TF Python - /** - * Normalizes along dimension axis using an L2 norm. - * - * @param tf The TensorFlow Ops - * @param x the input - * @param axis Dimension along which to normalize. - * @param the data type for the input and the result - * @return the normalized values based on L2 norm - */ - public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { - Operand squareSum = - tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); - Operand invNorm = - tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); - return tf.math.mul(x, invNorm); - } + /** * Converts binary labels into -1/1. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java index cecbecfed15..18fb8ada6b7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -34,6 +34,7 @@ public class FrameworkOps { public final NnOps nn; public final SetsOps sets; + public final MathOps math; /** * Creates a FrameworkOps instance with the provided scope @@ -45,6 +46,7 @@ private FrameworkOps(Scope scope) { this.scope = scope; nn = new NnOps(this); sets = new SetsOps(this); + math = new MathOps(this); } /** @@ -57,6 +59,7 @@ private FrameworkOps(Ops coreOps) { this.scope = coreOps.scope(); nn = new NnOps(this); sets = new SetsOps(this); + math = new MathOps(this); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java new file mode 100644 index 00000000000..57a18fc63c2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.Maximum; +import org.tensorflow.op.math.Mul; +import org.tensorflow.op.math.Rsqrt; +import org.tensorflow.op.math.Square; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class MathOps { + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + MathOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Normalizes along dimension axis using an L2 norm. + * + * @param x the input + * @param axis Dimension along which to normalize. + * @param the data type for the input and the result + * @return the normalized values based on L2 norm + */ + public Operand l2Normalize(Operand x, int[] axis) { + Operand squareSum = + ReduceSum.create(scope, + Square.create(scope, x), + Constant.vectorOf(scope, axis), + ReduceSum.keepDims(Boolean.TRUE)); + Operand invNorm = + Rsqrt.create(scope, + Maximum.create(scope, squareSum, + Cast.create(scope, + Constant.scalarOf(scope, 1e-12F), x.type()))); + return Mul.create(scope, x, invNorm); + } +} From bd072f4c56b05c007e91ffe68a025b6a0cf03f77 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 27 Mar 2021 18:50:26 -0400 Subject: [PATCH 05/31] Reformat code, fix javadocs --- .../tensorflow/framework/op/FrameworkOps.java | 76 +++-- .../org/tensorflow/framework/op/MathOps.java | 68 ++-- .../org/tensorflow/framework/op/NnOps.java | 312 +++++++++--------- .../op/nn/SigmoidCrossEntropyWithLogits.java | 14 +- .../op/nn/SoftmaxCrossEntropyWithLogits.java | 3 +- .../SparseSoftmaxCrossEntropyWithLogits.java | 52 +-- 6 files changed, 271 insertions(+), 254 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java index 18fb8ada6b7..c8b234f2c51 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -30,11 +30,10 @@ */ public class FrameworkOps { public final Ops coreOps; - private final Scope scope; - public final NnOps nn; public final SetsOps sets; public final MathOps math; + private final Scope scope; /** * Creates a FrameworkOps instance with the provided scope @@ -62,8 +61,43 @@ private FrameworkOps(Ops coreOps) { math = new MathOps(this); } + /** + * Creates an API for building operations in the provided execution environment + * + * @param env the exection environment + * @return the FrameworkOps + */ + public static FrameworkOps create(ExecutionEnvironment env) { + return new FrameworkOps(new Scope(env)); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + *

Invoking this method is equivalent to {@code + * FrameworkOps.create(EagerSession.getDefault())}. + * + * @return the FrameworkOps + */ + public static FrameworkOps create() { + return new FrameworkOps(new Scope(EagerSession.getDefault())); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + * @param coreOps the TensorFlow core Ops + * @return the FrameworkOps + */ + public static FrameworkOps create(Ops coreOps) { + return new FrameworkOps(coreOps); + } - /** Returns the current {@link Scope scope} of this API */ + /** + * Returns the current {@link Scope scope} of this API + * + * @return the current {@link Scope scope} of this API + */ public final Scope scope() { return scope; } @@ -81,6 +115,9 @@ public final Ops coreOps() { * Returns an API that builds operations with the provided name prefix. * *

@link Scope#withSubScope(String)} + * + * @param childScopeName the name of the child scope + * @return the FrameworkOps */ public FrameworkOps withSubScope(String childScopeName) { return new FrameworkOps(scope.withSubScope(childScopeName)); @@ -90,6 +127,9 @@ public FrameworkOps withSubScope(String childScopeName) { * Returns an API that uses the provided name for an op. * *

{@link Scope#withName(String)} + * + * @param opName the name of the scope + * @return the FrameworkOps */ public FrameworkOps withName(String opName) { return new FrameworkOps(scope.withName(opName)); @@ -99,6 +139,9 @@ public FrameworkOps withName(String opName) { * Returns an API that places the created operations on the device(s) matching the provided spec. * *

{@link Scope#withDevice(DeviceSpec)} + * + * @param deviceSpec the device specification for the scope + * @return the FrameworkOps */ public FrameworkOps withDevice(DeviceSpec deviceSpec) { return new FrameworkOps(scope.withDevice(deviceSpec)); @@ -108,32 +151,11 @@ public FrameworkOps withDevice(DeviceSpec deviceSpec) { * Returns an API that adds operations to the graph with the provided control dependencies. * *

{@link Scope#withControlDependencies(Iterable)} + * + * @param controls the operations + * @return the FrameworkOps */ public FrameworkOps withControlDependencies(Iterable controls) { return new FrameworkOps(scope.withControlDependencies(controls)); } - - /** Creates an API for building operations in the provided execution environment */ - public static FrameworkOps create(ExecutionEnvironment env) { - return new FrameworkOps(new Scope(env)); - } - - /** - * Creates an API for building operations in the default eager execution environment - * - *

Invoking this method is equivalent to {@code - * FrameworkOps.create(EagerSession.getDefault())}. - */ - public static FrameworkOps create() { - return new FrameworkOps(new Scope(EagerSession.getDefault())); - } - - /** - * Creates an API for building operations in the default eager execution environment - * - * @param coreOps the TensorFlow core Ops - */ - public static FrameworkOps create(Ops coreOps) { - return new FrameworkOps(coreOps); - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java index 57a18fc63c2..5208cde98f3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.op; import org.tensorflow.Operand; -import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.ReduceSum; @@ -26,42 +25,41 @@ import org.tensorflow.op.math.Square; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - public class MathOps { - private final Scope scope; + private final Scope scope; - private final FrameworkOps frameworkOps; + private final FrameworkOps frameworkOps; - /** - * Creates Framework {@code nn} Operations - * - * @param frameworkOps the TensorFLow framework Ops - */ - MathOps(FrameworkOps frameworkOps) { - this.scope = frameworkOps.scope(); - this.frameworkOps = frameworkOps; - } + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + MathOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } - /** - * Normalizes along dimension axis using an L2 norm. - * - * @param x the input - * @param axis Dimension along which to normalize. - * @param the data type for the input and the result - * @return the normalized values based on L2 norm - */ - public Operand l2Normalize(Operand x, int[] axis) { - Operand squareSum = - ReduceSum.create(scope, - Square.create(scope, x), - Constant.vectorOf(scope, axis), - ReduceSum.keepDims(Boolean.TRUE)); - Operand invNorm = - Rsqrt.create(scope, - Maximum.create(scope, squareSum, - Cast.create(scope, - Constant.scalarOf(scope, 1e-12F), x.type()))); - return Mul.create(scope, x, invNorm); - } + /** + * Normalizes along dimension axis using an L2 norm. + * + * @param x the input + * @param axis Dimension along which to normalize. + * @param the data type for the input and the result + * @return the normalized values based on L2 norm + */ + public Operand l2Normalize(Operand x, int[] axis) { + Operand squareSum = + ReduceSum.create( + scope, + Square.create(scope, x), + Constant.vectorOf(scope, axis), + ReduceSum.keepDims(Boolean.TRUE)); + Operand invNorm = + Rsqrt.create( + scope, + Maximum.create( + scope, squareSum, Cast.create(scope, Constant.scalarOf(scope, 1e-12F), x.type()))); + return Mul.create(scope, x, invNorm); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java index 4054f3ddbb5..0fea3743d95 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java @@ -32,166 +32,164 @@ *

{@link FrameworkOps} */ public class NnOps { - private final Scope scope; + private final Scope scope; - private final FrameworkOps frameworkOps; + private final FrameworkOps frameworkOps; - /** - * Creates Framework {@code nn} Operations - * @param frameworkOps the TensorFLow framework Ops - */ - NnOps(FrameworkOps frameworkOps) { - this.scope = frameworkOps.scope(); - this.frameworkOps = frameworkOps; - } + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + NnOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } - /** - * Computes sigmoid cross entropy given logits. - * - *

Measures the probability error in discrete classification tasks in which each class is - * independent and not mutually exclusive. For instance, one could perform multilabel - * classification where a picture can contain both an elephant and a dog at the same time. - * - *

For brevity, let x = logits, z = labels. The logistic loss in - * pseudo-code is - * - *

-     *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
-     *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
-     *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
-     *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
-     *   = (1 - z) * x + log(1 + exp(-x))
-     *   = x - x * z + log(1 + exp(-x))
-     *  
- * - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above - * - *

-     *  x - x * z + log(1 + exp(-x))
-     *   = log(exp(x)) - x * z + log(1 + exp(-x))
-     *   = - x * z + log(1 + exp(x))
-     *  
- * - *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent - * formulation - * - *

-     *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
-     *  
- * - *

logits and labels must have the same type and shape. - * - *

- * - * @param labels the labels - * @param logits the logits of type float32 or float64 - * @param the type of labels and logits - * @return the component-wise logistic losses. - * @throws IllegalArgumentException if logits' and labels' do not have the same shape - */ - public Operand sigmoidCrossEntropyWithLogits(Operand labels, - Operand logits) { - return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); - } - - /** - * Computes softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

While the classes are mutually exclusive, their probabilities need not be. All that is - * required is that each row of labels is a valid probability distribution. If they - * are not, the computation of the gradient will be incorrect. - * - *

If using exclusive labels (wherein one and only one class is true at a time), - * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} - * - *

Usage: - * - *

-     *    Operand<TFloat32> logits =
-     *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
-     *    Operand<TFloat32> labels =
-     *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
-     *    Operand<TFloat32> output =
-     *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
-     *    // output Shape = [2]
-     *    // dataType = FLOAT (1)
-     *    // values { 0.169846, 0.824745 }
-     *  
- * - *

Backpropagation will happen into both logits and labels. To - * disallow backpropagation into labels, pass label tensors through - * tf.stopGradient before feeding it to this function. - * - * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] - * , each row of labels[i] must be a valid probability distribution. - * @param logits Per-label activations, typically a linear output. These activation energies are - * interpreted as unnormalized log probabilities. - * @param axis The class dimension. -1 is the last dimension. - * @param the number type of the operands - * @return the softmax cross entropy loss. Its type is the same as logits and its - * shape is the same as labels except that it does not have the last dimension of - * labels. - */ - public Operand softmaxCrossEntropyWithLogits( - Operand labels, Operand logits, int axis) { - return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); - } - - /** - * Computes sparse softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

For this operation, the probability of a given label is considered exclusive. That is, soft - * classes are not allowed, and the labels vector must provide a single specific - * index for the true class for each row of logits (each minibatch entry). For soft - * softmax classification with a probability distribution for each entry, {@link - * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. - * - *

WARNING: - * - *

This op expects unscaled logits, since it performs a softmax on logits - * internally for efficiency. Do not call this op with the output of softmax, - * as it will produce incorrect results. - * - *

A common use case is to have logits of shape [batchSize, numClasses] and have - * labels of shape [batchSize], but higher dimensions are supported, in which case - * the dim-th dimension is assumed to be of size numClasses. - * logits must have the dataType of TFloat16, TFloat32 - * , or TFloat64, and labels must have the dtype of TInt32 - * or TInt64. - * - * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r - * is rank of labels and result) and the dataType is TInt32 - * or TInt64. Each entry in labels must be an index in [0, - * numClasses). Other values will raise an exception when this op is run on CPU, and - * return NaN for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., - * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, - * or TFloat64. These activation energies are interpreted as unnormalized log - * probabilities. - * @param The data type for the labels - * @param The data type for the logits and loss - * @return the loss - * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank - * of the labels is not equal to the rank of the logits minus one. - */ - - public Operand sparseSoftmaxCrossEntropyWithLogits( - Operand labels, Operand logits) { - return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); - } + /** + * Computes sigmoid cross entropy given {@code logits}. + * + *

Measures the probability error in discrete classification tasks in which each class is + * independent and not mutually exclusive. For instance, one could perform multilabel + * classification where a picture can contain both an elephant and a dog at the same time. + * + *

For brevity, let {@code x = logits}, {@code z = labels}. The logistic loss in pseudo-code is + * + *

+   *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+   *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
+   *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
+   *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
+   *   = (1 - z) * x + log(1 + exp(-x))
+   *   = x - x * z + log(1 + exp(-x))
+   *  
+ * + *

For {@code x < 0}, to avoid overflow in {@code exp(-x)}, we reformulate the above + * + *

+   *  x - x * z + log(1 + exp(-x))
+   *   = log(exp(x)) - x * z + log(1 + exp(-x))
+   *   = - x * z + log(1 + exp(x))
+   *  
+ * + *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent + * formulation + * + *

+   *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
+   *  
+ * + *

{@code logits} and {@code labels} must have the same type and shape. + * + *

+ * + * @param labels the labels + * @param logits the logits of type float32 or float64 + * @param the type of labels and logits + * @return the component-wise logistic losses. + * @throws IllegalArgumentException if logits' and labels' do not have the same shape + */ + public Operand sigmoidCrossEntropyWithLogits( + Operand labels, Operand logits) { + return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); + } + /** + * Computes softmax cross entropy between {@code logits} and {@code labels}. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of {@code labels} is a valid probability distribution. If they are + * not, the computation of the gradient will be incorrect. + * + *

If using exclusive {@code labels} (wherein one and only one class is true at a time), see + * {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} + * + *

Usage: + * + *

+   *    Operand<TFloat32> logits =
+   *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
+   *    Operand<TFloat32> labels =
+   *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
+   *    Operand<TFloat32> output =
+   *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+   *    // output Shape = [2]
+   *    // dataType = FLOAT (1)
+   *    // values { 0.169846, 0.824745 }
+   *  
+ * + *

Backpropagation will happen into both {@code logits} and {@code labels}. To disallow + * backpropagation into {@code labels}, pass label tensors through {@code tf.stopGradient} before + * feeding it to this function. + * + * @param labels Each vector along the class dimension should hold a valid probability + * distribution e.g. for the case in which labels are of shape {@code [batch_size, + * num_classes] }, each row of {@code labels[i]} must be a valid probability distribution. + * @param logits Per-label activations, typically a linear output. These activation energies are + * interpreted as unnormalized log probabilities. + * @param axis The class dimension. -1 is the last dimension. + * @param the number type of the operands + * @param the data type for the labels. + * @return the softmax cross entropy loss. Its type is the same as {@code logits} and its shape is + * the same as {@code labels} except that it does not have the last dimension of {@code + * labels}. + * + */ + public Operand softmaxCrossEntropyWithLogits( + Operand labels, Operand logits, int axis) { + return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + } + /** + * Computes sparse softmax cross entropy between {@code logits} and {@code labels}. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

For this operation, the probability of a given label is considered exclusive. That is, soft + * classes are not allowed, and the {@code labels} vector must provide a single specific index for + * the true class for each row of {@code logits} (each minibatch entry). For soft softmax + * classification with a probability distribution for each entry, {@link + * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. + * + *

WARNING: + * + *

This op expects unscaled logits, since it performs a {@code softmax} on {@code logits } + * internally for efficiency. Do not call this op with the output of {@code softmax}, as it will + * produce incorrect results. + * + *

A common use case is to have logits of shape {@code [batchSize, numClasses]} and have labels + * of shape {@code [batchSize]}, but higher dimensions are supported, in which case the {@code + * dim}-th dimension is assumed to be of size {@code numClasses}. {@code logits} must have the + * {@code dataType} of {@code TFloat16}, {@code TFloat32} , or {@code TFloat64}, and {@code + * labels} must have the dtype of {@code TInt32} or {@code TInt64}. + * + * @param labels {@code Tensor} of shape {@code [d_0, d_1, ..., d_{r-1}]} (where {@code r } is + * rank of {@code labels} and result) and the dataType is {@code TInt32} or {@code TInt64}. + * Each entry in {@code labels} must be an index in {@code [0, numClasses)}. Other values will + * raise an exception when this op is run on CPU, and return {@code NaN} for corresponding + * loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape {@code [d_0, d_1, ..., + * d_{r-1}, numClasses]} and dataType of {@code TFloat16}, {@code TFloat32}, or {@code + * TFloat64}. These activation energies are interpreted as unnormalized log probabilities. + * @param The data type for the labels + * @param The data type for the logits and loss + * @return the loss + * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if the rank + * of the labels is not equal to the rank of the logits minus one. + */ + public Operand sparseSoftmaxCrossEntropyWithLogits( + Operand labels, Operand logits) { + return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits( + scope, labels, logits); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java index b55385839d3..fc3f7739363 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java @@ -3,8 +3,6 @@ import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.Select; import org.tensorflow.op.core.ZerosLike; import org.tensorflow.op.dtypes.Cast; @@ -18,17 +16,17 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -//@Operator(group = "nn") +// @Operator(group = "nn") public class SigmoidCrossEntropyWithLogits { /** - * Computes sigmoid cross entropy given logits. + * Computes sigmoid cross entropy given {@code logits}. * *

Measures the probability error in discrete classification tasks in which each class is * independent and not mutually exclusive. For instance, one could perform multilabel * classification where a picture can contain both an elephant and a dog at the same time. * - *

For brevity, let x = logits, z = labels. The logistic loss in + *

For brevity, let {@code x = logits}, {@code z = labels}. The logistic loss in * pseudo-code is * *

@@ -40,7 +38,7 @@ public class SigmoidCrossEntropyWithLogits {
    *  = x - x * z + log(1 + exp(-x))
    * 
* - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above + *

For {@code x < 0}, to avoid overflow in {@code exp(-x)}, we reformulate the above * *

    * x - x * z + log(1 + exp(-x))
@@ -55,7 +53,7 @@ public class SigmoidCrossEntropyWithLogits {
    *   max(x, 0) - x * z + log(1 + exp(-abs(x)))
    * 
* - *

logits and labels must have the same type and shape. + *

{@code logits} and {@code labels} must have the same type and shape. * *

* @@ -66,7 +64,7 @@ public class SigmoidCrossEntropyWithLogits { * @return the component-wise logistic losses. * @throws IllegalArgumentException if logits' and labels' do not have the same shape */ - //@Endpoint(name = "sigmoidCrossEntropyWithLogits") + // @Endpoint(name = "sigmoidCrossEntropyWithLogits") public static Operand sigmoidCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits) { if (!isCompatible(labels.shape(), logits.shape())) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java index 0f5b8197f1e..7d59941f27a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -66,7 +66,8 @@ public class SoftmaxCrossEntropyWithLogits { * @param logits Per-label activations, typically a linear output. These activation energies are * interpreted as unnormalized log probabilities. * @param axis The class dimension. -1 is the last dimension. - * @param the number type of the operands + * @param the data type for the logits and return operand + * @param the data type for the labels * @return the softmax cross entropy loss. Its type is the same as logits and its * shape is the same as labels except that it does not have the last dimension of * labels. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 75766cf9bfb..0b2d29d6092 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -25,7 +25,7 @@ public class SparseSoftmaxCrossEntropyWithLogits { /** - * Computes sparse softmax cross entropy between logits and labels. + * Computes sparse softmax cross entropy between {@code logits} and {@code labels}. * *

Measures the probability error in discrete classification tasks in which the classes are * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is @@ -34,45 +34,45 @@ public class SparseSoftmaxCrossEntropyWithLogits { *

NOTE: * *

For this operation, the probability of a given label is considered exclusive. That is, soft - * classes are not allowed, and the labels vector must provide a single specific - * index for the true class for each row of logits (each minibatch entry). For soft + * classes are not allowed, and the {@code labels} vector must provide a single specific + * index for the true class for each row of {@code logits} (each minibatch entry). For soft * softmax classification with a probability distribution for each entry, {@link * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. * *

WARNING: * - *

This op expects unscaled logits, since it performs a softmax on logits - * internally for efficiency. Do not call this op with the output of softmax, + *

This op expects unscaled logits, since it performs a {@code softmax} on {@code logits + * } internally for efficiency. Do not call this op with the output of {@code softmax}, * as it will produce incorrect results. * - *

A common use case is to have logits of shape [batchSize, numClasses] and have - * labels of shape [batchSize], but higher dimensions are supported, in which case - * the dim-th dimension is assumed to be of size numClasses. - * logits must have the dataType of TFloat16, TFloat32 - * , or TFloat64, and labels must have the dtype of TInt32 - * or TInt64. + *

A common use case is to have logits of shape {@code [batchSize, numClasses]} and have + * labels of shape {@code [batchSize]}, but higher dimensions are supported, in which case + * the {@code dim}-th dimension is assumed to be of size {@code numClasses}. {@code + * logits} must have the {@code dataType} of {@code TFloat16}, {@code TFloat32} + * , or {@code TFloat64}, and {@code labels} must have the dtype of {@code TInt32} + * or {@code TInt64}. * * @param scope current scope - * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r - * is rank of labels and result) and the dataType is TInt32 - * or TInt64. Each entry in labels must be an index in [0, - * numClasses). Other values will raise an exception when this op is run on CPU, and - * return NaN for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., - * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, - * or TFloat64. These activation energies are interpreted as unnormalized log + * @param labels {@code Tensor} of shape {@code [d_0, d_1, ..., d_{r-1}]} (where {@code r + * } is rank of {@code labels} and result) and the dataType is {@code TInt32} + * or {@code TInt64}. Each entry in {@code labels} must be an index in {@code [0, + * numClasses)}. Other values will raise an exception when this op is run on CPU, and + * return {@code NaN} for corresponding loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape {@code [d_0, d_1, ..., + * d_{r-1}, numClasses]} and dataType of {@code TFloat16}, {@code TFloat32}, + * or {@code TFloat64}. These activation energies are interpreted as unnormalized log * probabilities. - * @param the data type for the labels - * @param the data tyoe for the loss and logits. + * @param the data type for the labels + * @param the data tyoe for the loss and logits. * @return the loss - * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank + * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if the rank * of the labels is not equal to the rank of the logits minus one. */ @SuppressWarnings("unchecked") @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") public static - Operand sparseSoftmaxCrossEntropyWithLogits( - Scope scope, Operand labels, Operand logits) { + Operand sparseSoftmaxCrossEntropyWithLogits( + Scope scope, Operand labels, Operand logits) { scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); Operand preciseLogits; if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { @@ -119,7 +119,7 @@ Operand sparseSoftmaxCrossEntropyWithLogits( return Cast.create(scope, cost, logits.type()); } else { // Unchecked cast already checked with previous if - return (Operand) cost; + return (Operand) cost; } } @@ -160,7 +160,7 @@ Operand sparseSoftmaxCrossEntropyWithLogits( return Cast.create(scope, cost, logits.type()); } else { // Unchecked cast already checked with previous if - return (Operand) cost; + return (Operand) cost; } } } From d29262b5a3169b6ac7f58890661138910ef6ac4f Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 16 Apr 2021 18:04:30 -0400 Subject: [PATCH 06/31] Add confusionMatrix() method. add Unit test --- .../org/tensorflow/framework/op/MathOps.java | 301 +++++++++++++ .../tensorflow/framework/op/MathOpsTest.java | 413 ++++++++++++++++++ 2 files changed, 714 insertions(+) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java index 5208cde98f3..36f5b692cab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java @@ -15,16 +15,37 @@ package org.tensorflow.framework.op; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; +import org.tensorflow.op.core.AssertThat; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Identity; +import org.tensorflow.op.core.OnesLike; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.ReduceAll; +import org.tensorflow.op.core.ReduceMax; import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.core.ScatterNd; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.op.core.Stack; +import org.tensorflow.op.core.Zeros; import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.GreaterEqual; +import org.tensorflow.op.math.Less; import org.tensorflow.op.math.Maximum; import org.tensorflow.op.math.Mul; import org.tensorflow.op.math.Rsqrt; import org.tensorflow.op.math.Square; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; +import java.util.Arrays; +import java.util.Collections; + public class MathOps { private final Scope scope; @@ -62,4 +83,284 @@ public Operand l2Normalize(Operand x, int[] axis) { scope, squareSum, Cast.create(scope, Constant.scalarOf(scope, 1e-12F), x.type()))); return Mul.create(scope, x, invNorm); } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid + * labels for a given classification task. Both prediction and labels must be 1-D arrays of the + * same shape in order for this function to work. + * + *

If `num_classes` is `None`, then `num_classes` will be set to one plus the maximum value in + * either predictions or labels. Class labels are expected to start at 0. For example, if + * `num_classes` is 3, then the possible labels would be `[0, 1, 2]`. + * + *

If `weights` is not `None`, then each prediction contributes its corresponding weight to the + * total value of the confusion matrix cell. + * + *

For example: + * + *

+   *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
+   *         [[0 0 0 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 0 0 0]
+   *          [0 0 0 0 1]]
+   * 
+ * + *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 + * confusion matrix. + * + * @param labels 1-D Operand of real labels for the classification task. + * @param predictions 1-D Operand of predictions for a given classification. + * @param Data type of the confusion matrix. + * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion + * matrix, where {@code n} is the number of possible labels in the classification task. + * @throws IllegalArgumentException If both predictions and labels are not 1-D vectors and have + * mismatched shapes, or if {@code weights} is not null and its shape doesn't match {@code + * predictions}. + */ + public Operand confusionMatrix(Operand labels, Operand predictions) { + return confusionMatrix(labels, predictions, null, null, labels.type()); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid + * labels for a given classification task. Both prediction and labels must be 1-D arrays of the + * same shape in order for this function to work. + * + *

If `num_classes` is `None`, then `num_classes` will be set to one plus the maximum value in + * either predictions or labels. Class labels are expected to start at 0. For example, if + * `num_classes` is 3, then the possible labels would be `[0, 1, 2]`. + * + *

If `weights` is not `None`, then each prediction contributes its corresponding weight to the + * total value of the confusion matrix cell. + * + *

For example: + * + *

+   *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
+   *         [[0 0 0 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 0 0 0]
+   *          [0 0 0 0 1]]
+   * 
+ * + *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 + * confusion matrix. + * + * @param labels 1-D Operand of real labels for the classification task. + * @param predictions 1-D Operand of predictions for a given classification. + * @param weights An optional Operand whose shape matches {@code predictions}. + * @param Data type of the confusion matrix. + * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion + * matrix, where {@code n} is the number of possible labels in the classification task. + * @throws IllegalArgumentException If both predictions and labels are not 1-D vectors and have + * mismatched shapes, or if {@code weights} is not null and its shape doesn't match {@code + * predictions}. + */ + public Operand confusionMatrix( + Operand labels, Operand predictions, Operand weights) { + return confusionMatrix(labels, predictions, weights, null, labels.type()); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid + * labels for a given classification task. Both prediction and labels must be 1-D arrays of the + * same shape in order for this function to work. + * + *

If `num_classes` is `None`, then `num_classes` will be set to one plus the maximum value in + * either predictions or labels. Class labels are expected to start at 0. For example, if + * `num_classes` is 3, then the possible labels would be `[0, 1, 2]`. + * + *

If `weights` is not `None`, then each prediction contributes its corresponding weight to the + * total value of the confusion matrix cell. + * + *

For example: + * + *

+   *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
+   *         [[0 0 0 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 0 0 0]
+   *          [0 0 0 0 1]]
+   * 
+ * + *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 + * confusion matrix. + * + * @param labels 1-D Operand of real labels for the classification task. + * @param predictions 1-D Operand of predictions for a given classification. + * @param weights An optional Operand whose shape matches {@code predictions}. + * @param numClasses The possible number of labels the classification task can have. If this value + * is null, it will be calculated using both predictions and labels. + * @param type Data type of the confusion matrix. + * @param Data type of the confusion matrix. + * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion + * matrix, where {@code n} is the number of possible labels in the classification task. + * @throws IllegalArgumentException If both predictions and labels are not 1-D vectors and have + * mismatched shapes, or if {@code weights} is not null and its shape doesn't match {@code + * predictions}. + */ + public Operand confusionMatrix( + Operand labels, + Operand predictions, + Operand weights, + Operand numClasses, + Class type) { + Scope lScope = scope.withSubScope("confusionMatrix"); + LossTuple tuple = removeSqueezableDimensions(labels, predictions, 0); + Operand lLabels = Cast.create(lScope, tuple.getLabels(), TInt64.class); + Operand lPredictions = Cast.create(lScope, tuple.getTarget(), TInt64.class); + + Operand zero = Constant.scalarOf(lScope, 0L); + Operand one = Constant.scalarOf(lScope, 1L); + + AssertThat labelsNonNegative = + AssertThat.create( + lScope, + ReduceAll.create(lScope, GreaterEqual.create(lScope, lLabels, zero), allAxes(lLabels)), + Collections.singletonList( + Constant.scalarOf(lScope, "labels contains negative values"))); + lLabels = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(labelsNonNegative)), lLabels); + + AssertThat predictionsNonNegative = + AssertThat.create( + lScope, + ReduceAll.create( + lScope, GreaterEqual.create(lScope, lPredictions, zero), allAxes(lPredictions)), + Collections.singletonList( + Constant.scalarOf(lScope, "predictions contains negative values"))); + lPredictions = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(predictionsNonNegative)), + lPredictions); + + Operand lNumClasses; + if (numClasses == null) { + lNumClasses = + Add.create( + lScope, + Maximum.create( + lScope, + ReduceMax.create(lScope, lPredictions, zero), + ReduceMax.create(lScope, lLabels, zero)), + one); + } else { + lNumClasses = Cast.create(lScope, numClasses, TInt64.class); + AssertThat labelsLess = + AssertThat.create( + lScope, + Less.create(lScope, lLabels, lNumClasses), + Collections.singletonList(Constant.scalarOf(lScope, "labels out of bounds"))); + lLabels = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(labelsLess)), lLabels); + + AssertThat predictionsLess = + AssertThat.create( + lScope, + Less.create(lScope, lPredictions, lNumClasses), + Collections.singletonList(Constant.scalarOf(lScope, "predictions out of bounds"))); + lPredictions = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(predictionsLess)), + lPredictions); + } + + if (weights != null) { + if (!predictions.shape().isCompatibleWith(weights.shape())) { + throw new IllegalArgumentException( + String.format( + "predictions.shape() [%s], is not compatible with weights.shape() [ %s].", + predictions.shape(), weights.shape())); + } + } + + Operand shape = Stack.create(lScope, Arrays.asList(lNumClasses, lNumClasses)); + Operand indices = + Stack.create(lScope, Arrays.asList(lLabels, lPredictions), Stack.axis(1L)); + Operand values = weights == null ? OnesLike.create(lScope, predictions) : weights; + Operand zeroMatrix = Zeros.create(lScope, Cast.create(lScope, shape, TInt32.class), type); + + return ScatterNd.create(lScope, indices, values, shape); + } + + /** + * Squeeze last dim if ranks differ from expected by exactly 1. + * + * @param labels Label values, a Operand whose dimensions match predictions + * . + * @param predictions Predicted values, a Tensor of arbitrary dimensions. + * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). + * @param the data type for the labels, predictions and result + * @return labels and predictions, possibly with last dim squeezed. + */ + public LossTuple removeSqueezableDimensions( + Operand labels, Operand predictions, int expectedRankDiff) { + Scope lScope = scope.withSubScope("removeSqueezableDimensions"); + Shape predictionsShape = predictions.shape(); + int predictionsRank = predictionsShape.numDimensions(); + Shape labelsShape = labels.shape(); + int labelsRank = labelsShape.numDimensions(); + + if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) { + // Use static rank. + int rankDiff = predictionsRank - labelsRank; + if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) { + predictions = Squeeze.create(lScope, predictions); + } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) { + labels = Squeeze.create(lScope, labels); + } + return new LossTuple<>(labels, predictions); + } + // Use dynamic rank. + + // TODO: hold for lazy select feature, + // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { + /* + * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze + * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), + * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); * + */ + predictions = + Squeeze.create(lScope, predictions, Squeeze.axis(Collections.singletonList(-1L))); + } + if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) { + /* + * TODO, if we ever get a select that does lazy evaluation labels = tf.select( + * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), tf.squeeze(labels, + * Squeeze.axis(Arrays.asList(-1L))), predictions ); * + */ + labels = Squeeze.create(lScope, labels, Squeeze.axis(Collections.singletonList(-1L))); + } + return new LossTuple<>(labels, predictions); + } + + public Operand allAxes(Operand op) { + int rank = op.shape().numDimensions(); + if (rank != Shape.UNKNOWN_SIZE) { + int[] axes = new int[rank]; + for (int i = 0; i < rank; i++) { + axes[i] = i; + } + return Constant.vectorOf(scope, axes); + } else { + return Range.create( + scope, Constant.scalarOf(scope, 0), Rank.create(scope, op), Constant.scalarOf(scope, 1)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java new file mode 100644 index 00000000000..326e3cdc2d1 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java @@ -0,0 +1,413 @@ +package org.tensorflow.framework.op; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +class MathOpsTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + double[][][] array = + new double[][][] { + { + {4.17021990e-01, 7.20324516e-01, 1.14374816e-04}, + {3.02332580e-01, 1.46755889e-01, 9.23385918e-02}, + {1.86260208e-01, 3.45560730e-01, 3.96767467e-01}, + {5.38816750e-01, 4.19194520e-01, 6.85219526e-01}, + {2.04452246e-01, 8.78117442e-01, 2.73875929e-02}, + {6.70467496e-01, 4.17304814e-01, 5.58689833e-01}, + {1.40386939e-01, 1.98101491e-01, 8.00744593e-01} + }, + { + {9.68261600e-01, 3.13424170e-01, 6.92322612e-01}, + {8.76389146e-01, 8.94606650e-01, 8.50442126e-02}, + {3.90547849e-02, 1.69830427e-01, 8.78142476e-01}, + {9.83468369e-02, 4.21107620e-01, 9.57889557e-01}, + {5.33165276e-01, 6.91877127e-01, 3.15515637e-01}, + {6.86500907e-01, 8.34625661e-01, 1.82882771e-02}, + {7.50144303e-01, 9.88861084e-01, 7.48165667e-01} + }, + { + {2.80443996e-01, 7.89279342e-01, 1.03226006e-01}, + {4.47893530e-01, 9.08595502e-01, 2.93614149e-01}, + {2.87775338e-01, 1.30028576e-01, 1.93669572e-02}, + {6.78835511e-01, 2.11628109e-01, 2.65546650e-01}, + {4.91573155e-01, 5.33625446e-02, 5.74117601e-01}, + {1.46728575e-01, 5.89305520e-01, 6.99758351e-01}, + {1.02334432e-01, 4.14055973e-01, 6.94400132e-01} + }, + { + {4.14179265e-01, 4.99534607e-02, 5.35896420e-01}, + {6.63794637e-01, 5.14889121e-01, 9.44594741e-01}, + {5.86555064e-01, 9.03401911e-01, 1.37474701e-01}, + {1.39276341e-01, 8.07391286e-01, 3.97676826e-01}, + {1.65354192e-01, 9.27508593e-01, 3.47765863e-01}, + {7.50812113e-01, 7.25997984e-01, 8.83306086e-01}, + {6.23672187e-01, 7.50942409e-01, 3.48898351e-01} + }, + { + {2.69927889e-01, 8.95886242e-01, 4.28091198e-01}, + {9.64840055e-01, 6.63441479e-01, 6.21695697e-01}, + {1.14745975e-01, 9.49489236e-01, 4.49912131e-01}, + {5.78389585e-01, 4.08136815e-01, 2.37026975e-01}, + {9.03379500e-01, 5.73679507e-01, 2.87032709e-03}, + {6.17144942e-01, 3.26644897e-01, 5.27058125e-01}, + {8.85942101e-01, 3.57269764e-01, 9.08535123e-01} + }, + { + {6.23360097e-01, 1.58212427e-02, 9.29437220e-01}, + {6.90896928e-01, 9.97322857e-01, 1.72340512e-01}, + {1.37135744e-01, 9.32595491e-01, 6.96818173e-01}, + {6.60001710e-02, 7.55463064e-01, 7.53876209e-01}, + {9.23024535e-01, 7.11524785e-01, 1.24270961e-01}, + {1.98801346e-02, 2.62109861e-02, 2.83064879e-02}, + {2.46211067e-01, 8.60027969e-01, 5.38831055e-01} + }, + { + {5.52821994e-01, 8.42030883e-01, 1.24173313e-01}, + {2.79183686e-01, 5.85759282e-01, 9.69595730e-01}, + {5.61030209e-01, 1.86472889e-02, 8.00632656e-01}, + {2.32974276e-01, 8.07105184e-01, 3.87860656e-01}, + {8.63541842e-01, 7.47121632e-01, 5.56240261e-01}, + {1.36455223e-01, 5.99176884e-02, 1.21343456e-01}, + {4.45518792e-02, 1.07494131e-01, 2.25709334e-01} + }, + { + {7.12988973e-01, 5.59717000e-01, 1.25559801e-02}, + {7.19742775e-02, 9.67276335e-01, 5.68100452e-01}, + {2.03293234e-01, 2.52325743e-01, 7.43825853e-01}, + {1.95429474e-01, 5.81358910e-01, 9.70019996e-01}, + {8.46828818e-01, 2.39847764e-01, 4.93769705e-01}, + {6.19955719e-01, 8.28980923e-01, 1.56791389e-01}, + {1.85762029e-02, 7.00221434e-02, 4.86345112e-01} + }, + { + {6.06329441e-01, 5.68851411e-01, 3.17362398e-01}, + {9.88616168e-01, 5.79745233e-01, 3.80141169e-01}, + {5.50948203e-01, 7.45334446e-01, 6.69232905e-01}, + {2.64919549e-01, 6.63348362e-02, 3.70084196e-01}, + {6.29717529e-01, 2.10174009e-01, 7.52755582e-01}, + {6.65364787e-02, 2.60315090e-01, 8.04754555e-01}, + {1.93434283e-01, 6.39460862e-01, 5.24670303e-01} + }, + { + {9.24807966e-01, 2.63296783e-01, 6.59610927e-02}, + {7.35065937e-01, 7.72178054e-01, 9.07815874e-01}, + {9.31972086e-01, 1.39515726e-02, 2.34362081e-01}, + {6.16778374e-01, 9.49016333e-01, 9.50176120e-01}, + {5.56653202e-01, 9.15606380e-01, 6.41566217e-01}, + {3.90007704e-01, 4.85990673e-01, 6.04310513e-01}, + {5.49547911e-01, 9.26181436e-01, 9.18733418e-01} + }, + { + {3.94875616e-01, 9.63262558e-01, 1.73955664e-01}, + {1.26329526e-01, 1.35079160e-01, 5.05662143e-01}, + {2.15248056e-02, 9.47970212e-01, 8.27115476e-01}, + {1.50189810e-02, 1.76196262e-01, 3.32063586e-01}, + {1.30996838e-01, 8.09490681e-01, 3.44736665e-01}, + {9.40107465e-01, 5.82014203e-01, 8.78831983e-01}, + {8.44734430e-01, 9.05392289e-01, 4.59880263e-01} + }, + { + {5.46346843e-01, 7.98603594e-01, 2.85718858e-01}, + {4.90253508e-01, 5.99110305e-01, 1.55332759e-02}, + {5.93481421e-01, 4.33676362e-01, 8.07360530e-01}, + {3.15244794e-01, 8.92888725e-01, 5.77857196e-01}, + {1.84010208e-01, 7.87929237e-01, 6.12031162e-01}, + {5.39092720e-02, 4.20193672e-01, 6.79068863e-01}, + {9.18601751e-01, 4.02024889e-04, 9.76759136e-01} + }, + { + {3.76580328e-01, 9.73783553e-01, 6.04716122e-01}, + {8.28845799e-01, 5.74711502e-01, 6.28076196e-01}, + {2.85576284e-01, 5.86833358e-01, 7.50021756e-01}, + {8.58313859e-01, 7.55082190e-01, 6.98057234e-01}, + {8.64479423e-01, 3.22681010e-01, 6.70788765e-01}, + {4.50873941e-01, 3.82102758e-01, 4.10811365e-01}, + {4.01479572e-01, 3.17383945e-01, 6.21919394e-01} + }, + { + {4.30247277e-01, 9.73802090e-01, 6.77800894e-01}, + {1.98569894e-01, 4.26701009e-01, 3.43346238e-01}, + {7.97638834e-01, 8.79998267e-01, 9.03841972e-01}, + {6.62719786e-01, 2.70208269e-01, 2.52366692e-01}, + {8.54897916e-01, 5.27714670e-01, 8.02161098e-01}, + {5.72488546e-01, 7.33142555e-01, 5.19011617e-01}, + {7.70883918e-01, 5.68857968e-01, 4.65709865e-01} + }, + { + {3.42688918e-01, 6.82093501e-02, 3.77924174e-01}, + {7.96260759e-02, 9.82817113e-01, 1.81612849e-01}, + {8.11858714e-01, 8.74961674e-01, 6.88413262e-01}, + {5.69494426e-01, 1.60971433e-01, 4.66880023e-01}, + {3.45172048e-01, 2.25039959e-01, 5.92511892e-01}, + {3.12269837e-01, 9.16305542e-01, 9.09635544e-01}, + {2.57118285e-01, 1.10891297e-01, 1.92962736e-01} + }, + { + {4.99584168e-01, 7.28585660e-01, 2.08194435e-01}, + {2.48033553e-01, 8.51671875e-01, 4.15848732e-01}, + {6.16685092e-01, 2.33666137e-01, 1.01967260e-01}, + {5.15857041e-01, 4.77140993e-01, 1.52671650e-01}, + {6.21806204e-01, 5.44010103e-01, 6.54137373e-01}, + {1.44545540e-01, 7.51527846e-01, 2.22049147e-01}, + {5.19351840e-01, 7.85296023e-01, 2.23304275e-02} + }, + { + {3.24362457e-01, 8.72922361e-01, 8.44709635e-01}, + {5.38440585e-01, 8.66608262e-01, 9.49805975e-01}, + {8.26407015e-01, 8.54115427e-01, 9.87434015e-02}, + {6.51304305e-01, 7.03516960e-01, 6.10240817e-01}, + {7.99615264e-01, 3.45712192e-02, 7.70238757e-01}, + {7.31728613e-01, 2.59698391e-01, 2.57069290e-01}, + {6.32303298e-01, 3.45297456e-01, 7.96588659e-01} + }, + { + {4.46146220e-01, 7.82749414e-01, 9.90471780e-01}, + {3.00248325e-01, 1.43005833e-01, 9.01308417e-01}, + {5.41559398e-01, 9.74740386e-01, 6.36604428e-01}, + {9.93912995e-01, 5.46070814e-01, 5.26425958e-01}, + {1.35427907e-01, 3.55705172e-01, 2.62185670e-02}, + {1.60395175e-01, 7.45637178e-01, 3.03996895e-02}, + {3.66543084e-01, 8.62346232e-01, 6.92677736e-01} + }, + { + {6.90942168e-01, 1.88636795e-01, 4.41904277e-01}, + {5.81577420e-01, 9.89751697e-01, 2.03906223e-01}, + {2.47732908e-01, 2.62173086e-01, 7.50172436e-01}, + {4.56975341e-01, 5.69294393e-02, 5.08516252e-01}, + {2.11960167e-01, 7.98604250e-01, 2.97331393e-01}, + {2.76060123e-02, 5.93432426e-01, 8.43840420e-01}, + {3.81016135e-01, 7.49858320e-01, 5.11141479e-01} + }, + { + {5.40951788e-01, 9.59434330e-01, 8.03960919e-01}, + {3.23230661e-02, 7.09387243e-01, 4.65001494e-01}, + {9.47548926e-01, 2.21432731e-01, 2.67072022e-01}, + {8.14739615e-02, 4.28618819e-01, 1.09018765e-01}, + {6.33786738e-01, 8.02963257e-01, 6.96800470e-01}, + {7.66211390e-01, 3.42454106e-01, 8.45851481e-01}, + {4.28768784e-01, 8.24009895e-01, 6.26496136e-01} + } + }; + + double[][][] expectedArray = { + { + {3.45350616e-02, 5.96526116e-02, 9.47178160e-06}, + {2.50372272e-02, 1.21533722e-02, 7.64688430e-03}, + {1.54248644e-02, 2.86171008e-02, 3.28577124e-02}, + {4.46213149e-02, 3.47149745e-02, 5.67454435e-02}, + {1.69314109e-02, 7.27199987e-02, 2.26806314e-03}, + {5.55237755e-02, 3.45584825e-02, 4.62670736e-02}, + {1.16259372e-02, 1.64054818e-02, 6.63124844e-02} + }, + { + {8.01851526e-02, 2.59557609e-02, 5.73336743e-02}, + {7.25768730e-02, 7.40855262e-02, 7.04281079e-03}, + {3.23426444e-03, 1.40642561e-02, 7.27220699e-02}, + {8.14444851e-03, 3.48734073e-02, 7.93262124e-02}, + {4.41532955e-02, 5.72967827e-02, 2.61289626e-02}, + {5.68515584e-02, 6.91182911e-02, 1.51451665e-03}, + {6.21220917e-02, 8.18910673e-02, 6.19582348e-02} + }, + { + {2.32245550e-02, 6.53630048e-02, 8.54850933e-03}, + {3.70916426e-02, 7.52439946e-02, 2.43152231e-02}, + {2.38316897e-02, 1.07681248e-02, 1.60384597e-03}, + {5.62167615e-02, 1.75256692e-02, 2.19908543e-02}, + {4.07089069e-02, 4.41914052e-03, 4.75447029e-02}, + {1.21511100e-02, 4.88024652e-02, 5.79494536e-02}, + {8.47467501e-03, 3.42894346e-02, 5.75057231e-02} + }, + { + {3.42996456e-02, 4.13682219e-03, 4.43794727e-02}, + {5.49711734e-02, 4.26397808e-02, 7.82252178e-02}, + {4.85746935e-02, 7.48138949e-02, 1.13847647e-02}, + {1.15339644e-02, 6.68629184e-02, 3.29330191e-02}, + {1.36935636e-02, 7.68102556e-02, 2.87997164e-02}, + {6.21773973e-02, 6.01224527e-02, 7.31496885e-02}, + {5.16484901e-02, 6.21881858e-02, 2.88935024e-02} + }, + { + {2.23536789e-02, 7.41914958e-02, 3.54517400e-02}, + {7.99018070e-02, 5.49419262e-02, 5.14848121e-02}, + {9.50251892e-03, 7.86305517e-02, 3.72588076e-02}, + {4.78984788e-02, 3.37992460e-02, 1.96290389e-02}, + {7.48120397e-02, 4.75084223e-02, 2.37701897e-04}, + {5.11079468e-02, 2.70506144e-02, 4.36475389e-02}, + {7.33679906e-02, 2.95867678e-02, 7.52389953e-02} + }, + { + {5.16226478e-02, 1.31021289e-03, 7.69699737e-02}, + {5.72156087e-02, 8.25918168e-02, 1.42721254e-02}, + {1.13566946e-02, 7.72315189e-02, 5.77059686e-02}, + {5.46570681e-03, 6.25625551e-02, 6.24311455e-02}, + {7.64389113e-02, 5.89238741e-02, 1.02913165e-02}, + {1.64634397e-03, 2.17062421e-03, 2.34416011e-03}, + {2.03896053e-02, 7.12219477e-02, 4.46224995e-02} + }, + { + {4.57811356e-02, 6.97315410e-02, 1.02832299e-02}, + {2.31201854e-02, 4.85087894e-02, 8.02956372e-02}, + {4.64608893e-02, 1.54424773e-03, 6.63032085e-02}, + {1.92934200e-02, 6.68392256e-02, 3.21201086e-02}, + {7.15129450e-02, 6.18717745e-02, 4.60642166e-02}, + {1.13003375e-02, 4.96199494e-03, 1.00488793e-02}, + {3.68949817e-03, 8.90196767e-03, 1.86917856e-02} + }, + { + {5.90451285e-02, 4.63521369e-02, 1.03980501e-03}, + {5.96044352e-03, 8.01035613e-02, 4.70464006e-02}, + {1.68354288e-02, 2.08959840e-02, 6.15988411e-02}, + {1.61842033e-02, 4.81443815e-02, 8.03307742e-02}, + {7.01288804e-02, 1.98626388e-02, 4.08908091e-02}, + {5.13407178e-02, 6.86508343e-02, 1.29844472e-02}, + {1.53836084e-03, 5.79878036e-03, 4.02759537e-02} + }, + { + {5.02122790e-02, 4.71085906e-02, 2.62818988e-02}, + {8.18707868e-02, 4.80107442e-02, 3.14808302e-02}, + {4.56259623e-02, 6.17237724e-02, 5.54215349e-02}, + {2.19389219e-02, 5.49342157e-03, 3.06479763e-02}, + {5.21491282e-02, 1.74052510e-02, 6.23383410e-02}, + {5.51012019e-03, 2.15576105e-02, 6.66445568e-02}, + {1.60189737e-02, 5.29560074e-02, 4.34497967e-02} + }, + { + {7.65866041e-02, 2.18045339e-02, 5.46247046e-03}, + {6.08734004e-02, 6.39467835e-02, 7.51794279e-02}, + {7.71798939e-02, 1.15537888e-03, 1.94083489e-02}, + {5.10775894e-02, 7.85913840e-02, 7.86874294e-02}, + {4.60984148e-02, 7.58245885e-02, 5.31303585e-02}, + {3.22979130e-02, 4.02465984e-02, 5.00450842e-02}, + {4.55099978e-02, 7.67003447e-02, 7.60835484e-02} + }, + { + {3.27010415e-02, 7.97711685e-02, 1.44058811e-02}, + {1.04617933e-02, 1.11863809e-02, 4.18756641e-02}, + {1.78254500e-03, 7.85047561e-02, 6.84963465e-02}, + {1.24377478e-03, 1.45914331e-02, 2.74993554e-02}, + {1.08483098e-02, 6.70367777e-02, 2.85488572e-02}, + {7.78536126e-02, 4.81986478e-02, 7.27791712e-02}, + {6.99554384e-02, 7.49787241e-02, 3.80843058e-02} + }, + { + {4.52449061e-02, 6.61351755e-02, 2.36613862e-02}, + {4.05996218e-02, 4.96144369e-02, 1.28636532e-03}, + {4.91482876e-02, 3.59142683e-02, 6.68603703e-02}, + {2.61065327e-02, 7.39432648e-02, 4.78543900e-02}, + {1.52385337e-02, 6.52511939e-02, 5.06844558e-02}, + {4.46441676e-03, 3.47977169e-02, 5.62360846e-02}, + {7.60726482e-02, 3.32930977e-05, 8.08888674e-02} + }, + { + {3.11859436e-02, 8.06424469e-02, 5.00786714e-02}, + {6.86396435e-02, 4.75938842e-02, 5.20132035e-02}, + {2.36495789e-02, 4.85977381e-02, 6.21119440e-02}, + {7.10799918e-02, 6.25310168e-02, 5.78085780e-02}, + {7.15905875e-02, 2.67223511e-02, 5.55503815e-02}, + {3.73384580e-02, 3.16432752e-02, 3.40207368e-02}, + {3.32479365e-02, 2.62836833e-02, 5.15033379e-02} + }, + { + {3.56302932e-02, 8.06439817e-02, 5.61310798e-02}, + {1.64442733e-02, 3.53366137e-02, 2.84337122e-02}, + {6.60552830e-02, 7.28757605e-02, 7.48503357e-02}, + {5.48821613e-02, 2.23768987e-02, 2.08993759e-02}, + {7.07971081e-02, 4.37019095e-02, 6.64297864e-02}, + {4.74097952e-02, 6.07141182e-02, 4.29811813e-02}, + {6.38396144e-02, 4.71091345e-02, 3.85670736e-02} + }, + { + {2.83792764e-02, 5.64865675e-03, 3.12972330e-02}, + {6.59411587e-03, 8.13905448e-02, 1.50400000e-02}, + {6.72328845e-02, 7.24586621e-02, 5.70099279e-02}, + {4.71618399e-02, 1.33306114e-02, 3.86639796e-02}, + {2.85849143e-02, 1.86363515e-02, 4.90679964e-02}, + {2.58601662e-02, 7.58824944e-02, 7.53301233e-02}, + {2.12928709e-02, 9.18329880e-03, 1.59799233e-02} + }, + { + {4.13723253e-02, 6.03367463e-02, 1.72413141e-02}, + {2.05405317e-02, 7.05299526e-02, 3.44378985e-02}, + {5.10698669e-02, 1.93507168e-02, 8.44426826e-03}, + {4.27199379e-02, 3.95137258e-02, 1.26432776e-02}, + {5.14939614e-02, 4.50513922e-02, 5.41714206e-02}, + {1.19703254e-02, 6.22366704e-02, 1.83886718e-02}, + {4.30093557e-02, 6.50331303e-02, 1.84926135e-03} + }, + { + {2.68615987e-02, 7.22897798e-02, 6.99533820e-02}, + {4.45901640e-02, 7.17668831e-02, 7.86567777e-02}, + {6.84376806e-02, 7.07323104e-02, 8.17728881e-03}, + {5.39368056e-02, 5.82607202e-02, 5.05361930e-02}, + {6.62189573e-02, 2.86296452e-03, 6.37861863e-02}, + {6.05970249e-02, 2.15065386e-02, 2.12888140e-02}, + {5.23632653e-02, 2.85952985e-02, 6.59683123e-02} + }, + { + {3.69469412e-02, 6.48222342e-02, 8.20244551e-02}, + {2.48646215e-02, 1.18428171e-02, 7.46405274e-02}, + {4.48484421e-02, 8.07216838e-02, 5.27194552e-02}, + {8.23094398e-02, 4.52220477e-02, 4.35951874e-02}, + {1.12152621e-02, 2.94571985e-02, 2.17125192e-03}, + {1.32828895e-02, 6.17488436e-02, 2.51750532e-03}, + {3.03547252e-02, 7.14139268e-02, 5.73630854e-02} + }, + { + {5.72193563e-02, 1.56216780e-02, 3.65956500e-02}, + {4.81624752e-02, 8.19648281e-02, 1.68861933e-02}, + {2.05156356e-02, 2.17114780e-02, 6.21244237e-02}, + {3.78437378e-02, 4.71452763e-03, 4.21120226e-02}, + {1.75531674e-02, 6.61352351e-02, 2.46230606e-02}, + {2.28615105e-03, 4.91442308e-02, 6.98814020e-02}, + {3.15532871e-02, 6.20984100e-02, 4.23294269e-02} + }, + { + {4.47981246e-02, 7.94541389e-02, 6.65788352e-02}, + {2.67678709e-03, 5.87468557e-02, 3.85084115e-02}, + {7.84698650e-02, 1.83376241e-02, 2.21171752e-02}, + {6.74714567e-03, 3.54954340e-02, 9.02822800e-03}, + {5.24861142e-02, 6.64962158e-02, 5.77045009e-02}, + {6.34526685e-02, 2.83598304e-02, 7.00479448e-02}, + {3.55078541e-02, 6.82391599e-02, 5.18823527e-02} + } + }; + + @Test + public void testL2Normalize() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + Operand input = tf.constant(array); + Operand result = fops.math.l2Normalize(tf.constant(array), new int[]{ 0,1,2}); + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testConfusionMatrix() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + long[] labels = new long[] {2, 0, 2, 2, 0, 1}; + long[] predictions = new long[] {0, 0, 2, 2, 0, 2}; + Operand result = + fops.math.confusionMatrix(tf.constant(labels), tf.constant(predictions)); + long[][] expected = + new long[][] { + {2, 0, 0}, + {0, 0, 1}, + {1, 0, 2} + }; + session.evaluate(tf.constant(expected), result); + } + } +} From e0a4a26d3a90100311c45a686ef19e0f1ebdbf19 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 26 Mar 2021 17:54:15 -0400 Subject: [PATCH 07/31] Moved high level tf.nn ops to framework. Moved tf.raw.nn Ops to tf.nn. Changed generation to generate SoftmaxCrossEntropyWithLogits and SparseSoftmaxCrossEntropyWithLogits to core NNOps (tf.nn). --- ...pi_def_SoftmaxCrossEntropyWithLogits.pbtxt | 2 +- ..._SparseSoftmaxCrossEntropyWithLogits.pbtxt | 2 +- .../annotations/org/tensorflow/op/NnOps.java | 175 +++--------------- .../org/tensorflow/op/NnRawOps.java | 83 --------- .../SoftmaxCrossEntropyWithLogits.java | 59 +++--- .../SparseSoftmaxCrossEntropyWithLogits.java | 64 +++---- .../op/nn/SigmoidCrossEntropyWithLogits.java | 14 +- .../op/nn/SoftmaxCrossEntropyWithLogits.java | 44 +++-- .../SparseSoftmaxCrossEntropyWithLogits.java | 47 +++-- 9 files changed, 150 insertions(+), 340 deletions(-) delete mode 100644 tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/{raw => }/SoftmaxCrossEntropyWithLogits.java (82%) rename tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/{raw => }/SparseSoftmaxCrossEntropyWithLogits.java (79%) rename {tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow => tensorflow-framework/src/main/java/org/tensorflow/framework}/op/nn/SigmoidCrossEntropyWithLogits.java (91%) rename {tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow => tensorflow-framework/src/main/java/org/tensorflow/framework}/op/nn/SoftmaxCrossEntropyWithLogits.java (87%) rename {tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow => tensorflow-framework/src/main/java/org/tensorflow/framework}/op/nn/SparseSoftmaxCrossEntropyWithLogits.java (83%) diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt index 5dba2164cd6..e064562c0f2 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SoftmaxCrossEntropyWithLogits.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "SoftmaxCrossEntropyWithLogits" endpoint { - name: "nn.raw.SoftmaxCrossEntropyWithLogits" + name: "nn.SoftmaxCrossEntropyWithLogits" } } diff --git a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt index cf80ff77565..7627d5f6074 100644 --- a/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt +++ b/tensorflow-core/tensorflow-core-api/src/bazel/api_def/api_def_SparseSoftmaxCrossEntropyWithLogits.pbtxt @@ -1,6 +1,6 @@ op { graph_op_name: "SparseSoftmaxCrossEntropyWithLogits" endpoint { - name: "nn.raw.SparseSoftmaxCrossEntropyWithLogits" + name: "nn.SparseSoftmaxCrossEntropyWithLogits" } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java index 8958b4fe2ff..1cf8b910297 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java @@ -83,7 +83,6 @@ import org.tensorflow.op.nn.Relu; import org.tensorflow.op.nn.Relu6; import org.tensorflow.op.nn.Selu; -import org.tensorflow.op.nn.SigmoidCrossEntropyWithLogits; import org.tensorflow.op.nn.Softmax; import org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.op.nn.Softsign; @@ -103,8 +102,6 @@ * @see {@link Ops} */ public final class NnOps { - public final NnRawOps raw; - private final Scope scope; private final Ops ops; @@ -112,7 +109,6 @@ public final class NnOps { NnOps(Ops ops) { this.scope = ops.scope(); this.ops = ops; - raw = new NnRawOps(ops); } /** @@ -1797,56 +1793,6 @@ public Selu selu(Operand features) { return Selu.create(scope, features); } - /** - * Computes sigmoid cross entropy given logits. - * - *

Measures the probability error in discrete classification tasks in which each class is - * independent and not mutually exclusive. For instance, one could perform multilabel - * classification where a picture can contain both an elephant and a dog at the same time. - * - *

For brevity, let x = logits, z = labels. The logistic loss in - * pseudo-code is - * - *

-   *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
-   *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
-   *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
-   *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
-   *   = (1 - z) * x + log(1 + exp(-x))
-   *   = x - x * z + log(1 + exp(-x))
-   *  
- * - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above - * - *

-   *  x - x * z + log(1 + exp(-x))
-   *   = log(exp(x)) - x * z + log(1 + exp(-x))
-   *   = - x * z + log(1 + exp(x))
-   *  
- * - *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent - * formulation - * - *

-   *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
-   *  
- * - *

logits and labels must have the same type and shape. - * - *

- * - * @param scope The TensorFlow scope - * @param labels the labels - * @param logits the logits of type float32 or float64 - * @param the type of labels and logits - * @return the component-wise logistic losses. - * @throws IllegalArgumentException if logits' and labels' do not have the same shape - */ - public Operand sigmoidCrossEntropyWithLogits(Operand labels, - Operand logits) { - return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); - } - /** * Computes softmax activations. * For each batch {@code i} and class {@code j} we have @@ -1864,54 +1810,20 @@ public Softmax softmax(Operand logits) { } /** - * Computes softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

While the classes are mutually exclusive, their probabilities need not be. All that is - * required is that each row of labels is a valid probability distribution. If they - * are not, the computation of the gradient will be incorrect. - * - *

If using exclusive labels (wherein one and only one class is true at a time), - * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} - * - *

Usage: - * - *

-   *    Operand<TFloat32> logits =
-   *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
-   *    Operand<TFloat32> labels =
-   *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
-   *    Operand<TFloat32> output =
-   *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
-   *    // output Shape = [2]
-   *    // dataType = FLOAT (1)
-   *    // values { 0.169846, 0.824745 }
-   *  
- * - *

Backpropagation will happen into both logits and labels. To - * disallow backpropagation into labels, pass label tensors through - * tf.stopGradient before feeding it to this function. + * Computes softmax cross entropy cost and gradients to backpropagate. + *

+ * Inputs are the logits, not probabilities. * - * @param scope current scope - * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] - * , each row of labels[i] must be a valid probability distribution. - * @param logits Per-label activations, typically a linear output. These activation energies are - * interpreted as unnormalized log probabilities. - * @param axis The class dimension. -1 is the last dimension. - * @param the number type of the operands - * @return the softmax cross entropy loss. Its type is the same as logits and its - * shape is the same as labels except that it does not have the last dimension of - * labels. + * @param data type for {@code loss()} output + * @param features batch_size x num_classes matrix + * @param labels batch_size x num_classes matrix + * The caller must ensure that each batch of labels represents a valid + * probability distribution. + * @return a new instance of SoftmaxCrossEntropyWithLogits */ - public Operand softmaxCrossEntropyWithLogits( - Operand labels, Operand logits, int axis) { - return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + public SoftmaxCrossEntropyWithLogits softmaxCrossEntropyWithLogits( + Operand features, Operand labels) { + return SoftmaxCrossEntropyWithLogits.create(scope, features, labels); } /** @@ -2098,51 +2010,24 @@ public SpaceToDepth spaceToDepth(Operand input, Long blo } /** - * Computes sparse softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

For this operation, the probability of a given label is considered exclusive. That is, soft - * classes are not allowed, and the labels vector must provide a single specific - * index for the true class for each row of logits (each minibatch entry). For soft - * softmax classification with a probability distribution for each entry, {@link - * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. - * - *

WARNING: - * - *

This op expects unscaled logits, since it performs a softmax on logits - * internally for efficiency. Do not call this op with the output of softmax, - * as it will produce incorrect results. - * - *

A common use case is to have logits of shape [batchSize, numClasses] and have - * labels of shape [batchSize], but higher dimensions are supported, in which case - * the dim-th dimension is assumed to be of size numClasses. - * logits must have the dataType of TFloat16, TFloat32 - * , or TFloat64, and labels must have the dtype of TInt32 - * or TInt64. - * - * @param scope current scope - * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r - * is rank of labels and result) and the dataType is TInt32 - * or TInt64. Each entry in labels must be an index in [0, - * numClasses). Other values will raise an exception when this op is run on CPU, and - * return NaN for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., - * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, - * or TFloat64. These activation energies are interpreted as unnormalized log - * probabilities. - * @return A Tensor of the same shape as labels and of the same type as - * logits with the softmax cross entropy loss. - * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank - * of the labels is not equal to the rank of the logits minus one. - */ - public Operand sparseSoftmaxCrossEntropyWithLogits( - Operand labels, Operand logits) { - return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); + * Computes softmax cross entropy cost and gradients to backpropagate. + *

+ * Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept + * a matrix of label probabilities, but rather a single label per row + * of features. This label is considered to have probability 1.0 for the + * given row. + *

+ * Inputs are the logits, not probabilities. + * + * @param data type for {@code loss()} output + * @param features batch_size x num_classes matrix + * @param labels batch_size vector with values in [0, num_classes). + * This is the label for the given minibatch entry. + * @return a new instance of SparseSoftmaxCrossEntropyWithLogits + */ + public SparseSoftmaxCrossEntropyWithLogits sparseSoftmaxCrossEntropyWithLogits( + Operand features, Operand labels) { + return SparseSoftmaxCrossEntropyWithLogits.create(scope, features, labels); } /** diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java deleted file mode 100644 index c287459c460..00000000000 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnRawOps.java +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================== -// -// This class has been generated, DO NOT EDIT! -// -package org.tensorflow.op; - -import org.tensorflow.Operand; -import org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits; -import org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits; -import org.tensorflow.types.family.TNumber; - -/** - * An API for building {@code nn.raw} operations as {@link Op Op}s - * - * @see {@link Ops} - */ -public final class NnRawOps { - private final Scope scope; - - private final Ops ops; - - NnRawOps(Ops ops) { - this.scope = ops.scope(); - this.ops = ops; - } - - /** - * Computes softmax cross entropy cost and gradients to backpropagate. - * Inputs are the logits, not probabilities. - * - * @param data type for {@code loss} output - * @param features batch_size x num_classes matrix - * @param labels batch_size x num_classes matrix - * The caller must ensure that each batch of labels represents a valid - * probability distribution. - * @param data type for {@code SoftmaxCrossEntropyWithLogits} output and operands - * @return a new instance of SoftmaxCrossEntropyWithLogits - */ - public SoftmaxCrossEntropyWithLogits softmaxCrossEntropyWithLogits( - Operand features, Operand labels) { - return SoftmaxCrossEntropyWithLogits.create(scope, features, labels); - } - - /** - * Computes softmax cross entropy cost and gradients to backpropagate. - * Unlike {@code SoftmaxCrossEntropyWithLogits}, this operation does not accept - * a matrix of label probabilities, but rather a single label per row - * of features. This label is considered to have probability 1.0 for the - * given row. - *

Inputs are the logits, not probabilities. - * - * @param data type for {@code loss} output - * @param features batch_size x num_classes matrix - * @param labels batch_size vector with values in [0, num_classes). - * This is the label for the given minibatch entry. - * @param data type for {@code SparseSoftmaxCrossEntropyWithLogits} output and operands - * @return a new instance of SparseSoftmaxCrossEntropyWithLogits - */ - public SparseSoftmaxCrossEntropyWithLogits sparseSoftmaxCrossEntropyWithLogits( - Operand features, Operand labels) { - return SparseSoftmaxCrossEntropyWithLogits.create(scope, features, labels); - } - - /** - * Get the parent {@link Ops} object. - */ - public final Ops ops() { - return ops; - } -} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java similarity index 82% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java index 331933979c7..5d3ab3c1100 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.nn.raw; +package org.tensorflow.op.nn; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -29,68 +29,57 @@ /** * Computes softmax cross entropy cost and gradients to backpropagate. + *

* Inputs are the logits, not probabilities. - * - * @param data type for {@code loss} output + * + * @param data type for {@code loss()} output */ -@Operator( - group = "nn.raw" -) +@Operator(group = "nn") public final class SoftmaxCrossEntropyWithLogits extends RawOp { - /** - * The name of this op, as known by TensorFlow core engine - */ - public static final String OP_NAME = "SoftmaxCrossEntropyWithLogits"; - - private Output loss; - - private Output backprop; - - private SoftmaxCrossEntropyWithLogits(Operation operation) { - super(operation); - int outputIdx = 0; - loss = operation.output(outputIdx++); - backprop = operation.output(outputIdx++); - } - + /** * Factory method to create a class wrapping a new SoftmaxCrossEntropyWithLogits operation. - * + * * @param scope current scope * @param features batch_size x num_classes matrix * @param labels batch_size x num_classes matrix * The caller must ensure that each batch of labels represents a valid * probability distribution. - * @param data type for {@code SoftmaxCrossEntropyWithLogits} output and operands * @return a new instance of SoftmaxCrossEntropyWithLogits */ - @Endpoint( - describeByClass = true - ) - public static SoftmaxCrossEntropyWithLogits create(Scope scope, - Operand features, Operand labels) { + @Endpoint(describeByClass = true) + public static SoftmaxCrossEntropyWithLogits create(Scope scope, Operand features, Operand labels) { OperationBuilder opBuilder = scope.env().opBuilder("SoftmaxCrossEntropyWithLogits", scope.makeOpName("SoftmaxCrossEntropyWithLogits")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(labels.asOutput()); opBuilder = scope.apply(opBuilder); return new SoftmaxCrossEntropyWithLogits<>(opBuilder.build()); } - + /** - * Gets loss. * Per example loss (batch_size vector). - * @return loss. */ public Output loss() { return loss; } - + /** - * Gets backprop. * backpropagated gradients (batch_size x num_classes matrix). - * @return backprop. */ public Output backprop() { return backprop; } + + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "SoftmaxCrossEntropyWithLogits"; + + private Output loss; + private Output backprop; + + private SoftmaxCrossEntropyWithLogits(Operation operation) { + super(operation); + int outputIdx = 0; + loss = operation.output(outputIdx++); + backprop = operation.output(outputIdx); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java similarity index 79% rename from tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java rename to tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 8c48cd0db4d..794beab4ded 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/raw/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -15,7 +15,7 @@ // This class has been generated, DO NOT EDIT! -package org.tensorflow.op.nn.raw; +package org.tensorflow.op.nn; import org.tensorflow.Operand; import org.tensorflow.Operation; @@ -29,71 +29,61 @@ /** * Computes softmax cross entropy cost and gradients to backpropagate. - * Unlike {@code SoftmaxCrossEntropyWithLogits}, this operation does not accept + *

+ * Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept * a matrix of label probabilities, but rather a single label per row * of features. This label is considered to have probability 1.0 for the * given row. - *

Inputs are the logits, not probabilities. - * - * @param data type for {@code loss} output + *

+ * Inputs are the logits, not probabilities. + * + * @param data type for {@code loss()} output */ -@Operator( - group = "nn.raw" -) +@Operator(group = "nn") public final class SparseSoftmaxCrossEntropyWithLogits extends RawOp { - /** - * The name of this op, as known by TensorFlow core engine - */ - public static final String OP_NAME = "SparseSoftmaxCrossEntropyWithLogits"; - - private Output loss; - - private Output backprop; - - private SparseSoftmaxCrossEntropyWithLogits(Operation operation) { - super(operation); - int outputIdx = 0; - loss = operation.output(outputIdx++); - backprop = operation.output(outputIdx++); - } - + /** * Factory method to create a class wrapping a new SparseSoftmaxCrossEntropyWithLogits operation. - * + * * @param scope current scope * @param features batch_size x num_classes matrix * @param labels batch_size vector with values in [0, num_classes). * This is the label for the given minibatch entry. - * @param data type for {@code SparseSoftmaxCrossEntropyWithLogits} output and operands * @return a new instance of SparseSoftmaxCrossEntropyWithLogits */ - @Endpoint( - describeByClass = true - ) - public static SparseSoftmaxCrossEntropyWithLogits create(Scope scope, - Operand features, Operand labels) { + @Endpoint(describeByClass = true) + public static SparseSoftmaxCrossEntropyWithLogits create(Scope scope, Operand features, Operand labels) { OperationBuilder opBuilder = scope.env().opBuilder("SparseSoftmaxCrossEntropyWithLogits", scope.makeOpName("SparseSoftmaxCrossEntropyWithLogits")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(labels.asOutput()); opBuilder = scope.apply(opBuilder); return new SparseSoftmaxCrossEntropyWithLogits<>(opBuilder.build()); } - + /** - * Gets loss. * Per example loss (batch_size vector). - * @return loss. */ public Output loss() { return loss; } - + /** - * Gets backprop. * backpropagated gradients (batch_size x num_classes matrix). - * @return backprop. */ public Output backprop() { return backprop; } + + /** The name of this op, as known by TensorFlow core engine */ + public static final String OP_NAME = "SparseSoftmaxCrossEntropyWithLogits"; + + private Output loss; + private Output backprop; + + private SparseSoftmaxCrossEntropyWithLogits(Operation operation) { + super(operation); + int outputIdx = 0; + loss = operation.output(outputIdx++); + backprop = operation.output(outputIdx); + } } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java similarity index 91% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java index 92c413f7e52..b55385839d3 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SigmoidCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java @@ -1,4 +1,4 @@ -package org.tensorflow.op.nn; +package org.tensorflow.framework.op.nn; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; @@ -8,11 +8,17 @@ import org.tensorflow.op.core.Select; import org.tensorflow.op.core.ZerosLike; import org.tensorflow.op.dtypes.Cast; -import org.tensorflow.op.math.*; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Exp; +import org.tensorflow.op.math.GreaterEqual; +import org.tensorflow.op.math.Log1p; +import org.tensorflow.op.math.Mul; +import org.tensorflow.op.math.Neg; +import org.tensorflow.op.math.Sub; import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -@Operator(group = "nn") +//@Operator(group = "nn") public class SigmoidCrossEntropyWithLogits { /** @@ -60,7 +66,7 @@ public class SigmoidCrossEntropyWithLogits { * @return the component-wise logistic losses. * @throws IllegalArgumentException if logits' and labels' do not have the same shape */ - @Endpoint(name = "sigmoidCrossEntropyWithLogits") + //@Endpoint(name = "sigmoidCrossEntropyWithLogits") public static Operand sigmoidCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits) { if (!isCompatible(labels.shape(), logits.shape())) { diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java similarity index 87% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java index ddeacbea4d4..0f5b8197f1e 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -1,11 +1,15 @@ -package org.tensorflow.op.nn; +package org.tensorflow.framework.op.nn; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.core.*; +import org.tensorflow.op.core.Concat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Slice; import org.tensorflow.op.dtypes.Cast; import org.tensorflow.op.linalg.Transpose; import org.tensorflow.op.math.Sub; @@ -14,12 +18,11 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; -import org.tensorflow.types.family.TType; import java.util.Arrays; import java.util.List; -@Operator(group = "nn") +// @Operator(group = "nn") public class SoftmaxCrossEntropyWithLogits { /** @@ -68,6 +71,7 @@ public class SoftmaxCrossEntropyWithLogits { * shape is the same as labels except that it does not have the last dimension of * labels. */ + @SuppressWarnings("unchecked") @Endpoint(name = "softmaxCrossEntropyWithLogits") public static Operand softmaxCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits, int axis) { @@ -78,7 +82,9 @@ public static Operand softmaxCrossEntr } if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { - Operand result = softmaxCrossEntropyWithLogits(scope, + Operand result = + softmaxCrossEntropyWithLogits( + scope, Cast.create(scope, labels, TFloat32.class), Cast.create(scope, logits, TFloat32.class), axis); @@ -86,10 +92,8 @@ public static Operand softmaxCrossEntr } if (logits.asOutput().type() != labels.asOutput().type()) { - return softmaxCrossEntropyWithLogits(scope, - Cast.create(scope, labels, logits.asOutput().type()), - logits, - axis); + return softmaxCrossEntropyWithLogits( + scope, Cast.create(scope, labels, logits.asOutput().type()), logits, axis); } Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); @@ -101,13 +105,20 @@ public static Operand softmaxCrossEntr labels = moveDimToEnd(scope, labels, axis, inputRank); } + Operand tLabels; + if (labels.type() != logits.type()) { + tLabels = Cast.create(scope, labels, logits.type()); + } else { + // Unchecked warning checked in if statement. + tLabels = (Operand) labels; + } + Shape inputShape = logits.shape(); logits = flattenOuterDims(scope, logits); - labels = flattenOuterDims(scope, labels); + tLabels = flattenOuterDims(scope, tLabels); - org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.raw.SoftmaxCrossEntropyWithLogits.create( - scope, logits, (Operand)labels); + org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits.create(scope, logits, tLabels); /* cannot use generic on cost, because cost may be recast later. */ Operand cost = smax.loss(); Operand outputShape = @@ -119,6 +130,9 @@ public static Operand softmaxCrossEntr cost = Reshape.create(scope, cost, outputShape); if (scope.env().isGraph() && !shape.hasUnknownDimension()) { long[] array = shape.asArray(); + if (array == null) { + array = new long[0]; + } long[] newArray = new long[array.length - 1]; if (axis < 0) { axis = shape.numDimensions() + axis; @@ -153,7 +167,7 @@ private static Operand flattenOuterDims(Scope scope, Oper boolean productValid = true; for (int i = ndims - 2; i >= 0; i--) { long d = shape.size(i); - if (d == org.tensorflow.ndarray.Shape.UNKNOWN_SIZE) { + if (d == Shape.UNKNOWN_SIZE) { productValid = false; break; } diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java similarity index 83% rename from tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 54b32bb5c63..64faa7c5d70 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -1,11 +1,10 @@ -package org.tensorflow.op.nn; +package org.tensorflow.framework.op.nn; import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.AssertThat; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Reshape; @@ -22,7 +21,7 @@ import java.util.Collections; import java.util.List; -@Operator(group = "nn") +// @Operator(group = "nn") public class SparseSoftmaxCrossEntropyWithLogits { /** @@ -63,19 +62,24 @@ public class SparseSoftmaxCrossEntropyWithLogits { * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, * or TFloat64. These activation energies are interpreted as unnormalized log * probabilities. + * @param the data type for the labels + * @param the data tyoe for the loss and logits. * @return A Tensor of the same shape as labels and of the same type as * logits with the softmax cross entropy loss. * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank * of the labels is not equal to the rank of the logits minus one. */ + @SuppressWarnings("unchecked") @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") - public static Operand sparseSoftmaxCrossEntropyWithLogits( - Scope scope, Operand labels, Operand logits) { + public static + Operand sparseSoftmaxCrossEntropyWithLogits( + Scope scope, Operand labels, Operand logits) { scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); - /** cannot use generics on preciseLogits as it may be recast later */ - Operand preciseLogits = logits; + Operand preciseLogits; if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { preciseLogits = Cast.create(scope, logits, TFloat32.class); + } else { + preciseLogits = logits; } Shape labelsStaticShape = labels.shape(); org.tensorflow.op.core.Shape labelsShape = @@ -108,14 +112,16 @@ public static Operand sparseSoftmaxCrossE } // Check if no reshapes are required. if (logitsShape.numDimensions() == 2) { - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create( scope, preciseLogits, labels); - Operand loss = smax.loss(); - if (logits.asOutput().type() == TFloat16.class) { - loss = Cast.create(scope, loss, TFloat16.class); + Operand cost = smax.loss(); + if (cost.type() != logits.type()) { + return Cast.create(scope, cost, logits.type()); + } else { + // Unchecked cast already checked with previous if + return (Operand) cost; } - return loss; } List shapeChecks = new ArrayList<>(); @@ -145,14 +151,17 @@ public static Operand sparseSoftmaxCrossE preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses)); labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1)); scope.withControlDependencies(shapeChecks); - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.raw.SparseSoftmaxCrossEntropyWithLogits.create( + // call raw op + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create( scope, preciseLogits, labels); - Operand cost = smax.loss(); + Operand cost = smax.loss(); cost = Reshape.create(scope, cost, labelsShape); - if (logits.asOutput().type() == TFloat16.class) { - cost = Cast.create(scope, cost, TFloat16.class); + if (cost.type() != logits.type()) { + return Cast.create(scope, cost, logits.type()); + } else { + // Unchecked cast already checked with previous if + return (Operand) cost; } - return cost; } } From 28db4df34beab9f73557145b2858aac8feb36fc0 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 26 Mar 2021 18:02:55 -0400 Subject: [PATCH 08/31] Added FrameworkOps analogous to Ops. Added NnOps and SetOps as groups. Fixed MetricsHelper and Losses to use the bew FrameworkOps. Moved SetsOps to framework.op. --- .../tensorflow/framework/losses/Losses.java | 17 +- .../framework/metrics/impl/MetricsHelper.java | 685 ++---------------- .../framework/metrics/impl/SetsOps.java | 147 ---- .../tensorflow/framework/op/FrameworkOps.java | 136 ++++ .../org/tensorflow/framework/op/NnOps.java | 197 +++++ .../org/tensorflow/framework/op/SetsOps.java | 161 ++++ .../SparseSoftmaxCrossEntropyWithLogits.java | 3 +- .../{SetsOpsTest.java => SetOpsTest.java} | 18 +- 8 files changed, 559 insertions(+), 805 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java rename tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/{SetsOpsTest.java => SetOpsTest.java} (86%) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 9aa94cf7fcf..aa5fa4ada6d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -19,6 +19,7 @@ import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.op.core.ReduceAll; import org.tensorflow.op.core.ReduceMax; import org.tensorflow.op.core.ReduceSum; @@ -181,7 +182,8 @@ public static Operand binaryCrossentropy( */ private static Operand binaryCrossentropyHelper( Ops tf, Operand target, Operand output, boolean fromLogits) { - if (fromLogits) return tf.nn.sigmoidCrossEntropyWithLogits(target, output); + FrameworkOps fop = FrameworkOps.create(tf); + if (fromLogits) { return fop.nn.sigmoidCrossEntropyWithLogits(target, output);} /* TODO - skip this logic for now. It requires walking back the inputs which is not yet possible if (!(output instanceof Variable) && (!tf.scope().env().isEager())) { @@ -191,7 +193,7 @@ private static Operand binaryCrossentropyHelper( // TODO if (output.op().numInputess() != 1) // TODO throw new IllegalArgumentException("output can only have 1 output"); // TODO output = output.op().inout(0); - // TODO return tf.nn.sigmoidCrossEntropyWithLogits(target, output); + // TODO return fop.nn.sigmoidCrossEntropyWithLogits(target, output); // TODO} } */ @@ -235,6 +237,7 @@ public static Operand categoricalCrossentropy( boolean fromLogits, float labelSmoothing, int axis) { + FrameworkOps fop = FrameworkOps.create(tf); Class predictionType = predictions.type(); Operand tLabels = cast(tf, labels, predictionType); LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); @@ -245,7 +248,7 @@ public static Operand categoricalCrossentropy( tLabels = smoothCategoricalLabels(tf, tLabels, labelSmoothing); } if (fromLogits) { - return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis); + return fop.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis); } /* TODO if (!(predictions instanceof Variable) && (!tf.scope().env().isEager())) { @@ -255,7 +258,7 @@ public static Operand categoricalCrossentropy( if (predictions.op().numOutputs() != 1) throw new IllegalArgumentException("output can only have 1 output"); predictions = predictions.op().output(0); - return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); + return fop.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); } } */ @@ -516,6 +519,7 @@ public static Operand sparseCategoricalCrossentropy( boolean fromLogits, int axis) { Class predictionType = predictions.type(); + FrameworkOps fop = FrameworkOps.create(tf); Operand epsilonConst = cast(tf, tf.constant(EPSILON), predictionType); Operand one = cast(tf, tf.constant(1), predictionType); Operand oneMinusEpsilonConst = tf.math.sub(one, epsilonConst); @@ -568,9 +572,8 @@ public static Operand sparseCategoricalCrossentropy( tf.constant( new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)})); } - - @SuppressWarnings("unchecked") - Operand loss = tf.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); + + Operand loss = fop.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); if (updateShape && predictionsRank >= 3) { Shape newShape = predictionsShape.take(predictionsShape.numDimensions() - 1); loss = tf.reshape(loss, tf.constant(newShape)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index 40336233d21..a82e1760d1f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,36 +15,21 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; -import org.tensorflow.framework.losses.impl.LossTuple; -import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; -import org.tensorflow.framework.utils.SparseTensor; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.op.core.Assign; -import org.tensorflow.op.core.OneHot; -import org.tensorflow.op.core.Rank; -import org.tensorflow.op.core.Squeeze; -import org.tensorflow.op.core.Stack; -import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; -import org.tensorflow.op.nn.TopK; import org.tensorflow.types.TBool; -import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; -import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicLong; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -59,8 +44,8 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values - * } + * Asserts that the sampleWeights can be broadcast to the same shape as values + * * *

In losses and metrics, limited weight broadcasting is supported. Weights must be either * scalar, or the same rank as the target values, with each dimension either 1, or the same as the @@ -69,11 +54,11 @@ public class MetricsHelper { * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} - * can be broadcast to {@code values} + * @return Operation with control dependencies to ensure sampleWeight + * can be broadcast to values * @param the type of Operand - * @throws NotBroadcastableException If static checks determine {@code sampleWeights} has an - * incorrect shape that prohibit broadcasting to {@code values} + * @throws NotBroadcastableException If static checks determine sampleWeights has an + * incorrect shape that prohibit broadcasting to values */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -94,7 +79,7 @@ public static Op assertBroadcastable( && !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") - .withControlDependencies(java.util.Collections.EMPTY_LIST) + .withControlDependencies(Collections.EMPTY_LIST) .noOp(); } if (weightsRankStatic != valuesRankStatic) { @@ -104,8 +89,8 @@ public static Op assertBroadcastable( ASSERT_BROADCAST_ERROR_PREFIX, valuesRankStatic, weightsRankStatic, - valuesShapeStatic, - weightsShapeStatic)); + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); } for (int i = 0; i < valuesRankStatic; i++) { @@ -116,8 +101,8 @@ public static Op assertBroadcastable( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, i, - valuesShapeStatic, - weightsShapeStatic)); + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); } } return tf.withSubScope("staticDimsCheckSuccess") @@ -190,24 +175,25 @@ private static Operand canBroadcastNonscalarShapes( private static Operand canBroadcastDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("canBroadcastDims"); + FrameworkOps fops = FrameworkOps.create(tf); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2D = tf.expandDims(weightsShape, tf.constant(-1)); - Operand diffResult = SetsOps.difference(tf, weightsShape2D, validDims); + Operand diffResult = fops.sets.difference(weightsShape2D, validDims); Operand numInvalidDims = tf.size(diffResult); return tf.math.equal(tf.constant(0), numInvalidDims); } /** - * Broadcast {@code weights} to the same shape as {@code values}. + * Broadcast weights to the same shape as values. * * @param tf the TensorFlow ops - * @param weights Operand whose shape is broadcastable to {@code values}. + * @param weights Operand whose shape is broadcastable to values. * @param values Operand of any shape * @param the type of Operands - * @return {@code weights} broadcast to {@code values} shape + * @return weights broadcast to values shape */ public static Operand broadcastWeights( Ops tf, Operand weights, Operand values) { @@ -228,473 +214,11 @@ public static Operand broadcastWeights( return ctf.math.mul(weights, tf.onesLike(values)); } - /** - * Checks that all the Symbolic Shapes are consistent. - * - * @param tf the TensorFlow Ops - * @param symbols the list of Symbolic Shapes - * @param message the error message if the shapes are not consistent. - * @return a list of Operands to check the consistency of the symbolic shapes ready to add to a - * control dependency. - */ - public static List assertShapes( - Ops tf, List> symbols, String message) { - List updateOperations = new ArrayList<>(); - // check that the symbolic shape rank matches the operands rank. - symbols.forEach( - symbol -> { - Operand operand = symbol.getOperand(); - int rank = symbol.rank(); - Rank tfRank = tf.rank(operand); - Op assertion = - tf.withSubScope("assertShapes-1") - .assertThat( - tf.math.equal(tfRank, tf.constant(rank)), - Collections.singletonList(tf.constant(message))); - updateOperations.add(assertion); - }); - - Map> dict = new HashMap<>(); - - // check that each operand's dimension size equals the corresponding symbolic shape's dimensions - // size - symbols.forEach( - symbol -> { - AtomicLong ll = new AtomicLong(); - symbol - .getSymbols() - .forEach( - s -> { - Operand size = dict.get(s); - if (size == null) { - // save size for later checks - size = - tf.shape.size(symbol.getOperand(), tf.constant(ll.get()), TInt64.class); - dict.put(s, size); - } - Op assertion = - tf.withSubScope("assertShapes-2") - .assertThat( - tf.math.equal( - tf.shape.size( - symbol.getOperand(), - tf.constant(ll.getAndIncrement()), - TInt64.class), - size), - Collections.singletonList(tf.constant(message))); - updateOperations.add(assertion); - }); - }); - - return updateOperations; - } + // aliases for mean /** - * Returns an op to update the given confusion matrix variables. - * - *

For every pair of values in {@code labels} and {@code predictions}: - * - *

-   * TRUE_POSITIVES:  {@code labels} == true and {@code predictions} > thresholds
-   * FALSE_POSITIVES: {@code labels} == true and {@code predictions} <= thresholds
-   * TRUE_NEGATIVES:  {@code labels} == false and {@code predictions} <= thresholds
-   * FALSE_NEGATIVE:  {@code labels} == false and {@code predictions} > thresholds
-   * 
- * - *

The results will be weighted and added together. When multiple thresholds are provided, we - * will repeat the same for every threshold. - * - *

For estimation of these metrics over a stream of data, the function creates an `update_op` - * operation that updates the given variables. - * - *

{@code labels}, {@code predictions}, and {@code sampleWeight} tensors are - * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code - * sampleWeight} is then broadcast to the shape of {@code predictions}. - * - * @param tf the TensorFlow Ops - * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding variables to update as values. If {@code multiLabel}, then the variable - * shapes are (T, D), where T is the number of thresholds and D is the number of classes - * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then - * the variable shapes are (T). - * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and - * corresponding initializer Operands to for {@code variablesToUpdate}. - * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the - * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra - * dimension of size 1 that would be squeezed. - * @param predictions the predictions shape (N, Cx, P1?) - * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when - * topK is set - * @param topK optional, indicates that only the top k predictions should be considered. Applied - * before possibly slicing by {@code classIndex}. - * @param classIndex optional, limits the prediction and labels to the specified class. This is an - * integer index into the first dimension of Cx. - * @param sampleWeight optional {@code Tensor} that is aligned with labels and predictions as - * explained above. Use weights of 0 to mask values. - * @param multiLabel indicates whether multidimensional prediction/labels should be treated as - * multilabel responses, or flattened into a single label. When true, the values of {@code - * variablesToUpdate} must have a second dimension equal to the number of labels and - * predictions per example, and those tensors must not be RaggedTensors. - * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied - * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES - * without explicit multilabel handling (i.e. when the data is to be flattened). Must have - * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex - * } is provided, then the final dimension of Dx is 1. These weights will be broadcast - * across the 0th dimension (the examples dimension) of {@code predictions}. May be null. - * Must be null if {@code multiLabel}. - * @param the data type for the variables - * @throws IllegalArgumentException If {@code predictions} and {@code labels} have - * mismatched shapes, or if {@code sampleWeight} is not null and its shape - * doesn't match {@code predictions}, or if {@code multiLabel && labelWeights != null}.. - * @return an op to update the given confusion matrix variables. - */ - @SuppressWarnings({"unchecked", "rawtypes"}) - public static List updateConfusionMatrixVariables( - Ops tf, - Map> variablesToUpdate, - Map> varInitializers, - Operand labels, - Operand predictions, - Operand thresholds, - Integer topK, - Integer classIndex, - Operand sampleWeight, - boolean multiLabel, - Operand labelWeights) { - if (multiLabel && labelWeights != null) - throw new IllegalArgumentException( - "labelWeights for multilabel data should be handled outside of updateConfusionMatrixVariables when multiLabel is true."); - - if (variablesToUpdate == null || variablesToUpdate.isEmpty()) { - return Collections.EMPTY_LIST; - } - - Operand tLabels = labels; - Operand tPredictions = predictions; - Operand tSampleWeight = sampleWeight; - - // We will tile data for threshold comparisons. We want a cross product of thresholds and - // predictions/labels: - // In the multilabel case, we want a data shape of (T, N, D). - // else (T, ND). - // where - // T is numThresholds (the size of the 0th dimension of thresholds) - // N is the number of examples (the 0th dimension of labels and predictions) - // Dx == Cx except that if classIndex != null, - // then the last dimension of Dx is size 1 - // D is the product of all Dx - // ND is N * D - - // size of the 0th dimension of thresholds - // reshape to scalar for operations later. - Operand numThresholds = - tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); - - // if multilabel, then (rank(thresholds) == 1) - // else true - Operand oneThresh; - if (multiLabel) { - oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds)); - } else { - // TODO handle Ragged Tensors???? - // [y_pred, - // y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], - // sampleWeights) - oneThresh = tf.constant(true); - } - - List controlOps = new ArrayList<>(); - Operand axes = allAxes(tf, tPredictions); - controlOps.add( - tf.withSubScope("updateConfusionMatrixVariables-1") - .assertThat( - tf.reduceAll( - tf.math.greaterEqual( - tPredictions, cast(tf, tf.constant(0), tPredictions.type())), - axes), - Collections.singletonList(tf.constant("predictions must be >= 0")))); - controlOps.add( - tf.withSubScope("updateConfusionMatrixVariables-2") - .assertThat( - tf.reduceAll( - tf.math.lessEqual(tPredictions, cast(tf, tf.constant(1), tPredictions.type())), - axes), - Collections.singletonList(tf.constant("predictions must be <= 1")))); - - LossTuple result = - LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions, tSampleWeight); - tPredictions = result.getTarget(); // shape (N, Cx) - tLabels = result.getLabels(); // shape (N, Cx) - tSampleWeight = result.getSampleWeights(); // broadcastable to (N, Dx) - - if (!tPredictions.shape().isCompatibleWith(tLabels.shape())) - throw new IllegalArgumentException( - String.format( - "Shapes %s and %s are incompatible)", - tPredictions.shape().toString(), tLabels.shape().toString())); - - if (topK != null) { - tPredictions = filterTopK(tf, tPredictions, topK); - } - - if (classIndex != null) { - // Slice to new shapes (N, Dx) - tLabels = tf.squeeze(tf.gather(tLabels, - tf.constant(new int[] {classIndex}), tf.constant(-1)), - Squeeze.axis(Collections.singletonList(1L))); - tPredictions = tf.squeeze(tf.gather(tPredictions, - tf.constant(new int[] {classIndex}), tf.constant(-1)), - Squeeze.axis(Collections.singletonList(1L))); - } - org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); - - Operand numExamples = - tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); - - // number of labels (and predictions) per example (after possibly slicing by classIndex) - // In the notation we are using for comments, this is D. - Operand numLabels = - tf.select( - tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), - tf.constant(1), - tf.reduceProd( - // take all but the first dimension - tf.shape.takeLast( - predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), - tf.constant(0))); - - // threshLabelTile == numLabels except in one case: - // if multilabel and rank(thresholds) != 1, then threshLabelTile is 1 - Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); - - // if multilabel, then shape (1, N, Dx) - // else shape (1, ND), - Operand predictionsExtraDim; - Operand labelsExtraDim; - - if (multiLabel) { - predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0)); - labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0)); - } else { - predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1))); - labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1))); - } - - // the shape of each thresholds tile - // if multilabel, then [T, 1, -1] - // else [T, -1] - List> threshPretileShape; - - // the tiling multiples for thresholds - // We want to repeat the thresholds for each data position. - // if multilabel, then [1, N, threshLabelTile]. (threshLabelTile is typically numLabels) - // else [1, ND] - List> threshTiles; - - // tiling multiples for predictionsExtraDim and labelsExtraDim - // We want to repeat the predictions and labels for each threshold. - // If multilabel, then [T, 1, 1] - // else [T, 1] - List> dataTiles; - - if (multiLabel) { - threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); - threshTiles = Arrays.asList(tf.constant(1), numExamples, threshLabelTile); - dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1)); - } else { - threshPretileShape = - Arrays.asList(tf.reshape(numThresholds, tf.constant(Shape.scalar())), tf.constant(-1)); - Operand mul = tf.math.mul(numExamples, numLabels); - threshTiles = Arrays.asList(tf.constant(1), mul); - dataTiles = Arrays.asList(numThresholds, tf.constant(1)); - } - - // if multilabel, then shape (T, 1, T*) - // else shape (T, T*) - // where T* is the product of all threshold dimension sizes beyond 0 - Operand thresholdsReshaped = - tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); - - Operand threshTilesShape = tf.stack(threshTiles); - - // if multilabel, then - // if thresholds has rank > 1, then shape (T, N, T*) - // else shape (T, N, D) - // else shape (T, ND) - Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); - - Operand dataTilesShape = tf.stack(dataTiles); - - // if multilabel, then shape (T, N, D) - // else (T, ND) - Operand predsTiled = tf.tile(predictionsExtraDim, dataTilesShape); - - // Compare predictions and threshold. - Operand predIsPos = tf.math.greater(predsTiled, threshTiled); - // Tile labels by number of thresholds - Operand labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles)); - Operand weightsTiled; - if (tSampleWeight != null) { - tSampleWeight = tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); - // if multilabel, then - // reshape tSampleWeight to (1, N, threshLabelTile) - // tile the result into shape (T, N, threshLabelTile) - // where threshLabelTile is typically D - // else - // reshape tSampleWeight to (1, ND) - // tile the result into shape (T, ND) - weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape); - } else { - weightsTiled = null; - } - - if (labelWeights != null) { - // Change shape to (1, Dx). - Operand lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0)); - - // Broadcast to shape (N, Dx). - lLabelWeights = tf.broadcastTo(lLabelWeights, tPredictions); - - // If multilabel: shape (T, N, D) - // else: shape (T, ND) - Operand labelWeightsTiled = - tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles)); - - if (weightsTiled == null) { - weightsTiled = labelWeightsTiled; - } else { - weightsTiled = tf.math.mul(weightsTiled, labelWeightsTiled); - } - } - - Map loopVars = new HashMap<>(); - loopVars.put(ConfusionMatrixEnum.TRUE_POSITIVES, new Operand[] {labelIsPos, predIsPos}); - Variable updateTN = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES); - Variable updateFP = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES); - Variable updateFN = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES); - - Operand predIsNeg = null; - Operand labelIsNeg; - if (updateFN != null || updateTN != null) { - predIsNeg = tf.math.logicalNot(predIsPos); - loopVars.put(ConfusionMatrixEnum.FALSE_NEGATIVES, new Operand[] {labelIsPos, predIsNeg}); - } - - if (updateFP != null || updateTN != null) { - labelIsNeg = tf.math.logicalNot(labelIsPos); - loopVars.put(ConfusionMatrixEnum.FALSE_POSITIVES, new Operand[] {labelIsNeg, predIsPos}); - if (updateTN != null) { - loopVars.put(ConfusionMatrixEnum.TRUE_NEGATIVES, new Operand[] {labelIsNeg, predIsNeg}); - } - } - - final Operand weightsTiledF = weightsTiled; - loopVars - .keySet() - .forEach( - (c) -> { - if (variablesToUpdate.containsKey(c)) { - Operand[] op = loopVars.get(c); - // op[0] = label, op[1] == prediction - controlOps.add( - weightedAssignAdd( - tf, - op[0], - op[1], - weightsTiledF, - variablesToUpdate.get(c), - varInitializers.get(c))); - } - }); - - return controlOps; - } - - /** - * Creates an Operand that adds the values by taking the logical and of labels and predictions to - * the specified confusion matrix variable. - * - * @param tf The TensorFlow Ops - * @param labels the labels - * @param predictions the predictions - * @param weights the weights applied to the logical and result, may be null - * @param variable the variable to update - * @param initializer the variable initializer to be applied to the variable, may be null. - * @param the data type for the variable. - * @return an Operand that updates the variable. - */ - private static Operand weightedAssignAdd( - Ops tf, - Operand labels, - Operand predictions, - Operand weights, - Variable variable, - Assign initializer) { - Class type = variable.type(); - Operand labelAndPred = cast(tf, tf.math.logicalAnd(labels, predictions), type); - - if (weights != null) { - labelAndPred = tf.math.mul(labelAndPred, weights); - } - // if multilabel: - // sum across examples, leaving shape (T, D) - // else: - // sum across ND, leaving shape (T) - Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); - Operand assignAdd; - if (initializer != null) { - Ops tfc = - tf.withSubScope("weightedAssignAdd") - .withControlDependencies(Collections.singletonList(initializer)); - assignAdd = tfc.assignAdd(variable, valueSum); - } else { - assignAdd = tf.assignAdd(variable, valueSum); - } - return assignAdd; - } - - /** - * Filters top-k values in the last dim of x and set the rest to NEG_INF. - * - *

Used for computing top-k prediction values in dense labels (which has the same shape as - * predictions) for recall and precision top-k metrics. - * - * @param tf The TensorFlow Ops - * @param x the tensor with any dimensions to filter - * @param topK the number of values to keep. - * @param the data type for x and the return value. - * @return the topK prediction values. - */ - private static Operand filterTopK(Ops tf, Operand x, int topK) { - Class type = x.type(); - Shape xShape = x.shape(); - // top has the same rank as x; the last dimension becomes indices of the topK features. - TopK top = tf.nn.topK(x, tf.constant(topK), TopK.sorted(false)); - // oneHot has an additional dimension: the one-hot representation of each topK index. - OneHot oneHot = - tf.oneHot( - top.indices(), - cast(tf, tf.constant(xShape.size(xShape.numDimensions() - 1)), TInt32.class), - tf.constant(1), - tf.constant(0), - OneHot.axis(-1L)); - // Sum the one-hot representations along the last dimension of x. - Operand topKMask = cast(tf, tf.reduceSum(oneHot, tf.constant(-2)), type); - - // x * top_k_mask + NEG_INF * (1 - top_k_mask) - Operand add1 = tf.math.mul(x, topKMask); - Operand add2 = - tf.math.mul( - cast(tf, tf.constant(NEG_INF), type), - tf.math.sub(cast(tf, tf.constant(1), type), topKMask)); - return tf.math.add(add1, add2); - } - - // alias for mean - - /** - * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false - * } + * Calculate the mean of the operand, along all axes and keepDims is false + * * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -706,8 +230,8 @@ public static Operand mean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is - * {@code false} + * Calculate the mean of the operand, alongside the specified axis with keepDims is + * false * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -725,12 +249,12 @@ public static Operand mean( * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained * with length 1. * @param the type of the operand - * @return the mean of elements of {@code x}. + * @return the mean of elements of x. */ public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); @@ -742,12 +266,12 @@ public static Operand mean(Ops tf, Operand x, boolean * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained * with length 1. * @param the data type of the Operand - * @return the mean of elements of {@code x}. + * @return the mean of elements of x. */ public static Operand mean( Ops tf, Operand x, Operand axes, boolean keepDims) { @@ -757,134 +281,9 @@ public static Operand mean( return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } - public static - LossTuple raggedAssertCompatibleAndGetFlatValues( - Ops tf, Operand labels, Operand predictions) { - // TODO handle ragged Tensors - Operand tLabels = cast(tf, labels, predictions.type()); - return new LossTuple<>(tLabels, predictions); - } - - /** - * Computes the confusion matrix from predictions and labels. - * - *

The matrix columns represent the prediction labels and the rows represent the real labels. - * The confusion matrix is always a 2-D array of shape {@code [n, n]}, where {@code n} is the - * number of valid labels for a given classification task. Both prediction and labels must be 1-D - * arrays of the same shape in order for this function to work. - * - *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum - * value in either predictions or labels. Class labels are expected to start at 0. For example, if - * {@code numClasses}` is 3, then the possible labels would be {@code [0, 1, 2]}. - * - *

If {@code weights} is not null, then each prediction contributes its corresponding weight to - * the total value of the confusion matrix cell. - * - *

For example: - * - *

{@code
-   *     confusion_matrix([1, 2, 4], [2, 2, 4]) ==>
-   *          [[0 0 0 0 0]
-   *           [0 0 1 0 0]
-   *           [0 0 1 0 0]
-   *           [0 0 0 0 0]
-   *           [0 0 0 0 1]]
-   * }
- * - * Note that the possible labels are assumed to be {@code [0, 1, 2, 3,4]}, resulting in a 5x5 - * confusion matrix. - * - * @param tf the TensorFlow Ops - * @param labels 1-D {@code Operand} of real labels for the classification task. - * @param predictions 1-D {@code Operand} of predictions for a given classification. - * @param numClasses The possible number of labels the classification task can have. If this value - * is not provided, it will be calculated using both predictions and labels array. - * @param weights optional weights to be applied to the confusion matrix - * @param type Data type of the confusion matrix. - * @param the type of Operands - * @return A {@code Operand} of type {@code type} with shape {@code [n, n]} - * representing the confusion matrix, where {@code n} is the number of possible labels in - * the classification task. - * @throws IllegalArgumentException If both {@code predictions} and {@code labels} do - * not have compatible shapes, or if {@code weights} is not{@code null} and its - * shape is not compatible with {@code predictions}. - */ - // TODO should this be moved to FramnworkOps under math. - public static Operand confusionMatrix( - Ops tf, - Operand labels, - Operand predictions, - Operand numClasses, - Operand weights, - Class type) { - if (!predictions.shape().isCompatibleWith(labels.shape())) - throw new IllegalArgumentException( - String.format( - "Prediction shape %s is not compatible with labels shape %s", - predictions.shape().toString(), labels.shape().toString())); - tf = tf.withSubScope("confusionMatrix"); - LossTuple ops = LossesHelper.squeezeOrExpandDimensions(tf, predictions, labels, null); - Operand tPredictions = cast(tf, ops.getTarget(), TInt64.class); - Operand tLabels = cast(tf, ops.getLabels(), TInt64.class); - - List labelControls = new ArrayList<>(); - List predictionControls = new ArrayList<>(); - - labelControls.add( - tf.assertThat( - tf.reduceAny(tf.math.greaterEqual(tLabels, tf.constant(0L)), allAxes(tf, tLabels)), - Collections.singletonList(tf.constant("`labels` contains negative values")))); - - predictionControls.add( - tf.assertThat( - tf.reduceAny( - tf.math.greaterEqual(tPredictions, tf.constant(0L)), allAxes(tf, tPredictions)), - Collections.singletonList(tf.constant("`predictions` contains negative values")))); - if (numClasses == null) { - numClasses = - tf.math.maximum( - tf.reduceMax(tPredictions, allAxes(tf, tPredictions)), - tf.reduceMax(tLabels, allAxes(tf, tLabels))); - } else { - labelControls.add( - tf.assertThat( - tf.reduceAny(tf.math.less(tLabels, numClasses), allAxes(tf, tLabels)), - Collections.singletonList(tf.constant("``labels` out of bounds")))); - predictionControls.add( - tf.assertThat( - tf.reduceAny(tf.math.less(tPredictions, numClasses), allAxes(tf, tPredictions)), - Collections.singletonList(tf.constant("``predictions` out of bounds")))); - } - - if (weights != null) { - if (!tPredictions.shape().isCompatibleWith(weights.shape())) { - throw new IllegalArgumentException( - String.format( - "Prediction shape %s is not compatible with weights shape %s", - tPredictions.shape().toString(), weights.shape().toString())); - } - } - - Ops tfc = tf.withSubScope("confusionMatrixLabels").withControlDependencies(labelControls); - tLabels = tfc.identity(tLabels); - - tfc = tf.withSubScope("confusionMatrixPredictions").withControlDependencies(predictionControls); - tPredictions = tfc.identity(tPredictions); - - Operand shape = tf.stack(Arrays.asList(numClasses, numClasses)); - Operand indices = tf.stack(Arrays.asList(tLabels, tPredictions), Stack.axis(1L)); - Operand values = - weights == null ? cast(tf, tf.onesLike(tPredictions), type) : cast(tf, weights, type); - SparseTensor cmSparse = new SparseTensor<>(indices, values, shape); - Operand zeroMatrix = tf.zeros(shape, type); - - return tf.sparse.sparseTensorDenseAdd( - cmSparse.getIndices(), cmSparse.getValues(), cmSparse.getDenseShape(), zeroMatrix); - } - /** - * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false - * } + * Calculate the mean of the operand, along all axes and keepDims is false + * * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -895,8 +294,8 @@ public static Operand booleanMean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is - * {@code false} + * Calculate the mean of the operand, alongside the specified axis with keepDims is + * false * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -913,11 +312,11 @@ public static Operand booleanMean( * * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained * with length 1. - * @return the mean of elements of {@code x} containing floating point numbers + * @return the mean of elements of x containing floating point numbers */ public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); @@ -929,11 +328,11 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is - * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes - * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained * with length 1. - * @return the mean of elements of {@code x} containing floating point numbers + * @return the mean of elements of x containing floating point numbers */ public static Operand booleanMean( Ops tf, Operand x, Operand axes, boolean keepDims) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java deleted file mode 100644 index 68157632557..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java +++ /dev/null @@ -1,147 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.metrics.impl; - -import org.tensorflow.Operand; -import org.tensorflow.op.Ops; -import org.tensorflow.op.SparseOps; -import org.tensorflow.op.sparse.DenseToDenseSetOperation; -import org.tensorflow.types.family.TNumber; - -import static org.tensorflow.framework.utils.CastHelper.cast; - -/** Implementation of set operations */ -public class SetsOps { - - /** - * Computes set difference of elements in last dimension of {@code a} and {@code b} with - * {@code aMinusB} set to true. - * - *

All but the last dimension of {@code a} and {@code b} must match - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand difference(Ops tf, Operand a, Operand b) { - return difference(tf, a, b, true); - } - - /** - * Computes set difference of elements in last dimension of {@code a} and {@code b}. - * - *

All but the last dimension of {@code a} and {@code b} must match - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param aMinusB whether to subtract b from a, vs vice versa. - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand difference( - Ops tf, Operand a, Operand b, boolean aMinusB) { - return setOperation(tf, a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); - } - - /** - * Computes set union of elements in last dimension of {@code a} and {@code b}. - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand union(Ops tf, Operand a, Operand b) { - return setOperation(tf, a, b, Operation.UNION); - } - - /** - * Computes set intersection of elements in last dimension of {@code a} and {@code b}. - * - * @param tf the TensorFlow Ops - * @param a The first operand representing set {@code a} - * @param b The other operand representing set {@code b} - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the * same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand intersection(Ops tf, Operand a, Operand b) { - return setOperation(tf, a, b, Operation.INTERSECTION); - } - - /** - * Compute set operation of elements in last dimension of {@code a} and {@code b}. - * - * @param tf the TensorFlow Ops - * @param a The first set operation operand - * @param b The other et operation operand - * @param setOperation The set operation to perform, {@link Operation}. - * @param the data type for the sets - * @return An Operand with the same rank as {@code a} and {@code b}, and all but the - * last dimension the same. Elements along the last dimension contain the results of the set - * operation. - */ - public static Operand setOperation( - Ops tf, Operand a, Operand b, Operation setOperation) { - - DenseToDenseSetOperation setOperationResult = - tf.sparse.denseToDenseSetOperation( - a, b, setOperation.getSetOperation(), DenseToDenseSetOperation.validateIndices(true)); - - return tf.sparse.sparseToDense( - setOperationResult.resultIndices(), - setOperationResult.resultShape(), - setOperationResult.resultValues(), - cast(tf, tf.constant(0), a.type())); - } - - /** - * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops - * function {@link SparseOps#denseToDenseSetOperation} - */ - public enum Operation { - A_MINUS_B("a-b"), - B_MINUS_A("b-a"), - INTERSECTION("intersection"), - UNION("union"); - - private final String setOperation; - - Operation(String setOperation) { - this.setOperation = setOperation; - } - - /** - * Gets the set operation String value used to pass as the stringOperation value to {@link - * SparseOps#denseToDenseSetOperation} - * - * @return the set operation String value - */ - public String getSetOperation() { - return setOperation; - } - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java new file mode 100644 index 00000000000..cecbecfed15 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -0,0 +1,136 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.DeviceSpec; +import org.tensorflow.EagerSession; +import org.tensorflow.ExecutionEnvironment; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; + +/** + * An API for building framework operations as {@link Op Op}s + * + *

These are higher level ops that may invoke core ops. Higher level Ops may perform the + * operation solely in the TensorFlow framework or do preprocessing of the Operands before invoking + * a core level Op. + */ +public class FrameworkOps { + public final Ops coreOps; + private final Scope scope; + + public final NnOps nn; + public final SetsOps sets; + + /** + * Creates a FrameworkOps instance with the provided scope + * + * @param scope the scope + */ + private FrameworkOps(Scope scope) { + this.coreOps = Ops.create(scope.env()); + this.scope = scope; + nn = new NnOps(this); + sets = new SetsOps(this); + } + + /** + * Creates a FrameworkOps instance based on the provided Core Ops + * + * @param coreOps The TensorFlow Core Ops + */ + private FrameworkOps(Ops coreOps) { + this.coreOps = coreOps; + this.scope = coreOps.scope(); + nn = new NnOps(this); + sets = new SetsOps(this); + } + + + /** Returns the current {@link Scope scope} of this API */ + public final Scope scope() { + return scope; + } + + /** + * Gets the core Ops + * + * @return coreOps + */ + public final Ops coreOps() { + return coreOps; + } + + /** + * Returns an API that builds operations with the provided name prefix. + * + *

@link Scope#withSubScope(String)} + */ + public FrameworkOps withSubScope(String childScopeName) { + return new FrameworkOps(scope.withSubScope(childScopeName)); + } + + /** + * Returns an API that uses the provided name for an op. + * + *

{@link Scope#withName(String)} + */ + public FrameworkOps withName(String opName) { + return new FrameworkOps(scope.withName(opName)); + } + + /** + * Returns an API that places the created operations on the device(s) matching the provided spec. + * + *

{@link Scope#withDevice(DeviceSpec)} + */ + public FrameworkOps withDevice(DeviceSpec deviceSpec) { + return new FrameworkOps(scope.withDevice(deviceSpec)); + } + + /** + * Returns an API that adds operations to the graph with the provided control dependencies. + * + *

{@link Scope#withControlDependencies(Iterable)} + */ + public FrameworkOps withControlDependencies(Iterable controls) { + return new FrameworkOps(scope.withControlDependencies(controls)); + } + + /** Creates an API for building operations in the provided execution environment */ + public static FrameworkOps create(ExecutionEnvironment env) { + return new FrameworkOps(new Scope(env)); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + *

Invoking this method is equivalent to {@code + * FrameworkOps.create(EagerSession.getDefault())}. + */ + public static FrameworkOps create() { + return new FrameworkOps(new Scope(EagerSession.getDefault())); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + * @param coreOps the TensorFlow core Ops + */ + public static FrameworkOps create(Ops coreOps) { + return new FrameworkOps(coreOps); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java new file mode 100644 index 00000000000..4054f3ddbb5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java @@ -0,0 +1,197 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.framework.op.nn.SigmoidCrossEntropyWithLogits; +import org.tensorflow.framework.op.nn.SoftmaxCrossEntropyWithLogits; +import org.tensorflow.framework.op.nn.SparseSoftmaxCrossEntropyWithLogits; +import org.tensorflow.op.Op; +import org.tensorflow.op.Scope; +import org.tensorflow.types.family.TNumber; + +/** + * An API for building {@code nn} operations as {@link Op Op}s + * + *

These are higher level ops that may invoke core ops. Higher level Ops may perform the + * operation solely in the TensorFlow framework or do preprocessing of the Operands before invoking + * a core level Op. + * + *

{@link FrameworkOps} + */ +public class NnOps { + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * @param frameworkOps the TensorFLow framework Ops + */ + NnOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Computes sigmoid cross entropy given logits. + * + *

Measures the probability error in discrete classification tasks in which each class is + * independent and not mutually exclusive. For instance, one could perform multilabel + * classification where a picture can contain both an elephant and a dog at the same time. + * + *

For brevity, let x = logits, z = labels. The logistic loss in + * pseudo-code is + * + *

+     *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+     *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
+     *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
+     *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
+     *   = (1 - z) * x + log(1 + exp(-x))
+     *   = x - x * z + log(1 + exp(-x))
+     *  
+ * + *

For x < 0, to avoid overflow in exp(-x), we reformulate the above + * + *

+     *  x - x * z + log(1 + exp(-x))
+     *   = log(exp(x)) - x * z + log(1 + exp(-x))
+     *   = - x * z + log(1 + exp(x))
+     *  
+ * + *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent + * formulation + * + *

+     *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
+     *  
+ * + *

logits and labels must have the same type and shape. + * + *

+ * + * @param labels the labels + * @param logits the logits of type float32 or float64 + * @param the type of labels and logits + * @return the component-wise logistic losses. + * @throws IllegalArgumentException if logits' and labels' do not have the same shape + */ + public Operand sigmoidCrossEntropyWithLogits(Operand labels, + Operand logits) { + return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); + } + + /** + * Computes softmax cross entropy between logits and labels. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of labels is a valid probability distribution. If they + * are not, the computation of the gradient will be incorrect. + * + *

If using exclusive labels (wherein one and only one class is true at a time), + * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} + * + *

Usage: + * + *

+     *    Operand<TFloat32> logits =
+     *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
+     *    Operand<TFloat32> labels =
+     *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
+     *    Operand<TFloat32> output =
+     *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+     *    // output Shape = [2]
+     *    // dataType = FLOAT (1)
+     *    // values { 0.169846, 0.824745 }
+     *  
+ * + *

Backpropagation will happen into both logits and labels. To + * disallow backpropagation into labels, pass label tensors through + * tf.stopGradient before feeding it to this function. + * + * @param labels Each vector along the class dimension should hold a valid probability + * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] + * , each row of labels[i] must be a valid probability distribution. + * @param logits Per-label activations, typically a linear output. These activation energies are + * interpreted as unnormalized log probabilities. + * @param axis The class dimension. -1 is the last dimension. + * @param the number type of the operands + * @return the softmax cross entropy loss. Its type is the same as logits and its + * shape is the same as labels except that it does not have the last dimension of + * labels. + */ + public Operand softmaxCrossEntropyWithLogits( + Operand labels, Operand logits, int axis) { + return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + } + + /** + * Computes sparse softmax cross entropy between logits and labels. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

For this operation, the probability of a given label is considered exclusive. That is, soft + * classes are not allowed, and the labels vector must provide a single specific + * index for the true class for each row of logits (each minibatch entry). For soft + * softmax classification with a probability distribution for each entry, {@link + * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. + * + *

WARNING: + * + *

This op expects unscaled logits, since it performs a softmax on logits + * internally for efficiency. Do not call this op with the output of softmax, + * as it will produce incorrect results. + * + *

A common use case is to have logits of shape [batchSize, numClasses] and have + * labels of shape [batchSize], but higher dimensions are supported, in which case + * the dim-th dimension is assumed to be of size numClasses. + * logits must have the dataType of TFloat16, TFloat32 + * , or TFloat64, and labels must have the dtype of TInt32 + * or TInt64. + * + * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r + * is rank of labels and result) and the dataType is TInt32 + * or TInt64. Each entry in labels must be an index in [0, + * numClasses). Other values will raise an exception when this op is run on CPU, and + * return NaN for corresponding loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., + * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, + * or TFloat64. These activation energies are interpreted as unnormalized log + * probabilities. + * @param The data type for the labels + * @param The data type for the logits and loss + * @return the loss + * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank + * of the labels is not equal to the rank of the logits minus one. + */ + + public Operand sparseSoftmaxCrossEntropyWithLogits( + Operand labels, Operand logits) { + return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); + } + + +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java new file mode 100644 index 00000000000..d7833cdbb06 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java @@ -0,0 +1,161 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.op.Scope; +import org.tensorflow.op.SparseOps; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.sparse.DenseToDenseSetOperation; +import org.tensorflow.op.sparse.SparseToDense; +import org.tensorflow.types.family.TNumber; + +/** Implementation of set operations */ +public class SetsOps { + + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + SetsOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Computes set difference of elements in last dimension of a and b with + * aMinusB set to true. + * + *

All but the last dimension of a and b must match + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand difference(Operand a, Operand b) { + return difference(a, b, true); + } + + /** + * Computes set difference of elements in last dimension of a and b. + * + *

All but the last dimension of a and b must match + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param aMinusB whether to subtract b from a, vs vice versa. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand difference(Operand a, Operand b, boolean aMinusB) { + return setOperation(a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); + } + + /** + * Computes set union of elements in last dimension of a and b. + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand union(Operand a, Operand b) { + return setOperation(a, b, Operation.UNION); + } + + /** + * Computes set intersection of elements in last dimension of a and b. + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand intersection(Operand a, Operand b) { + return setOperation(a, b, Operation.INTERSECTION); + } + + /** + * Compute set operation of elements in last dimension of a and b. + * + * @param a The first set operation operand + * @param b The other et operation operand + * @param setOperation The set operation to perform, {@link Operation}. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand setOperation( + Operand a, Operand b, Operation setOperation) { + + DenseToDenseSetOperation setOperationResult = + DenseToDenseSetOperation.create( + scope, + a, + b, + setOperation.getSetOperation(), + DenseToDenseSetOperation.validateIndices(true)); + + return SparseToDense.create( + scope, + setOperationResult.resultIndices(), + setOperationResult.resultShape(), + setOperationResult.resultValues(), + Cast.create(scope, Constant.scalarOf(scope, 0), a.type())); + } + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 64faa7c5d70..75766cf9bfb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -64,8 +64,7 @@ public class SparseSoftmaxCrossEntropyWithLogits { * probabilities. * @param the data type for the labels * @param the data tyoe for the loss and logits. - * @return A Tensor of the same shape as labels and of the same type as - * logits with the softmax cross entropy loss. + * @return the loss * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank * of the labels is not equal to the rank of the logits minus one. */ diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetOpsTest.java similarity index 86% rename from tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetOpsTest.java index eceff2797f8..e10f016bd94 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetOpsTest.java @@ -2,6 +2,8 @@ import org.junit.jupiter.api.Test; import org.tensorflow.Operand; +import org.tensorflow.framework.op.FrameworkOps; +import org.tensorflow.framework.op.SetsOps; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -15,7 +17,7 @@ import static org.tensorflow.framework.utils.CastHelper.cast; -class SetsOpsTest { +class SetOpsTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -28,6 +30,7 @@ public void testSetIntersectionMultirow2() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 5}}); int[][] expected = new int[][] {{1, 9}, {0, 0}}; @@ -35,7 +38,7 @@ public void testSetIntersectionMultirow2() { for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); + Operand intersection = fops.sets.intersection(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } @@ -49,6 +52,7 @@ public void testSetIntersectionDuplicates2d() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{1, 1, 3}}); Operand b = tf.constant(new int[][] {{1, 1}}); int[][] expected = {{1}}; @@ -56,7 +60,7 @@ public void testSetIntersectionDuplicates2d() { for (Class type : types) { Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = SetsOps.intersection(tf, aa, bb); + Operand intersection = fops.sets.intersection(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); @@ -72,6 +76,7 @@ public void testDenseSetDifferenceMultirow2d() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{1, 5, 9}, {4, 5, 3}}); Operand b = tf.constant(new int[][] {{1, 2, 6}, {1, 2, 2}}); @@ -81,14 +86,14 @@ public void testDenseSetDifferenceMultirow2d() { int[][] expected = {{5, 9, 0}, {3, 4, 5}}; // a- b Shape expectedShape = Shape.of(2, 3); - Operand intersection = SetsOps.difference(tf, aa, bb); + Operand intersection = fops.sets.difference(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); // b - a expected = new int[][] {{2, 6}, {1, 2}}; expectedShape = Shape.of(2, 2); - intersection = SetsOps.difference(tf, aa, bb, false); + intersection = fops.sets.difference(aa, bb, false); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); @@ -103,6 +108,7 @@ public void testDenseUnionMultirow2d() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); Operand b = tf.constant(new int[][] {{1, 9}, {1, 2}}); int[][] expected = new int[][] {{5, 0}, {3, 4}}; @@ -111,7 +117,7 @@ public void testDenseUnionMultirow2d() { Operand bb = cast(tf, b, type); Shape expectedShape = Shape.of(2, 2); // a- b - Operand intersection = SetsOps.difference(tf, aa, bb); + Operand intersection = fops.sets.difference(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } From ba24371189a09094c2941540d65ce100c57caf5e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 27 Mar 2021 15:21:11 -0400 Subject: [PATCH 09/31] Added FrameworkOps analogous to Ops. Added NnOps and SetOps as groups. Fixed MetricsHelper and Losses to use the bew FrameworkOps. Moved SetsOps to framework.op. --- .../src/main/java/org/tensorflow/framework/losses/Losses.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index aa5fa4ada6d..33c8d50409d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -572,7 +572,7 @@ public static Operand sparseCategoricalCrossentropy( tf.constant( new long[] {-1L, predictionsShape.size(predictionsShape.numDimensions() - 1)})); } - + Operand loss = fop.nn.sparseSoftmaxCrossEntropyWithLogits(iLabels, predictions); if (updateShape && predictionsRank >= 3) { Shape newShape = predictionsShape.take(predictionsShape.numDimensions() - 1); From 4d3f17cf4cff04dee66f2e00756d911eaf12e2bd Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 27 Mar 2021 15:36:41 -0400 Subject: [PATCH 10/31] Move l2Normalize to MathOps --- .../tensorflow/framework/losses/Losses.java | 23 ++----- .../tensorflow/framework/op/FrameworkOps.java | 3 + .../org/tensorflow/framework/op/MathOps.java | 67 +++++++++++++++++++ 3 files changed, 74 insertions(+), 19 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 33c8d50409d..398588cee67 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -337,13 +337,14 @@ public static Operand categoricalHinge( */ public static Operand cosineSimilarity( Ops tf, Operand labels, Operand predictions, int[] axis) { + FrameworkOps fops = FrameworkOps.create(tf); Operand tLabels = cast(tf, labels, predictions.type()); LossTuple lossTuple = LossesHelper.squeezeOrExpandDimensions(tf, tLabels, predictions, null); predictions = lossTuple.getTarget(); tLabels = lossTuple.getLabels(); - tLabels = l2Normalize(tf, tLabels, axis); - predictions = l2Normalize(tf, predictions, axis); + tLabels = fops.math.l2Normalize(tLabels, axis); + predictions = fops.math.l2Normalize(predictions, axis); Operand mathMul = tf.math.mul(tLabels, predictions); return tf.reduceSum(mathMul, tf.constant(axis), ReduceSum.keepDims(Boolean.FALSE)); } @@ -651,23 +652,7 @@ private static Operand smoothCategoricalLabels( return tf.math.add(tf.math.mul(labels, oneMinusSmoothing), tf.math.div(smoothing, numClasses)); } - // TODO this was tf.math.l2_normalize in TF Python - /** - * Normalizes along dimension axis using an L2 norm. - * - * @param tf The TensorFlow Ops - * @param x the input - * @param axis Dimension along which to normalize. - * @param the data type for the input and the result - * @return the normalized values based on L2 norm - */ - public static Operand l2Normalize(Ops tf, Operand x, int[] axis) { - Operand squareSum = - tf.reduceSum(tf.math.square(x), tf.constant(axis), ReduceSum.keepDims(Boolean.TRUE)); - Operand invNorm = - tf.math.rsqrt(tf.math.maximum(squareSum, cast(tf, tf.constant(1e-12F), x.type()))); - return tf.math.mul(x, invNorm); - } + /** * Converts binary labels into -1/1. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java index cecbecfed15..18fb8ada6b7 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -34,6 +34,7 @@ public class FrameworkOps { public final NnOps nn; public final SetsOps sets; + public final MathOps math; /** * Creates a FrameworkOps instance with the provided scope @@ -45,6 +46,7 @@ private FrameworkOps(Scope scope) { this.scope = scope; nn = new NnOps(this); sets = new SetsOps(this); + math = new MathOps(this); } /** @@ -57,6 +59,7 @@ private FrameworkOps(Ops coreOps) { this.scope = coreOps.scope(); nn = new NnOps(this); sets = new SetsOps(this); + math = new MathOps(this); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java new file mode 100644 index 00000000000..57a18fc63c2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.Scope; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.Maximum; +import org.tensorflow.op.math.Mul; +import org.tensorflow.op.math.Rsqrt; +import org.tensorflow.op.math.Square; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +public class MathOps { + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + MathOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Normalizes along dimension axis using an L2 norm. + * + * @param x the input + * @param axis Dimension along which to normalize. + * @param the data type for the input and the result + * @return the normalized values based on L2 norm + */ + public Operand l2Normalize(Operand x, int[] axis) { + Operand squareSum = + ReduceSum.create(scope, + Square.create(scope, x), + Constant.vectorOf(scope, axis), + ReduceSum.keepDims(Boolean.TRUE)); + Operand invNorm = + Rsqrt.create(scope, + Maximum.create(scope, squareSum, + Cast.create(scope, + Constant.scalarOf(scope, 1e-12F), x.type()))); + return Mul.create(scope, x, invNorm); + } +} From 9e07483e233df90c0ec7be793b5c7f11700933bf Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 27 Mar 2021 18:50:26 -0400 Subject: [PATCH 11/31] Reformat code, fix javadocs --- .../tensorflow/framework/op/FrameworkOps.java | 76 +++-- .../org/tensorflow/framework/op/MathOps.java | 68 ++-- .../org/tensorflow/framework/op/NnOps.java | 312 +++++++++--------- .../op/nn/SigmoidCrossEntropyWithLogits.java | 14 +- .../op/nn/SoftmaxCrossEntropyWithLogits.java | 3 +- .../SparseSoftmaxCrossEntropyWithLogits.java | 52 +-- 6 files changed, 271 insertions(+), 254 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java index 18fb8ada6b7..c8b234f2c51 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -30,11 +30,10 @@ */ public class FrameworkOps { public final Ops coreOps; - private final Scope scope; - public final NnOps nn; public final SetsOps sets; public final MathOps math; + private final Scope scope; /** * Creates a FrameworkOps instance with the provided scope @@ -62,8 +61,43 @@ private FrameworkOps(Ops coreOps) { math = new MathOps(this); } + /** + * Creates an API for building operations in the provided execution environment + * + * @param env the exection environment + * @return the FrameworkOps + */ + public static FrameworkOps create(ExecutionEnvironment env) { + return new FrameworkOps(new Scope(env)); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + *

Invoking this method is equivalent to {@code + * FrameworkOps.create(EagerSession.getDefault())}. + * + * @return the FrameworkOps + */ + public static FrameworkOps create() { + return new FrameworkOps(new Scope(EagerSession.getDefault())); + } + + /** + * Creates an API for building operations in the default eager execution environment + * + * @param coreOps the TensorFlow core Ops + * @return the FrameworkOps + */ + public static FrameworkOps create(Ops coreOps) { + return new FrameworkOps(coreOps); + } - /** Returns the current {@link Scope scope} of this API */ + /** + * Returns the current {@link Scope scope} of this API + * + * @return the current {@link Scope scope} of this API + */ public final Scope scope() { return scope; } @@ -81,6 +115,9 @@ public final Ops coreOps() { * Returns an API that builds operations with the provided name prefix. * *

@link Scope#withSubScope(String)} + * + * @param childScopeName the name of the child scope + * @return the FrameworkOps */ public FrameworkOps withSubScope(String childScopeName) { return new FrameworkOps(scope.withSubScope(childScopeName)); @@ -90,6 +127,9 @@ public FrameworkOps withSubScope(String childScopeName) { * Returns an API that uses the provided name for an op. * *

{@link Scope#withName(String)} + * + * @param opName the name of the scope + * @return the FrameworkOps */ public FrameworkOps withName(String opName) { return new FrameworkOps(scope.withName(opName)); @@ -99,6 +139,9 @@ public FrameworkOps withName(String opName) { * Returns an API that places the created operations on the device(s) matching the provided spec. * *

{@link Scope#withDevice(DeviceSpec)} + * + * @param deviceSpec the device specification for the scope + * @return the FrameworkOps */ public FrameworkOps withDevice(DeviceSpec deviceSpec) { return new FrameworkOps(scope.withDevice(deviceSpec)); @@ -108,32 +151,11 @@ public FrameworkOps withDevice(DeviceSpec deviceSpec) { * Returns an API that adds operations to the graph with the provided control dependencies. * *

{@link Scope#withControlDependencies(Iterable)} + * + * @param controls the operations + * @return the FrameworkOps */ public FrameworkOps withControlDependencies(Iterable controls) { return new FrameworkOps(scope.withControlDependencies(controls)); } - - /** Creates an API for building operations in the provided execution environment */ - public static FrameworkOps create(ExecutionEnvironment env) { - return new FrameworkOps(new Scope(env)); - } - - /** - * Creates an API for building operations in the default eager execution environment - * - *

Invoking this method is equivalent to {@code - * FrameworkOps.create(EagerSession.getDefault())}. - */ - public static FrameworkOps create() { - return new FrameworkOps(new Scope(EagerSession.getDefault())); - } - - /** - * Creates an API for building operations in the default eager execution environment - * - * @param coreOps the TensorFlow core Ops - */ - public static FrameworkOps create(Ops coreOps) { - return new FrameworkOps(coreOps); - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java index 57a18fc63c2..5208cde98f3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.op; import org.tensorflow.Operand; -import org.tensorflow.op.Ops; import org.tensorflow.op.Scope; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.ReduceSum; @@ -26,42 +25,41 @@ import org.tensorflow.op.math.Square; import org.tensorflow.types.family.TNumber; -import static org.tensorflow.framework.utils.CastHelper.cast; - public class MathOps { - private final Scope scope; + private final Scope scope; - private final FrameworkOps frameworkOps; + private final FrameworkOps frameworkOps; - /** - * Creates Framework {@code nn} Operations - * - * @param frameworkOps the TensorFLow framework Ops - */ - MathOps(FrameworkOps frameworkOps) { - this.scope = frameworkOps.scope(); - this.frameworkOps = frameworkOps; - } + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + MathOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } - /** - * Normalizes along dimension axis using an L2 norm. - * - * @param x the input - * @param axis Dimension along which to normalize. - * @param the data type for the input and the result - * @return the normalized values based on L2 norm - */ - public Operand l2Normalize(Operand x, int[] axis) { - Operand squareSum = - ReduceSum.create(scope, - Square.create(scope, x), - Constant.vectorOf(scope, axis), - ReduceSum.keepDims(Boolean.TRUE)); - Operand invNorm = - Rsqrt.create(scope, - Maximum.create(scope, squareSum, - Cast.create(scope, - Constant.scalarOf(scope, 1e-12F), x.type()))); - return Mul.create(scope, x, invNorm); - } + /** + * Normalizes along dimension axis using an L2 norm. + * + * @param x the input + * @param axis Dimension along which to normalize. + * @param the data type for the input and the result + * @return the normalized values based on L2 norm + */ + public Operand l2Normalize(Operand x, int[] axis) { + Operand squareSum = + ReduceSum.create( + scope, + Square.create(scope, x), + Constant.vectorOf(scope, axis), + ReduceSum.keepDims(Boolean.TRUE)); + Operand invNorm = + Rsqrt.create( + scope, + Maximum.create( + scope, squareSum, Cast.create(scope, Constant.scalarOf(scope, 1e-12F), x.type()))); + return Mul.create(scope, x, invNorm); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java index 4054f3ddbb5..0fea3743d95 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java @@ -32,166 +32,164 @@ *

{@link FrameworkOps} */ public class NnOps { - private final Scope scope; + private final Scope scope; - private final FrameworkOps frameworkOps; + private final FrameworkOps frameworkOps; - /** - * Creates Framework {@code nn} Operations - * @param frameworkOps the TensorFLow framework Ops - */ - NnOps(FrameworkOps frameworkOps) { - this.scope = frameworkOps.scope(); - this.frameworkOps = frameworkOps; - } + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + NnOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } - /** - * Computes sigmoid cross entropy given logits. - * - *

Measures the probability error in discrete classification tasks in which each class is - * independent and not mutually exclusive. For instance, one could perform multilabel - * classification where a picture can contain both an elephant and a dog at the same time. - * - *

For brevity, let x = logits, z = labels. The logistic loss in - * pseudo-code is - * - *

-     *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
-     *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
-     *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
-     *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
-     *   = (1 - z) * x + log(1 + exp(-x))
-     *   = x - x * z + log(1 + exp(-x))
-     *  
- * - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above - * - *

-     *  x - x * z + log(1 + exp(-x))
-     *   = log(exp(x)) - x * z + log(1 + exp(-x))
-     *   = - x * z + log(1 + exp(x))
-     *  
- * - *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent - * formulation - * - *

-     *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
-     *  
- * - *

logits and labels must have the same type and shape. - * - *

- * - * @param labels the labels - * @param logits the logits of type float32 or float64 - * @param the type of labels and logits - * @return the component-wise logistic losses. - * @throws IllegalArgumentException if logits' and labels' do not have the same shape - */ - public Operand sigmoidCrossEntropyWithLogits(Operand labels, - Operand logits) { - return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); - } - - /** - * Computes softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

While the classes are mutually exclusive, their probabilities need not be. All that is - * required is that each row of labels is a valid probability distribution. If they - * are not, the computation of the gradient will be incorrect. - * - *

If using exclusive labels (wherein one and only one class is true at a time), - * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} - * - *

Usage: - * - *

-     *    Operand<TFloat32> logits =
-     *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
-     *    Operand<TFloat32> labels =
-     *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
-     *    Operand<TFloat32> output =
-     *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
-     *    // output Shape = [2]
-     *    // dataType = FLOAT (1)
-     *    // values { 0.169846, 0.824745 }
-     *  
- * - *

Backpropagation will happen into both logits and labels. To - * disallow backpropagation into labels, pass label tensors through - * tf.stopGradient before feeding it to this function. - * - * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] - * , each row of labels[i] must be a valid probability distribution. - * @param logits Per-label activations, typically a linear output. These activation energies are - * interpreted as unnormalized log probabilities. - * @param axis The class dimension. -1 is the last dimension. - * @param the number type of the operands - * @return the softmax cross entropy loss. Its type is the same as logits and its - * shape is the same as labels except that it does not have the last dimension of - * labels. - */ - public Operand softmaxCrossEntropyWithLogits( - Operand labels, Operand logits, int axis) { - return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); - } - - /** - * Computes sparse softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

For this operation, the probability of a given label is considered exclusive. That is, soft - * classes are not allowed, and the labels vector must provide a single specific - * index for the true class for each row of logits (each minibatch entry). For soft - * softmax classification with a probability distribution for each entry, {@link - * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. - * - *

WARNING: - * - *

This op expects unscaled logits, since it performs a softmax on logits - * internally for efficiency. Do not call this op with the output of softmax, - * as it will produce incorrect results. - * - *

A common use case is to have logits of shape [batchSize, numClasses] and have - * labels of shape [batchSize], but higher dimensions are supported, in which case - * the dim-th dimension is assumed to be of size numClasses. - * logits must have the dataType of TFloat16, TFloat32 - * , or TFloat64, and labels must have the dtype of TInt32 - * or TInt64. - * - * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r - * is rank of labels and result) and the dataType is TInt32 - * or TInt64. Each entry in labels must be an index in [0, - * numClasses). Other values will raise an exception when this op is run on CPU, and - * return NaN for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., - * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, - * or TFloat64. These activation energies are interpreted as unnormalized log - * probabilities. - * @param The data type for the labels - * @param The data type for the logits and loss - * @return the loss - * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank - * of the labels is not equal to the rank of the logits minus one. - */ - - public Operand sparseSoftmaxCrossEntropyWithLogits( - Operand labels, Operand logits) { - return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits(scope, labels, logits); - } + /** + * Computes sigmoid cross entropy given {@code logits}. + * + *

Measures the probability error in discrete classification tasks in which each class is + * independent and not mutually exclusive. For instance, one could perform multilabel + * classification where a picture can contain both an elephant and a dog at the same time. + * + *

For brevity, let {@code x = logits}, {@code z = labels}. The logistic loss in pseudo-code is + * + *

+   *  z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
+   *   = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
+   *   = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
+   *   = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
+   *   = (1 - z) * x + log(1 + exp(-x))
+   *   = x - x * z + log(1 + exp(-x))
+   *  
+ * + *

For {@code x < 0}, to avoid overflow in {@code exp(-x)}, we reformulate the above + * + *

+   *  x - x * z + log(1 + exp(-x))
+   *   = log(exp(x)) - x * z + log(1 + exp(-x))
+   *   = - x * z + log(1 + exp(x))
+   *  
+ * + *

Hence, to ensure stability and avoid overflow, the implementation uses this equivalent + * formulation + * + *

+   *    max(x, 0) - x * z + log(1 + exp(-abs(x)))
+   *  
+ * + *

{@code logits} and {@code labels} must have the same type and shape. + * + *

+ * + * @param labels the labels + * @param logits the logits of type float32 or float64 + * @param the type of labels and logits + * @return the component-wise logistic losses. + * @throws IllegalArgumentException if logits' and labels' do not have the same shape + */ + public Operand sigmoidCrossEntropyWithLogits( + Operand labels, Operand logits) { + return SigmoidCrossEntropyWithLogits.sigmoidCrossEntropyWithLogits(scope, labels, logits); + } + /** + * Computes softmax cross entropy between {@code logits} and {@code labels}. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of {@code labels} is a valid probability distribution. If they are + * not, the computation of the gradient will be incorrect. + * + *

If using exclusive {@code labels} (wherein one and only one class is true at a time), see + * {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} + * + *

Usage: + * + *

+   *    Operand<TFloat32> logits =
+   *        tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
+   *    Operand<TFloat32> labels =
+   *        tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
+   *    Operand<TFloat32> output =
+   *        tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+   *    // output Shape = [2]
+   *    // dataType = FLOAT (1)
+   *    // values { 0.169846, 0.824745 }
+   *  
+ * + *

Backpropagation will happen into both {@code logits} and {@code labels}. To disallow + * backpropagation into {@code labels}, pass label tensors through {@code tf.stopGradient} before + * feeding it to this function. + * + * @param labels Each vector along the class dimension should hold a valid probability + * distribution e.g. for the case in which labels are of shape {@code [batch_size, + * num_classes] }, each row of {@code labels[i]} must be a valid probability distribution. + * @param logits Per-label activations, typically a linear output. These activation energies are + * interpreted as unnormalized log probabilities. + * @param axis The class dimension. -1 is the last dimension. + * @param the number type of the operands + * @param the data type for the labels. + * @return the softmax cross entropy loss. Its type is the same as {@code logits} and its shape is + * the same as {@code labels} except that it does not have the last dimension of {@code + * labels}. + * + */ + public Operand softmaxCrossEntropyWithLogits( + Operand labels, Operand logits, int axis) { + return SoftmaxCrossEntropyWithLogits.softmaxCrossEntropyWithLogits(scope, labels, logits, axis); + } + /** + * Computes sparse softmax cross entropy between {@code logits} and {@code labels}. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

For this operation, the probability of a given label is considered exclusive. That is, soft + * classes are not allowed, and the {@code labels} vector must provide a single specific index for + * the true class for each row of {@code logits} (each minibatch entry). For soft softmax + * classification with a probability distribution for each entry, {@link + * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. + * + *

WARNING: + * + *

This op expects unscaled logits, since it performs a {@code softmax} on {@code logits } + * internally for efficiency. Do not call this op with the output of {@code softmax}, as it will + * produce incorrect results. + * + *

A common use case is to have logits of shape {@code [batchSize, numClasses]} and have labels + * of shape {@code [batchSize]}, but higher dimensions are supported, in which case the {@code + * dim}-th dimension is assumed to be of size {@code numClasses}. {@code logits} must have the + * {@code dataType} of {@code TFloat16}, {@code TFloat32} , or {@code TFloat64}, and {@code + * labels} must have the dtype of {@code TInt32} or {@code TInt64}. + * + * @param labels {@code Tensor} of shape {@code [d_0, d_1, ..., d_{r-1}]} (where {@code r } is + * rank of {@code labels} and result) and the dataType is {@code TInt32} or {@code TInt64}. + * Each entry in {@code labels} must be an index in {@code [0, numClasses)}. Other values will + * raise an exception when this op is run on CPU, and return {@code NaN} for corresponding + * loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape {@code [d_0, d_1, ..., + * d_{r-1}, numClasses]} and dataType of {@code TFloat16}, {@code TFloat32}, or {@code + * TFloat64}. These activation energies are interpreted as unnormalized log probabilities. + * @param The data type for the labels + * @param The data type for the logits and loss + * @return the loss + * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if the rank + * of the labels is not equal to the rank of the logits minus one. + */ + public Operand sparseSoftmaxCrossEntropyWithLogits( + Operand labels, Operand logits) { + return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits( + scope, labels, logits); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java index b55385839d3..fc3f7739363 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java @@ -3,8 +3,6 @@ import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.Select; import org.tensorflow.op.core.ZerosLike; import org.tensorflow.op.dtypes.Cast; @@ -18,17 +16,17 @@ import org.tensorflow.types.TBool; import org.tensorflow.types.family.TNumber; -//@Operator(group = "nn") +// @Operator(group = "nn") public class SigmoidCrossEntropyWithLogits { /** - * Computes sigmoid cross entropy given logits. + * Computes sigmoid cross entropy given {@code logits}. * *

Measures the probability error in discrete classification tasks in which each class is * independent and not mutually exclusive. For instance, one could perform multilabel * classification where a picture can contain both an elephant and a dog at the same time. * - *

For brevity, let x = logits, z = labels. The logistic loss in + *

For brevity, let {@code x = logits}, {@code z = labels}. The logistic loss in * pseudo-code is * *

@@ -40,7 +38,7 @@ public class SigmoidCrossEntropyWithLogits {
    *  = x - x * z + log(1 + exp(-x))
    * 
* - *

For x < 0, to avoid overflow in exp(-x), we reformulate the above + *

For {@code x < 0}, to avoid overflow in {@code exp(-x)}, we reformulate the above * *

    * x - x * z + log(1 + exp(-x))
@@ -55,7 +53,7 @@ public class SigmoidCrossEntropyWithLogits {
    *   max(x, 0) - x * z + log(1 + exp(-abs(x)))
    * 
* - *

logits and labels must have the same type and shape. + *

{@code logits} and {@code labels} must have the same type and shape. * *

* @@ -66,7 +64,7 @@ public class SigmoidCrossEntropyWithLogits { * @return the component-wise logistic losses. * @throws IllegalArgumentException if logits' and labels' do not have the same shape */ - //@Endpoint(name = "sigmoidCrossEntropyWithLogits") + // @Endpoint(name = "sigmoidCrossEntropyWithLogits") public static Operand sigmoidCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits) { if (!isCompatible(labels.shape(), logits.shape())) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java index 0f5b8197f1e..7d59941f27a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -66,7 +66,8 @@ public class SoftmaxCrossEntropyWithLogits { * @param logits Per-label activations, typically a linear output. These activation energies are * interpreted as unnormalized log probabilities. * @param axis The class dimension. -1 is the last dimension. - * @param the number type of the operands + * @param the data type for the logits and return operand + * @param the data type for the labels * @return the softmax cross entropy loss. Its type is the same as logits and its * shape is the same as labels except that it does not have the last dimension of * labels. diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 75766cf9bfb..0b2d29d6092 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -25,7 +25,7 @@ public class SparseSoftmaxCrossEntropyWithLogits { /** - * Computes sparse softmax cross entropy between logits and labels. + * Computes sparse softmax cross entropy between {@code logits} and {@code labels}. * *

Measures the probability error in discrete classification tasks in which the classes are * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is @@ -34,45 +34,45 @@ public class SparseSoftmaxCrossEntropyWithLogits { *

NOTE: * *

For this operation, the probability of a given label is considered exclusive. That is, soft - * classes are not allowed, and the labels vector must provide a single specific - * index for the true class for each row of logits (each minibatch entry). For soft + * classes are not allowed, and the {@code labels} vector must provide a single specific + * index for the true class for each row of {@code logits} (each minibatch entry). For soft * softmax classification with a probability distribution for each entry, {@link * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. * *

WARNING: * - *

This op expects unscaled logits, since it performs a softmax on logits - * internally for efficiency. Do not call this op with the output of softmax, + *

This op expects unscaled logits, since it performs a {@code softmax} on {@code logits + * } internally for efficiency. Do not call this op with the output of {@code softmax}, * as it will produce incorrect results. * - *

A common use case is to have logits of shape [batchSize, numClasses] and have - * labels of shape [batchSize], but higher dimensions are supported, in which case - * the dim-th dimension is assumed to be of size numClasses. - * logits must have the dataType of TFloat16, TFloat32 - * , or TFloat64, and labels must have the dtype of TInt32 - * or TInt64. + *

A common use case is to have logits of shape {@code [batchSize, numClasses]} and have + * labels of shape {@code [batchSize]}, but higher dimensions are supported, in which case + * the {@code dim}-th dimension is assumed to be of size {@code numClasses}. {@code + * logits} must have the {@code dataType} of {@code TFloat16}, {@code TFloat32} + * , or {@code TFloat64}, and {@code labels} must have the dtype of {@code TInt32} + * or {@code TInt64}. * * @param scope current scope - * @param labels Tensor of shape [d_0, d_1, ..., d_{r-1}] (where r - * is rank of labels and result) and the dataType is TInt32 - * or TInt64. Each entry in labels must be an index in [0, - * numClasses). Other values will raise an exception when this op is run on CPU, and - * return NaN for corresponding loss and gradient rows on GPU. - * @param logits Per-label activations (typically a linear output) of shape [d_0, d_1, ..., - * d_{r-1}, numClasses] and dataType of TFloat16, TFloat32, - * or TFloat64. These activation energies are interpreted as unnormalized log + * @param labels {@code Tensor} of shape {@code [d_0, d_1, ..., d_{r-1}]} (where {@code r + * } is rank of {@code labels} and result) and the dataType is {@code TInt32} + * or {@code TInt64}. Each entry in {@code labels} must be an index in {@code [0, + * numClasses)}. Other values will raise an exception when this op is run on CPU, and + * return {@code NaN} for corresponding loss and gradient rows on GPU. + * @param logits Per-label activations (typically a linear output) of shape {@code [d_0, d_1, ..., + * d_{r-1}, numClasses]} and dataType of {@code TFloat16}, {@code TFloat32}, + * or {@code TFloat64}. These activation energies are interpreted as unnormalized log * probabilities. - * @param the data type for the labels - * @param the data tyoe for the loss and logits. + * @param the data type for the labels + * @param the data tyoe for the loss and logits. * @return the loss - * @throws IllegalArgumentException If logits are scalars (need to have rank >= 1) or if the rank + * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if the rank * of the labels is not equal to the rank of the logits minus one. */ @SuppressWarnings("unchecked") @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") public static - Operand sparseSoftmaxCrossEntropyWithLogits( - Scope scope, Operand labels, Operand logits) { + Operand sparseSoftmaxCrossEntropyWithLogits( + Scope scope, Operand labels, Operand logits) { scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); Operand preciseLogits; if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { @@ -119,7 +119,7 @@ Operand sparseSoftmaxCrossEntropyWithLogits( return Cast.create(scope, cost, logits.type()); } else { // Unchecked cast already checked with previous if - return (Operand) cost; + return (Operand) cost; } } @@ -160,7 +160,7 @@ Operand sparseSoftmaxCrossEntropyWithLogits( return Cast.create(scope, cost, logits.type()); } else { // Unchecked cast already checked with previous if - return (Operand) cost; + return (Operand) cost; } } } From 790bf3517c93e975ea28f03b72b9b7a6d0dc2bde Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Fri, 16 Apr 2021 18:04:30 -0400 Subject: [PATCH 12/31] Add confusionMatrix() method. add Unit test --- .../org/tensorflow/framework/op/MathOps.java | 301 +++++++++++++ .../tensorflow/framework/op/MathOpsTest.java | 413 ++++++++++++++++++ 2 files changed, 714 insertions(+) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java index 5208cde98f3..36f5b692cab 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java @@ -15,16 +15,37 @@ package org.tensorflow.framework.op; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; +import org.tensorflow.op.core.AssertThat; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Identity; +import org.tensorflow.op.core.OnesLike; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.ReduceAll; +import org.tensorflow.op.core.ReduceMax; import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.core.ScatterNd; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.op.core.Stack; +import org.tensorflow.op.core.Zeros; import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.GreaterEqual; +import org.tensorflow.op.math.Less; import org.tensorflow.op.math.Maximum; import org.tensorflow.op.math.Mul; import org.tensorflow.op.math.Rsqrt; import org.tensorflow.op.math.Square; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TNumber; +import java.util.Arrays; +import java.util.Collections; + public class MathOps { private final Scope scope; @@ -62,4 +83,284 @@ public Operand l2Normalize(Operand x, int[] axis) { scope, squareSum, Cast.create(scope, Constant.scalarOf(scope, 1e-12F), x.type()))); return Mul.create(scope, x, invNorm); } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid + * labels for a given classification task. Both prediction and labels must be 1-D arrays of the + * same shape in order for this function to work. + * + *

If `num_classes` is `None`, then `num_classes` will be set to one plus the maximum value in + * either predictions or labels. Class labels are expected to start at 0. For example, if + * `num_classes` is 3, then the possible labels would be `[0, 1, 2]`. + * + *

If `weights` is not `None`, then each prediction contributes its corresponding weight to the + * total value of the confusion matrix cell. + * + *

For example: + * + *

+   *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
+   *         [[0 0 0 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 0 0 0]
+   *          [0 0 0 0 1]]
+   * 
+ * + *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 + * confusion matrix. + * + * @param labels 1-D Operand of real labels for the classification task. + * @param predictions 1-D Operand of predictions for a given classification. + * @param Data type of the confusion matrix. + * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion + * matrix, where {@code n} is the number of possible labels in the classification task. + * @throws IllegalArgumentException If both predictions and labels are not 1-D vectors and have + * mismatched shapes, or if {@code weights} is not null and its shape doesn't match {@code + * predictions}. + */ + public Operand confusionMatrix(Operand labels, Operand predictions) { + return confusionMatrix(labels, predictions, null, null, labels.type()); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid + * labels for a given classification task. Both prediction and labels must be 1-D arrays of the + * same shape in order for this function to work. + * + *

If `num_classes` is `None`, then `num_classes` will be set to one plus the maximum value in + * either predictions or labels. Class labels are expected to start at 0. For example, if + * `num_classes` is 3, then the possible labels would be `[0, 1, 2]`. + * + *

If `weights` is not `None`, then each prediction contributes its corresponding weight to the + * total value of the confusion matrix cell. + * + *

For example: + * + *

+   *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
+   *         [[0 0 0 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 0 0 0]
+   *          [0 0 0 0 1]]
+   * 
+ * + *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 + * confusion matrix. + * + * @param labels 1-D Operand of real labels for the classification task. + * @param predictions 1-D Operand of predictions for a given classification. + * @param weights An optional Operand whose shape matches {@code predictions}. + * @param Data type of the confusion matrix. + * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion + * matrix, where {@code n} is the number of possible labels in the classification task. + * @throws IllegalArgumentException If both predictions and labels are not 1-D vectors and have + * mismatched shapes, or if {@code weights} is not null and its shape doesn't match {@code + * predictions}. + */ + public Operand confusionMatrix( + Operand labels, Operand predictions, Operand weights) { + return confusionMatrix(labels, predictions, weights, null, labels.type()); + } + + /** + * Computes the confusion matrix from predictions and labels. + * + *

The matrix columns represent the prediction labels and the rows represent the real labels. + * The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid + * labels for a given classification task. Both prediction and labels must be 1-D arrays of the + * same shape in order for this function to work. + * + *

If `num_classes` is `None`, then `num_classes` will be set to one plus the maximum value in + * either predictions or labels. Class labels are expected to start at 0. For example, if + * `num_classes` is 3, then the possible labels would be `[0, 1, 2]`. + * + *

If `weights` is not `None`, then each prediction contributes its corresponding weight to the + * total value of the confusion matrix cell. + * + *

For example: + * + *

+   *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
+   *         [[0 0 0 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 1 0 0]
+   *          [0 0 0 0 0]
+   *          [0 0 0 0 1]]
+   * 
+ * + *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 + * confusion matrix. + * + * @param labels 1-D Operand of real labels for the classification task. + * @param predictions 1-D Operand of predictions for a given classification. + * @param weights An optional Operand whose shape matches {@code predictions}. + * @param numClasses The possible number of labels the classification task can have. If this value + * is null, it will be calculated using both predictions and labels. + * @param type Data type of the confusion matrix. + * @param Data type of the confusion matrix. + * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion + * matrix, where {@code n} is the number of possible labels in the classification task. + * @throws IllegalArgumentException If both predictions and labels are not 1-D vectors and have + * mismatched shapes, or if {@code weights} is not null and its shape doesn't match {@code + * predictions}. + */ + public Operand confusionMatrix( + Operand labels, + Operand predictions, + Operand weights, + Operand numClasses, + Class type) { + Scope lScope = scope.withSubScope("confusionMatrix"); + LossTuple tuple = removeSqueezableDimensions(labels, predictions, 0); + Operand lLabels = Cast.create(lScope, tuple.getLabels(), TInt64.class); + Operand lPredictions = Cast.create(lScope, tuple.getTarget(), TInt64.class); + + Operand zero = Constant.scalarOf(lScope, 0L); + Operand one = Constant.scalarOf(lScope, 1L); + + AssertThat labelsNonNegative = + AssertThat.create( + lScope, + ReduceAll.create(lScope, GreaterEqual.create(lScope, lLabels, zero), allAxes(lLabels)), + Collections.singletonList( + Constant.scalarOf(lScope, "labels contains negative values"))); + lLabels = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(labelsNonNegative)), lLabels); + + AssertThat predictionsNonNegative = + AssertThat.create( + lScope, + ReduceAll.create( + lScope, GreaterEqual.create(lScope, lPredictions, zero), allAxes(lPredictions)), + Collections.singletonList( + Constant.scalarOf(lScope, "predictions contains negative values"))); + lPredictions = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(predictionsNonNegative)), + lPredictions); + + Operand lNumClasses; + if (numClasses == null) { + lNumClasses = + Add.create( + lScope, + Maximum.create( + lScope, + ReduceMax.create(lScope, lPredictions, zero), + ReduceMax.create(lScope, lLabels, zero)), + one); + } else { + lNumClasses = Cast.create(lScope, numClasses, TInt64.class); + AssertThat labelsLess = + AssertThat.create( + lScope, + Less.create(lScope, lLabels, lNumClasses), + Collections.singletonList(Constant.scalarOf(lScope, "labels out of bounds"))); + lLabels = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(labelsLess)), lLabels); + + AssertThat predictionsLess = + AssertThat.create( + lScope, + Less.create(lScope, lPredictions, lNumClasses), + Collections.singletonList(Constant.scalarOf(lScope, "predictions out of bounds"))); + lPredictions = + Identity.create( + lScope.withControlDependencies(Collections.singletonList(predictionsLess)), + lPredictions); + } + + if (weights != null) { + if (!predictions.shape().isCompatibleWith(weights.shape())) { + throw new IllegalArgumentException( + String.format( + "predictions.shape() [%s], is not compatible with weights.shape() [ %s].", + predictions.shape(), weights.shape())); + } + } + + Operand shape = Stack.create(lScope, Arrays.asList(lNumClasses, lNumClasses)); + Operand indices = + Stack.create(lScope, Arrays.asList(lLabels, lPredictions), Stack.axis(1L)); + Operand values = weights == null ? OnesLike.create(lScope, predictions) : weights; + Operand zeroMatrix = Zeros.create(lScope, Cast.create(lScope, shape, TInt32.class), type); + + return ScatterNd.create(lScope, indices, values, shape); + } + + /** + * Squeeze last dim if ranks differ from expected by exactly 1. + * + * @param labels Label values, a Operand whose dimensions match predictions + * . + * @param predictions Predicted values, a Tensor of arbitrary dimensions. + * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). + * @param the data type for the labels, predictions and result + * @return labels and predictions, possibly with last dim squeezed. + */ + public LossTuple removeSqueezableDimensions( + Operand labels, Operand predictions, int expectedRankDiff) { + Scope lScope = scope.withSubScope("removeSqueezableDimensions"); + Shape predictionsShape = predictions.shape(); + int predictionsRank = predictionsShape.numDimensions(); + Shape labelsShape = labels.shape(); + int labelsRank = labelsShape.numDimensions(); + + if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) { + // Use static rank. + int rankDiff = predictionsRank - labelsRank; + if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) { + predictions = Squeeze.create(lScope, predictions); + } else if (rankDiff == expectedRankDiff - 1 && Shape.isCompatible(labelsShape.size(-1), 1)) { + labels = Squeeze.create(lScope, labels); + } + return new LossTuple<>(labels, predictions); + } + // Use dynamic rank. + + // TODO: hold for lazy select feature, + // Operand rankDiff = tf.math.sub(tf.rank(predictions), tf.rank(labels)); + if (predictionsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(predictionsShape.size(-1), 1)) { + /* + * TODO, if we ever get a select that does lazy evaluation, but for now do the tf.squeeze + * predictions = tf.select( tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), + * tf.squeeze(predictions, Squeeze.axis(Arrays.asList(-1L))), predictions ); * + */ + predictions = + Squeeze.create(lScope, predictions, Squeeze.axis(Collections.singletonList(-1L))); + } + if (labelsRank == Shape.UNKNOWN_SIZE && Shape.isCompatible(labelsShape.size(-1), 1)) { + /* + * TODO, if we ever get a select that does lazy evaluation labels = tf.select( + * tf.math.equal(tf.constant(expectedRankDiff+1),rankDiff ), tf.squeeze(labels, + * Squeeze.axis(Arrays.asList(-1L))), predictions ); * + */ + labels = Squeeze.create(lScope, labels, Squeeze.axis(Collections.singletonList(-1L))); + } + return new LossTuple<>(labels, predictions); + } + + public Operand allAxes(Operand op) { + int rank = op.shape().numDimensions(); + if (rank != Shape.UNKNOWN_SIZE) { + int[] axes = new int[rank]; + for (int i = 0; i < rank; i++) { + axes[i] = i; + } + return Constant.vectorOf(scope, axes); + } else { + return Range.create( + scope, Constant.scalarOf(scope, 0), Rank.create(scope, op), Constant.scalarOf(scope, 1)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java new file mode 100644 index 00000000000..326e3cdc2d1 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java @@ -0,0 +1,413 @@ +package org.tensorflow.framework.op; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; + +class MathOpsTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + double[][][] array = + new double[][][] { + { + {4.17021990e-01, 7.20324516e-01, 1.14374816e-04}, + {3.02332580e-01, 1.46755889e-01, 9.23385918e-02}, + {1.86260208e-01, 3.45560730e-01, 3.96767467e-01}, + {5.38816750e-01, 4.19194520e-01, 6.85219526e-01}, + {2.04452246e-01, 8.78117442e-01, 2.73875929e-02}, + {6.70467496e-01, 4.17304814e-01, 5.58689833e-01}, + {1.40386939e-01, 1.98101491e-01, 8.00744593e-01} + }, + { + {9.68261600e-01, 3.13424170e-01, 6.92322612e-01}, + {8.76389146e-01, 8.94606650e-01, 8.50442126e-02}, + {3.90547849e-02, 1.69830427e-01, 8.78142476e-01}, + {9.83468369e-02, 4.21107620e-01, 9.57889557e-01}, + {5.33165276e-01, 6.91877127e-01, 3.15515637e-01}, + {6.86500907e-01, 8.34625661e-01, 1.82882771e-02}, + {7.50144303e-01, 9.88861084e-01, 7.48165667e-01} + }, + { + {2.80443996e-01, 7.89279342e-01, 1.03226006e-01}, + {4.47893530e-01, 9.08595502e-01, 2.93614149e-01}, + {2.87775338e-01, 1.30028576e-01, 1.93669572e-02}, + {6.78835511e-01, 2.11628109e-01, 2.65546650e-01}, + {4.91573155e-01, 5.33625446e-02, 5.74117601e-01}, + {1.46728575e-01, 5.89305520e-01, 6.99758351e-01}, + {1.02334432e-01, 4.14055973e-01, 6.94400132e-01} + }, + { + {4.14179265e-01, 4.99534607e-02, 5.35896420e-01}, + {6.63794637e-01, 5.14889121e-01, 9.44594741e-01}, + {5.86555064e-01, 9.03401911e-01, 1.37474701e-01}, + {1.39276341e-01, 8.07391286e-01, 3.97676826e-01}, + {1.65354192e-01, 9.27508593e-01, 3.47765863e-01}, + {7.50812113e-01, 7.25997984e-01, 8.83306086e-01}, + {6.23672187e-01, 7.50942409e-01, 3.48898351e-01} + }, + { + {2.69927889e-01, 8.95886242e-01, 4.28091198e-01}, + {9.64840055e-01, 6.63441479e-01, 6.21695697e-01}, + {1.14745975e-01, 9.49489236e-01, 4.49912131e-01}, + {5.78389585e-01, 4.08136815e-01, 2.37026975e-01}, + {9.03379500e-01, 5.73679507e-01, 2.87032709e-03}, + {6.17144942e-01, 3.26644897e-01, 5.27058125e-01}, + {8.85942101e-01, 3.57269764e-01, 9.08535123e-01} + }, + { + {6.23360097e-01, 1.58212427e-02, 9.29437220e-01}, + {6.90896928e-01, 9.97322857e-01, 1.72340512e-01}, + {1.37135744e-01, 9.32595491e-01, 6.96818173e-01}, + {6.60001710e-02, 7.55463064e-01, 7.53876209e-01}, + {9.23024535e-01, 7.11524785e-01, 1.24270961e-01}, + {1.98801346e-02, 2.62109861e-02, 2.83064879e-02}, + {2.46211067e-01, 8.60027969e-01, 5.38831055e-01} + }, + { + {5.52821994e-01, 8.42030883e-01, 1.24173313e-01}, + {2.79183686e-01, 5.85759282e-01, 9.69595730e-01}, + {5.61030209e-01, 1.86472889e-02, 8.00632656e-01}, + {2.32974276e-01, 8.07105184e-01, 3.87860656e-01}, + {8.63541842e-01, 7.47121632e-01, 5.56240261e-01}, + {1.36455223e-01, 5.99176884e-02, 1.21343456e-01}, + {4.45518792e-02, 1.07494131e-01, 2.25709334e-01} + }, + { + {7.12988973e-01, 5.59717000e-01, 1.25559801e-02}, + {7.19742775e-02, 9.67276335e-01, 5.68100452e-01}, + {2.03293234e-01, 2.52325743e-01, 7.43825853e-01}, + {1.95429474e-01, 5.81358910e-01, 9.70019996e-01}, + {8.46828818e-01, 2.39847764e-01, 4.93769705e-01}, + {6.19955719e-01, 8.28980923e-01, 1.56791389e-01}, + {1.85762029e-02, 7.00221434e-02, 4.86345112e-01} + }, + { + {6.06329441e-01, 5.68851411e-01, 3.17362398e-01}, + {9.88616168e-01, 5.79745233e-01, 3.80141169e-01}, + {5.50948203e-01, 7.45334446e-01, 6.69232905e-01}, + {2.64919549e-01, 6.63348362e-02, 3.70084196e-01}, + {6.29717529e-01, 2.10174009e-01, 7.52755582e-01}, + {6.65364787e-02, 2.60315090e-01, 8.04754555e-01}, + {1.93434283e-01, 6.39460862e-01, 5.24670303e-01} + }, + { + {9.24807966e-01, 2.63296783e-01, 6.59610927e-02}, + {7.35065937e-01, 7.72178054e-01, 9.07815874e-01}, + {9.31972086e-01, 1.39515726e-02, 2.34362081e-01}, + {6.16778374e-01, 9.49016333e-01, 9.50176120e-01}, + {5.56653202e-01, 9.15606380e-01, 6.41566217e-01}, + {3.90007704e-01, 4.85990673e-01, 6.04310513e-01}, + {5.49547911e-01, 9.26181436e-01, 9.18733418e-01} + }, + { + {3.94875616e-01, 9.63262558e-01, 1.73955664e-01}, + {1.26329526e-01, 1.35079160e-01, 5.05662143e-01}, + {2.15248056e-02, 9.47970212e-01, 8.27115476e-01}, + {1.50189810e-02, 1.76196262e-01, 3.32063586e-01}, + {1.30996838e-01, 8.09490681e-01, 3.44736665e-01}, + {9.40107465e-01, 5.82014203e-01, 8.78831983e-01}, + {8.44734430e-01, 9.05392289e-01, 4.59880263e-01} + }, + { + {5.46346843e-01, 7.98603594e-01, 2.85718858e-01}, + {4.90253508e-01, 5.99110305e-01, 1.55332759e-02}, + {5.93481421e-01, 4.33676362e-01, 8.07360530e-01}, + {3.15244794e-01, 8.92888725e-01, 5.77857196e-01}, + {1.84010208e-01, 7.87929237e-01, 6.12031162e-01}, + {5.39092720e-02, 4.20193672e-01, 6.79068863e-01}, + {9.18601751e-01, 4.02024889e-04, 9.76759136e-01} + }, + { + {3.76580328e-01, 9.73783553e-01, 6.04716122e-01}, + {8.28845799e-01, 5.74711502e-01, 6.28076196e-01}, + {2.85576284e-01, 5.86833358e-01, 7.50021756e-01}, + {8.58313859e-01, 7.55082190e-01, 6.98057234e-01}, + {8.64479423e-01, 3.22681010e-01, 6.70788765e-01}, + {4.50873941e-01, 3.82102758e-01, 4.10811365e-01}, + {4.01479572e-01, 3.17383945e-01, 6.21919394e-01} + }, + { + {4.30247277e-01, 9.73802090e-01, 6.77800894e-01}, + {1.98569894e-01, 4.26701009e-01, 3.43346238e-01}, + {7.97638834e-01, 8.79998267e-01, 9.03841972e-01}, + {6.62719786e-01, 2.70208269e-01, 2.52366692e-01}, + {8.54897916e-01, 5.27714670e-01, 8.02161098e-01}, + {5.72488546e-01, 7.33142555e-01, 5.19011617e-01}, + {7.70883918e-01, 5.68857968e-01, 4.65709865e-01} + }, + { + {3.42688918e-01, 6.82093501e-02, 3.77924174e-01}, + {7.96260759e-02, 9.82817113e-01, 1.81612849e-01}, + {8.11858714e-01, 8.74961674e-01, 6.88413262e-01}, + {5.69494426e-01, 1.60971433e-01, 4.66880023e-01}, + {3.45172048e-01, 2.25039959e-01, 5.92511892e-01}, + {3.12269837e-01, 9.16305542e-01, 9.09635544e-01}, + {2.57118285e-01, 1.10891297e-01, 1.92962736e-01} + }, + { + {4.99584168e-01, 7.28585660e-01, 2.08194435e-01}, + {2.48033553e-01, 8.51671875e-01, 4.15848732e-01}, + {6.16685092e-01, 2.33666137e-01, 1.01967260e-01}, + {5.15857041e-01, 4.77140993e-01, 1.52671650e-01}, + {6.21806204e-01, 5.44010103e-01, 6.54137373e-01}, + {1.44545540e-01, 7.51527846e-01, 2.22049147e-01}, + {5.19351840e-01, 7.85296023e-01, 2.23304275e-02} + }, + { + {3.24362457e-01, 8.72922361e-01, 8.44709635e-01}, + {5.38440585e-01, 8.66608262e-01, 9.49805975e-01}, + {8.26407015e-01, 8.54115427e-01, 9.87434015e-02}, + {6.51304305e-01, 7.03516960e-01, 6.10240817e-01}, + {7.99615264e-01, 3.45712192e-02, 7.70238757e-01}, + {7.31728613e-01, 2.59698391e-01, 2.57069290e-01}, + {6.32303298e-01, 3.45297456e-01, 7.96588659e-01} + }, + { + {4.46146220e-01, 7.82749414e-01, 9.90471780e-01}, + {3.00248325e-01, 1.43005833e-01, 9.01308417e-01}, + {5.41559398e-01, 9.74740386e-01, 6.36604428e-01}, + {9.93912995e-01, 5.46070814e-01, 5.26425958e-01}, + {1.35427907e-01, 3.55705172e-01, 2.62185670e-02}, + {1.60395175e-01, 7.45637178e-01, 3.03996895e-02}, + {3.66543084e-01, 8.62346232e-01, 6.92677736e-01} + }, + { + {6.90942168e-01, 1.88636795e-01, 4.41904277e-01}, + {5.81577420e-01, 9.89751697e-01, 2.03906223e-01}, + {2.47732908e-01, 2.62173086e-01, 7.50172436e-01}, + {4.56975341e-01, 5.69294393e-02, 5.08516252e-01}, + {2.11960167e-01, 7.98604250e-01, 2.97331393e-01}, + {2.76060123e-02, 5.93432426e-01, 8.43840420e-01}, + {3.81016135e-01, 7.49858320e-01, 5.11141479e-01} + }, + { + {5.40951788e-01, 9.59434330e-01, 8.03960919e-01}, + {3.23230661e-02, 7.09387243e-01, 4.65001494e-01}, + {9.47548926e-01, 2.21432731e-01, 2.67072022e-01}, + {8.14739615e-02, 4.28618819e-01, 1.09018765e-01}, + {6.33786738e-01, 8.02963257e-01, 6.96800470e-01}, + {7.66211390e-01, 3.42454106e-01, 8.45851481e-01}, + {4.28768784e-01, 8.24009895e-01, 6.26496136e-01} + } + }; + + double[][][] expectedArray = { + { + {3.45350616e-02, 5.96526116e-02, 9.47178160e-06}, + {2.50372272e-02, 1.21533722e-02, 7.64688430e-03}, + {1.54248644e-02, 2.86171008e-02, 3.28577124e-02}, + {4.46213149e-02, 3.47149745e-02, 5.67454435e-02}, + {1.69314109e-02, 7.27199987e-02, 2.26806314e-03}, + {5.55237755e-02, 3.45584825e-02, 4.62670736e-02}, + {1.16259372e-02, 1.64054818e-02, 6.63124844e-02} + }, + { + {8.01851526e-02, 2.59557609e-02, 5.73336743e-02}, + {7.25768730e-02, 7.40855262e-02, 7.04281079e-03}, + {3.23426444e-03, 1.40642561e-02, 7.27220699e-02}, + {8.14444851e-03, 3.48734073e-02, 7.93262124e-02}, + {4.41532955e-02, 5.72967827e-02, 2.61289626e-02}, + {5.68515584e-02, 6.91182911e-02, 1.51451665e-03}, + {6.21220917e-02, 8.18910673e-02, 6.19582348e-02} + }, + { + {2.32245550e-02, 6.53630048e-02, 8.54850933e-03}, + {3.70916426e-02, 7.52439946e-02, 2.43152231e-02}, + {2.38316897e-02, 1.07681248e-02, 1.60384597e-03}, + {5.62167615e-02, 1.75256692e-02, 2.19908543e-02}, + {4.07089069e-02, 4.41914052e-03, 4.75447029e-02}, + {1.21511100e-02, 4.88024652e-02, 5.79494536e-02}, + {8.47467501e-03, 3.42894346e-02, 5.75057231e-02} + }, + { + {3.42996456e-02, 4.13682219e-03, 4.43794727e-02}, + {5.49711734e-02, 4.26397808e-02, 7.82252178e-02}, + {4.85746935e-02, 7.48138949e-02, 1.13847647e-02}, + {1.15339644e-02, 6.68629184e-02, 3.29330191e-02}, + {1.36935636e-02, 7.68102556e-02, 2.87997164e-02}, + {6.21773973e-02, 6.01224527e-02, 7.31496885e-02}, + {5.16484901e-02, 6.21881858e-02, 2.88935024e-02} + }, + { + {2.23536789e-02, 7.41914958e-02, 3.54517400e-02}, + {7.99018070e-02, 5.49419262e-02, 5.14848121e-02}, + {9.50251892e-03, 7.86305517e-02, 3.72588076e-02}, + {4.78984788e-02, 3.37992460e-02, 1.96290389e-02}, + {7.48120397e-02, 4.75084223e-02, 2.37701897e-04}, + {5.11079468e-02, 2.70506144e-02, 4.36475389e-02}, + {7.33679906e-02, 2.95867678e-02, 7.52389953e-02} + }, + { + {5.16226478e-02, 1.31021289e-03, 7.69699737e-02}, + {5.72156087e-02, 8.25918168e-02, 1.42721254e-02}, + {1.13566946e-02, 7.72315189e-02, 5.77059686e-02}, + {5.46570681e-03, 6.25625551e-02, 6.24311455e-02}, + {7.64389113e-02, 5.89238741e-02, 1.02913165e-02}, + {1.64634397e-03, 2.17062421e-03, 2.34416011e-03}, + {2.03896053e-02, 7.12219477e-02, 4.46224995e-02} + }, + { + {4.57811356e-02, 6.97315410e-02, 1.02832299e-02}, + {2.31201854e-02, 4.85087894e-02, 8.02956372e-02}, + {4.64608893e-02, 1.54424773e-03, 6.63032085e-02}, + {1.92934200e-02, 6.68392256e-02, 3.21201086e-02}, + {7.15129450e-02, 6.18717745e-02, 4.60642166e-02}, + {1.13003375e-02, 4.96199494e-03, 1.00488793e-02}, + {3.68949817e-03, 8.90196767e-03, 1.86917856e-02} + }, + { + {5.90451285e-02, 4.63521369e-02, 1.03980501e-03}, + {5.96044352e-03, 8.01035613e-02, 4.70464006e-02}, + {1.68354288e-02, 2.08959840e-02, 6.15988411e-02}, + {1.61842033e-02, 4.81443815e-02, 8.03307742e-02}, + {7.01288804e-02, 1.98626388e-02, 4.08908091e-02}, + {5.13407178e-02, 6.86508343e-02, 1.29844472e-02}, + {1.53836084e-03, 5.79878036e-03, 4.02759537e-02} + }, + { + {5.02122790e-02, 4.71085906e-02, 2.62818988e-02}, + {8.18707868e-02, 4.80107442e-02, 3.14808302e-02}, + {4.56259623e-02, 6.17237724e-02, 5.54215349e-02}, + {2.19389219e-02, 5.49342157e-03, 3.06479763e-02}, + {5.21491282e-02, 1.74052510e-02, 6.23383410e-02}, + {5.51012019e-03, 2.15576105e-02, 6.66445568e-02}, + {1.60189737e-02, 5.29560074e-02, 4.34497967e-02} + }, + { + {7.65866041e-02, 2.18045339e-02, 5.46247046e-03}, + {6.08734004e-02, 6.39467835e-02, 7.51794279e-02}, + {7.71798939e-02, 1.15537888e-03, 1.94083489e-02}, + {5.10775894e-02, 7.85913840e-02, 7.86874294e-02}, + {4.60984148e-02, 7.58245885e-02, 5.31303585e-02}, + {3.22979130e-02, 4.02465984e-02, 5.00450842e-02}, + {4.55099978e-02, 7.67003447e-02, 7.60835484e-02} + }, + { + {3.27010415e-02, 7.97711685e-02, 1.44058811e-02}, + {1.04617933e-02, 1.11863809e-02, 4.18756641e-02}, + {1.78254500e-03, 7.85047561e-02, 6.84963465e-02}, + {1.24377478e-03, 1.45914331e-02, 2.74993554e-02}, + {1.08483098e-02, 6.70367777e-02, 2.85488572e-02}, + {7.78536126e-02, 4.81986478e-02, 7.27791712e-02}, + {6.99554384e-02, 7.49787241e-02, 3.80843058e-02} + }, + { + {4.52449061e-02, 6.61351755e-02, 2.36613862e-02}, + {4.05996218e-02, 4.96144369e-02, 1.28636532e-03}, + {4.91482876e-02, 3.59142683e-02, 6.68603703e-02}, + {2.61065327e-02, 7.39432648e-02, 4.78543900e-02}, + {1.52385337e-02, 6.52511939e-02, 5.06844558e-02}, + {4.46441676e-03, 3.47977169e-02, 5.62360846e-02}, + {7.60726482e-02, 3.32930977e-05, 8.08888674e-02} + }, + { + {3.11859436e-02, 8.06424469e-02, 5.00786714e-02}, + {6.86396435e-02, 4.75938842e-02, 5.20132035e-02}, + {2.36495789e-02, 4.85977381e-02, 6.21119440e-02}, + {7.10799918e-02, 6.25310168e-02, 5.78085780e-02}, + {7.15905875e-02, 2.67223511e-02, 5.55503815e-02}, + {3.73384580e-02, 3.16432752e-02, 3.40207368e-02}, + {3.32479365e-02, 2.62836833e-02, 5.15033379e-02} + }, + { + {3.56302932e-02, 8.06439817e-02, 5.61310798e-02}, + {1.64442733e-02, 3.53366137e-02, 2.84337122e-02}, + {6.60552830e-02, 7.28757605e-02, 7.48503357e-02}, + {5.48821613e-02, 2.23768987e-02, 2.08993759e-02}, + {7.07971081e-02, 4.37019095e-02, 6.64297864e-02}, + {4.74097952e-02, 6.07141182e-02, 4.29811813e-02}, + {6.38396144e-02, 4.71091345e-02, 3.85670736e-02} + }, + { + {2.83792764e-02, 5.64865675e-03, 3.12972330e-02}, + {6.59411587e-03, 8.13905448e-02, 1.50400000e-02}, + {6.72328845e-02, 7.24586621e-02, 5.70099279e-02}, + {4.71618399e-02, 1.33306114e-02, 3.86639796e-02}, + {2.85849143e-02, 1.86363515e-02, 4.90679964e-02}, + {2.58601662e-02, 7.58824944e-02, 7.53301233e-02}, + {2.12928709e-02, 9.18329880e-03, 1.59799233e-02} + }, + { + {4.13723253e-02, 6.03367463e-02, 1.72413141e-02}, + {2.05405317e-02, 7.05299526e-02, 3.44378985e-02}, + {5.10698669e-02, 1.93507168e-02, 8.44426826e-03}, + {4.27199379e-02, 3.95137258e-02, 1.26432776e-02}, + {5.14939614e-02, 4.50513922e-02, 5.41714206e-02}, + {1.19703254e-02, 6.22366704e-02, 1.83886718e-02}, + {4.30093557e-02, 6.50331303e-02, 1.84926135e-03} + }, + { + {2.68615987e-02, 7.22897798e-02, 6.99533820e-02}, + {4.45901640e-02, 7.17668831e-02, 7.86567777e-02}, + {6.84376806e-02, 7.07323104e-02, 8.17728881e-03}, + {5.39368056e-02, 5.82607202e-02, 5.05361930e-02}, + {6.62189573e-02, 2.86296452e-03, 6.37861863e-02}, + {6.05970249e-02, 2.15065386e-02, 2.12888140e-02}, + {5.23632653e-02, 2.85952985e-02, 6.59683123e-02} + }, + { + {3.69469412e-02, 6.48222342e-02, 8.20244551e-02}, + {2.48646215e-02, 1.18428171e-02, 7.46405274e-02}, + {4.48484421e-02, 8.07216838e-02, 5.27194552e-02}, + {8.23094398e-02, 4.52220477e-02, 4.35951874e-02}, + {1.12152621e-02, 2.94571985e-02, 2.17125192e-03}, + {1.32828895e-02, 6.17488436e-02, 2.51750532e-03}, + {3.03547252e-02, 7.14139268e-02, 5.73630854e-02} + }, + { + {5.72193563e-02, 1.56216780e-02, 3.65956500e-02}, + {4.81624752e-02, 8.19648281e-02, 1.68861933e-02}, + {2.05156356e-02, 2.17114780e-02, 6.21244237e-02}, + {3.78437378e-02, 4.71452763e-03, 4.21120226e-02}, + {1.75531674e-02, 6.61352351e-02, 2.46230606e-02}, + {2.28615105e-03, 4.91442308e-02, 6.98814020e-02}, + {3.15532871e-02, 6.20984100e-02, 4.23294269e-02} + }, + { + {4.47981246e-02, 7.94541389e-02, 6.65788352e-02}, + {2.67678709e-03, 5.87468557e-02, 3.85084115e-02}, + {7.84698650e-02, 1.83376241e-02, 2.21171752e-02}, + {6.74714567e-03, 3.54954340e-02, 9.02822800e-03}, + {5.24861142e-02, 6.64962158e-02, 5.77045009e-02}, + {6.34526685e-02, 2.83598304e-02, 7.00479448e-02}, + {3.55078541e-02, 6.82391599e-02, 5.18823527e-02} + } + }; + + @Test + public void testL2Normalize() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + Operand input = tf.constant(array); + Operand result = fops.math.l2Normalize(tf.constant(array), new int[]{ 0,1,2}); + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testConfusionMatrix() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + long[] labels = new long[] {2, 0, 2, 2, 0, 1}; + long[] predictions = new long[] {0, 0, 2, 2, 0, 2}; + Operand result = + fops.math.confusionMatrix(tf.constant(labels), tf.constant(predictions)); + long[][] expected = + new long[][] { + {2, 0, 0}, + {0, 0, 1}, + {1, 0, 2} + }; + session.evaluate(tf.constant(expected), result); + } + } +} From b4ca97a025645a227aae7c306ab397a383f0a6d9 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 2 May 2021 18:37:19 -0400 Subject: [PATCH 13/31] Added linalg methods for matmul --- .../tensorflow/framework/op/LinalgOps.java | 306 ++++++++++++++++++ .../framework/op/LinalgOpsTest.java | 60 ++++ 2 files changed, 366 insertions(+) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java new file mode 100644 index 00000000000..eb069a2db22 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java @@ -0,0 +1,306 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.SparseTensor; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.Conj; +import org.tensorflow.op.sparse.SparseMatMul; +import org.tensorflow.op.train.BatchMatMul; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; + +public class LinalgOps { + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + LinalgOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Multiplies matrix a by matrix b, producing a * b + * . + * + *

The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 + * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions + * specify matching batch size. + * + *

Both matrices must be of the same type. The supported types are: TFloat16, + * TFloat32, TFloat64, TInt32. + * + *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by + * setting one of the corresponding flag to true. These are false by default. + * + *

A simple 2-D tensor matrix multiplication: + * + *

{@code
+   * Operand a = tf.constant(new double[][] {
+   *         {-8.944851},
+   *         {4.1711287},
+   *         {-0.22380222}
+   *     });
+   * Operand b = tf.constant( new double[][] {
+   *         {-14.276086, -12.433481, -2.2447076, -1.5775859, 1.8588694}
+   *     });
+   * Operand result = fops.linalg.matmul(a, b);
+   * // result = {
+   * //     {127.69746,  111.21564,  20.078575,  14.111271,  -16.62731},
+   * //     {-59.547394, -51.861652, -9.362965,  -6.580314,    7.753584},
+   * //     {  3.1950197,  2.7826407, 0.50237054, 0.35306725, -0.4160191}
+   * //  }
+   *
+   * }
+ * + *

Note: This is matrix product, not element-wise product. + * + * @param a an Operand of of type TFloat16, TFloat32, TFloat64 + * , TInt32. with a rank > 1 + * @param b an Operand with same type and rank as a. + * @param the data type of the Operands + * @return A Operand of the same type as a and b where each inner-most + * matrix is the product of the corresponding matrices in a and b. + * This is the matrix product not an element-wise product. + * @throws java.lang.IllegalArgumentException If transposeA and adjointA + * , or transposeB and adjointB are both set to `true`. + */ + @Endpoint(name = "matmul") + public Operand matmul(Operand a, Operand b) { + return matmul(a, b, false, false, false, false, false, false); + } + + /** + * Multiplies matrix a by matrix b, producing a * b + * . + * + *

The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 + * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions + * specify matching batch size. + * + *

Both matrices must be of the same type. The supported types are: TFloat16, + * TFloat32, TFloat64, TInt32. + * + *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by + * setting one of the corresponding flag to true. These are false by default. + * + *

+ * + *

Note: This is matrix product, not element-wise product. + * + *

A simple 2-D tensor matrix multiplication: + * + *

{@code
+   * Operand a = tf.constant(new double[][] {
+   *         {-8.944851},
+   *         {4.1711287},
+   *         {-0.22380222}
+   *     });
+   * Operand b = tf.constant( new double[][] {
+   *         {-14.276086, -12.433481, -2.2447076, -1.5775859, 1.8588694}
+   *     });
+   * Operand result = fops.linalg.matmul(a, b);
+   * // result = {
+   * //     {127.69746,  111.21564,  20.078575,  14.111271,  -16.62731},
+   * //     {-59.547394, -51.861652, -9.362965,  -6.580314,    7.753584},
+   * //     {  3.1950197,  2.7826407, 0.50237054, 0.35306725, -0.4160191}
+   * //  }
+   *
+   * }
+ * + * @param a an Operand of of type TFloat16, TFloat32, TFloat64 + * , TInt32. with a rank > 1 + * @param b an Operand with same type and rank as a. + * @param transposeA If `true`, a is transposed before multiplication. + * @param transposeB If `True`, b is transposed before multiplication + * @param the data type of the Operands + * @return A Operand of the same type as a and b where each inner-most + * matrix is the product of the corresponding matrices in a and b. + * This is the matrix product not an element-wise product. + * @throws java.lang.IllegalArgumentException If transposeA and adjointA + * , or transposeB and adjointB are both set to `true`. + */ + @Endpoint(name = "matmul") + public Operand matmul( + Operand a, Operand b, boolean transposeA, boolean transposeB) { + return matmul(a, b, transposeA, transposeB, false, false, false, false); + } + + /** + * Multiplies matrix a by matrix b, producing a * b + * . + * + *

The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 + * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions + * specify matching batch size. + * + *

Both matrices must be of the same type. The supported types are: TFloat16, + * TFloat32, TFloat64, TInt32. + * + *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by + * setting one of the corresponding flag to true. These are false by default. + * + *

Note: This is matrix product, not element-wise product. + * + *

A simple 2-D tensor matrix multiplication: + * + *

{@code
+   * Operand a = tf.constant(new double[][] {
+   *         {-8.944851},
+   *         {4.1711287},
+   *         {-0.22380222}
+   *     });
+   * Operand b = tf.constant( new double[][] {
+   *         {-14.276086, -12.433481, -2.2447076, -1.5775859, 1.8588694}
+   *     });
+   * Operand result = fops.linalg.matmul(a, b);
+   * // result = {
+   * //     {127.69746,  111.21564,  20.078575,  14.111271,  -16.62731},
+   * //     {-59.547394, -51.861652, -9.362965,  -6.580314,    7.753584},
+   * //     {  3.1950197,  2.7826407, 0.50237054, 0.35306725, -0.4160191}
+   * //  }
+   *
+   * }
+ * + * @param a an Operand of of type TFloat16, TFloat32, TFloat64 + * , TInt32. with a rank > 1 + * @param b an Operand with same type and rank as a. + * @param transposeA If true, a is transposed before multiplication. + * @param transposeB If True, b is transposed before multiplication + * @param adjointA If true, a is conjugated and transposed before multiplication. + * @param adjointB If true, b is conjugated and transposed before multiplication. + * @param aIsSparse If true, a is treated as a sparse matrix. Notice, this does + * not support {@link SparseTensor}, it just makes optimizations that assume most values + * in a are zero. + * @param bIsSparse If true, b is treated as a sparse matrix. Notice, this does + * not support {@link SparseTensor}, it just makes optimizations that assume most values + * in b are zero. + * @param the data type of the Operands + * @return A Operand of the same type as a and b where each inner-most + * matrix is the product of the corresponding matrices in a and b. + * This is the matrix product not an element-wise product. + * @throws java.lang.IllegalArgumentException If transposeA and adjointA + * , or transposeB and adjointB are both set to `true`. + */ + @SuppressWarnings("unchecked") + @Endpoint(name = "matmul") + public Operand matmul( + Operand a, + Operand b, + boolean transposeA, + boolean transposeB, + boolean adjointA, + boolean adjointB, + boolean aIsSparse, + boolean bIsSparse) { + Scope lscope = scope.withSubScope("MatMul"); + if (transposeA && adjointA) + throw new IllegalArgumentException("Only one of transposeA and adjointA can be true."); + if (transposeB && adjointB) + throw new IllegalArgumentException("Only one of transposeB and adjointB can be true."); + if (!(TFloating.class.isAssignableFrom(a.type()) || a.type().equals(TInt32.class))) + throw new IllegalArgumentException( + String.format( + "Operand 'a' must be of type 'TBfloat16','TFloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s", + a.type().getSimpleName())); + if (!(TFloating.class.isAssignableFrom(a.type()) || b.type().equals(TInt32.class))) + throw new IllegalArgumentException( + String.format( + "Operand 'b' must be of type 'TBfloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s", + b.type().getSimpleName())); + + Shape aShape = a.shape(); + Shape bShape = b.shape(); + if (aShape.numDimensions() != bShape.numDimensions()) + throw new IllegalArgumentException( + String.format( + "Parameters 'a' and 'b' must the same rank: found a rank = %d, b rank = %d", + aShape.numDimensions(), bShape.numDimensions())); + boolean outputMayHaveNonEmptyBatchShape = + aShape.numDimensions() == Shape.UNKNOWN_SIZE + || aShape.numDimensions() > 2 + || bShape.numDimensions() == Shape.UNKNOWN_SIZE; + + if ((!aIsSparse && !bIsSparse) && outputMayHaveNonEmptyBatchShape) { + // BatchMatmul does not support transpose, so we conjugate the matrix and + // use adjoint instead. Conj() is a noop for real matrices. + if (transposeA) { + a = Conj.create(scope, a); + adjointA = true; + } + if (transposeB) { + b = Conj.create(scope, b); + adjointB = true; + } + return BatchMatMul.create( + lscope, a, b, BatchMatMul.adjX(adjointA), BatchMatMul.adjY(adjointB)); + } + + // Neither matmul nor sparse_matmul support adjoint, so we conjugate + // the matrix and use transpose instead. Conj() is a noop for real + // matrices. + if (adjointA) { + a = Conj.create(scope, a); + transposeA = true; + } + if (adjointB) { + b = Conj.create(scope, b); + transposeB = true; + } + + boolean useSparseMatmul = false; + if (aIsSparse || bIsSparse) { + useSparseMatmul = + (a.type().equals(TBfloat16.class) || a.type().equals(TFloat32.class)) + && (b.type().equals(TBfloat16.class) || b.type().equals(TFloat32.class)); + } + if ((a.type().equals(TBfloat16.class) || b.type().equals(TBfloat16.class)) + && !a.type().equals(b.type())) useSparseMatmul = true; + + if (useSparseMatmul) { + Operand result = + SparseMatMul.create( + lscope, + a, + b, + SparseMatMul.transposeA(transposeA), + SparseMatMul.transposeB(transposeB), + SparseMatMul.aIsSparse(aIsSparse), + SparseMatMul.bIsSparse(bIsSparse)); + if (a.type().equals(TFloat32.class)) return (Operand) result; + else return Cast.create(scope, result, a.type()); + } + + return org.tensorflow.op.linalg.MatMul.create( + lscope, + a, + b, + org.tensorflow.op.linalg.MatMul.transposeA(transposeA), + org.tensorflow.op.linalg.MatMul.transposeB(transposeB)); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java new file mode 100644 index 00000000000..f2c297ce032 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java @@ -0,0 +1,60 @@ +package org.tensorflow.framework.op; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class LinalgOpsTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void test2D() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + Operand a = tf.constant(new float[][] {{3.7213619f}}); + Operand b = tf.constant(new float[][] {{8.153921f}}); + + Operand ans = fops.linalg.matmul(a, b); + Operand expected = tf.constant(new float[][] {{30.34369f}}); + session.evaluate(expected, ans); + + Operand a64 = + tf.constant(new double[][] {{-8.944851}, {4.1711287}, {-0.22380222}}); + Operand b64 = + tf.constant( + new double[][] {{-14.276086, -12.433481, -2.2447076, -1.5775859, 1.8588694}}); + + Operand ans64 = fops.linalg.matmul(a64, b64); + Operand expected64 = + tf.constant( + new double[][] { + {127.69746, 111.21564, 20.078575, 14.111271, -16.62731}, + {-59.547394, -51.861652, -9.362965, -6.580314, 7.753584}, + {3.1950197, 2.7826407, 0.50237054, 0.35306725, -0.4160191} + }); + session.evaluate(expected64, ans64); + + a64 = + tf.constant( + new double[][] { + {-9.189821, -1.588742, -8.684379}, + {-10.953391, -8.473055, -6.8909864}, + {-11.712155, -6.6350083, -2.4441578}, + {1.4037079, -11.279383, 0.9129576}, + {0.11368857, 2.3792067, -11.218701}, + }); + b64 = tf.constant(new double[][] {{-4.933953}, {-12.692161}, {-10.192119}}); + ans64 = fops.linalg.matmul(a64, b64); + expected64 = + tf.constant( + new double[][] {{154.01892}, {231.81863}, {166.91096}, {126.92895}, {83.58413}}); + session.setEpsilon(1e-4f); + session.evaluate(expected64, ans64); + } + } +} From e83d26b6cb7e4d44616efb1df249a310cabaebe2 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 2 May 2021 18:47:53 -0400 Subject: [PATCH 14/31] add nn ops for sigmoidCrossEntropyWithLogits, softmaxCrossEntropyWithLogits and sparseSoftmaxCrossEntropyWithLogits --- .../annotations/org/tensorflow/op/NnOps.java | 13 ++-- .../op/nn/SoftmaxCrossEntropyWithLogits.java | 57 +++++++++------ .../SparseSoftmaxCrossEntropyWithLogits.java | 62 +++++++++------- .../org/tensorflow/framework/op/NnOps.java | 15 ++-- .../op/nn/SigmoidCrossEntropyWithLogits.java | 3 +- .../SparseSoftmaxCrossEntropyWithLogits.java | 70 +++++++++++-------- .../tensorflow/framework/op/NnOpsTest.java | 68 ++++++++++++++++++ 7 files changed, 192 insertions(+), 96 deletions(-) create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java index 1cf8b910297..2bd4d13145f 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/NnOps.java @@ -1811,14 +1811,14 @@ public Softmax softmax(Operand logits) { /** * Computes softmax cross entropy cost and gradients to backpropagate. - *

* Inputs are the logits, not probabilities. * - * @param data type for {@code loss()} output + * @param data type for {@code loss} output * @param features batch_size x num_classes matrix * @param labels batch_size x num_classes matrix * The caller must ensure that each batch of labels represents a valid * probability distribution. + * @param data type for {@code SoftmaxCrossEntropyWithLogits} output and operands * @return a new instance of SoftmaxCrossEntropyWithLogits */ public SoftmaxCrossEntropyWithLogits softmaxCrossEntropyWithLogits( @@ -2011,18 +2011,17 @@ public SpaceToDepth spaceToDepth(Operand input, Long blo /** * Computes softmax cross entropy cost and gradients to backpropagate. - *

- * Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept + * Unlike {@code SoftmaxCrossEntropyWithLogits}, this operation does not accept * a matrix of label probabilities, but rather a single label per row * of features. This label is considered to have probability 1.0 for the * given row. - *

- * Inputs are the logits, not probabilities. + *

Inputs are the logits, not probabilities. * - * @param data type for {@code loss()} output + * @param data type for {@code loss} output * @param features batch_size x num_classes matrix * @param labels batch_size vector with values in [0, num_classes). * This is the label for the given minibatch entry. + * @param data type for {@code SparseSoftmaxCrossEntropyWithLogits} output and operands * @return a new instance of SparseSoftmaxCrossEntropyWithLogits */ public SparseSoftmaxCrossEntropyWithLogits sparseSoftmaxCrossEntropyWithLogits( diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java index 5d3ab3c1100..d6eed5cbe28 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -29,57 +29,68 @@ /** * Computes softmax cross entropy cost and gradients to backpropagate. - *

* Inputs are the logits, not probabilities. - * - * @param data type for {@code loss()} output + * + * @param data type for {@code loss} output */ -@Operator(group = "nn") +@Operator( + group = "nn" +) public final class SoftmaxCrossEntropyWithLogits extends RawOp { - + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "SoftmaxCrossEntropyWithLogits"; + + private Output loss; + + private Output backprop; + + private SoftmaxCrossEntropyWithLogits(Operation operation) { + super(operation); + int outputIdx = 0; + loss = operation.output(outputIdx++); + backprop = operation.output(outputIdx++); + } + /** * Factory method to create a class wrapping a new SoftmaxCrossEntropyWithLogits operation. - * + * * @param scope current scope * @param features batch_size x num_classes matrix * @param labels batch_size x num_classes matrix * The caller must ensure that each batch of labels represents a valid * probability distribution. + * @param data type for {@code SoftmaxCrossEntropyWithLogits} output and operands * @return a new instance of SoftmaxCrossEntropyWithLogits */ - @Endpoint(describeByClass = true) - public static SoftmaxCrossEntropyWithLogits create(Scope scope, Operand features, Operand labels) { + @Endpoint( + describeByClass = true + ) + public static SoftmaxCrossEntropyWithLogits create(Scope scope, + Operand features, Operand labels) { OperationBuilder opBuilder = scope.env().opBuilder("SoftmaxCrossEntropyWithLogits", scope.makeOpName("SoftmaxCrossEntropyWithLogits")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(labels.asOutput()); opBuilder = scope.apply(opBuilder); return new SoftmaxCrossEntropyWithLogits<>(opBuilder.build()); } - + /** + * Gets loss. * Per example loss (batch_size vector). + * @return loss. */ public Output loss() { return loss; } - + /** + * Gets backprop. * backpropagated gradients (batch_size x num_classes matrix). + * @return backprop. */ public Output backprop() { return backprop; } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "SoftmaxCrossEntropyWithLogits"; - - private Output loss; - private Output backprop; - - private SoftmaxCrossEntropyWithLogits(Operation operation) { - super(operation); - int outputIdx = 0; - loss = operation.output(outputIdx++); - backprop = operation.output(outputIdx); - } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java index 794beab4ded..26498cdce7a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/nn/SparseSoftmaxCrossEntropyWithLogits.java @@ -29,61 +29,71 @@ /** * Computes softmax cross entropy cost and gradients to backpropagate. - *

- * Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept + * Unlike {@code SoftmaxCrossEntropyWithLogits}, this operation does not accept * a matrix of label probabilities, but rather a single label per row * of features. This label is considered to have probability 1.0 for the * given row. - *

- * Inputs are the logits, not probabilities. - * - * @param data type for {@code loss()} output + *

Inputs are the logits, not probabilities. + * + * @param data type for {@code loss} output */ -@Operator(group = "nn") +@Operator( + group = "nn" +) public final class SparseSoftmaxCrossEntropyWithLogits extends RawOp { - + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "SparseSoftmaxCrossEntropyWithLogits"; + + private Output loss; + + private Output backprop; + + private SparseSoftmaxCrossEntropyWithLogits(Operation operation) { + super(operation); + int outputIdx = 0; + loss = operation.output(outputIdx++); + backprop = operation.output(outputIdx++); + } + /** * Factory method to create a class wrapping a new SparseSoftmaxCrossEntropyWithLogits operation. - * + * * @param scope current scope * @param features batch_size x num_classes matrix * @param labels batch_size vector with values in [0, num_classes). * This is the label for the given minibatch entry. + * @param data type for {@code SparseSoftmaxCrossEntropyWithLogits} output and operands * @return a new instance of SparseSoftmaxCrossEntropyWithLogits */ - @Endpoint(describeByClass = true) - public static SparseSoftmaxCrossEntropyWithLogits create(Scope scope, Operand features, Operand labels) { + @Endpoint( + describeByClass = true + ) + public static SparseSoftmaxCrossEntropyWithLogits create(Scope scope, + Operand features, Operand labels) { OperationBuilder opBuilder = scope.env().opBuilder("SparseSoftmaxCrossEntropyWithLogits", scope.makeOpName("SparseSoftmaxCrossEntropyWithLogits")); opBuilder.addInput(features.asOutput()); opBuilder.addInput(labels.asOutput()); opBuilder = scope.apply(opBuilder); return new SparseSoftmaxCrossEntropyWithLogits<>(opBuilder.build()); } - + /** + * Gets loss. * Per example loss (batch_size vector). + * @return loss. */ public Output loss() { return loss; } - + /** + * Gets backprop. * backpropagated gradients (batch_size x num_classes matrix). + * @return backprop. */ public Output backprop() { return backprop; } - - /** The name of this op, as known by TensorFlow core engine */ - public static final String OP_NAME = "SparseSoftmaxCrossEntropyWithLogits"; - - private Output loss; - private Output backprop; - - private SparseSoftmaxCrossEntropyWithLogits(Operation operation) { - super(operation); - int outputIdx = 0; - loss = operation.output(outputIdx++); - backprop = operation.output(outputIdx); - } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java index 0fea3743d95..4f5120a3dbf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java @@ -87,7 +87,7 @@ public class NnOps { * @param logits the logits of type float32 or float64 * @param the type of labels and logits * @return the component-wise logistic losses. - * @throws IllegalArgumentException if logits' and labels' do not have the same shape + * @throws IllegalArgumentException if logits and labels do not have the same shape */ public Operand sigmoidCrossEntropyWithLogits( Operand labels, Operand logits) { @@ -139,7 +139,6 @@ public Operand sigmoidCrossEntropyWithLogits( * @return the softmax cross entropy loss. Its type is the same as {@code logits} and its shape is * the same as {@code labels} except that it does not have the last dimension of {@code * labels}. - * */ public Operand softmaxCrossEntropyWithLogits( Operand labels, Operand logits, int axis) { @@ -181,14 +180,14 @@ public Operand softmaxCrossEntropyWith * @param logits Per-label activations (typically a linear output) of shape {@code [d_0, d_1, ..., * d_{r-1}, numClasses]} and dataType of {@code TFloat16}, {@code TFloat32}, or {@code * TFloat64}. These activation energies are interpreted as unnormalized log probabilities. - * @param The data type for the labels - * @param The data type for the logits and loss + * @param the data type for the labels + * @param the data tyoe for the loss and logits. * @return the loss - * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if the rank - * of the labels is not equal to the rank of the logits minus one. + * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if + * the rank of the labels is not equal to the rank of the logits minus one. */ - public Operand sparseSoftmaxCrossEntropyWithLogits( - Operand labels, Operand logits) { + public Operand sparseSoftmaxCrossEntropyWithLogits( + Operand labels, Operand logits) { return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits( scope, labels, logits); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java index fc3f7739363..432e1b47a3f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SigmoidCrossEntropyWithLogits.java @@ -26,8 +26,7 @@ public class SigmoidCrossEntropyWithLogits { * independent and not mutually exclusive. For instance, one could perform multilabel * classification where a picture can contain both an elephant and a dog at the same time. * - *

For brevity, let {@code x = logits}, {@code z = labels}. The logistic loss in - * pseudo-code is + *

For brevity, let {@code x = logits}, {@code z = labels}. The logistic loss in pseudo-code is * *

    * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java
index 0b2d29d6092..553adf90aad 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SparseSoftmaxCrossEntropyWithLogits.java
@@ -14,7 +14,11 @@
 import org.tensorflow.types.TBfloat16;
 import org.tensorflow.types.TFloat16;
 import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.TFloat64;
 import org.tensorflow.types.TInt32;
+import org.tensorflow.types.TInt64;
+import org.tensorflow.types.family.TFloating;
+import org.tensorflow.types.family.TIntegral;
 import org.tensorflow.types.family.TNumber;
 
 import java.util.ArrayList;
@@ -34,39 +38,37 @@ public class SparseSoftmaxCrossEntropyWithLogits {
    * 

NOTE: * *

For this operation, the probability of a given label is considered exclusive. That is, soft - * classes are not allowed, and the {@code labels} vector must provide a single specific - * index for the true class for each row of {@code logits} (each minibatch entry). For soft - * softmax classification with a probability distribution for each entry, {@link + * classes are not allowed, and the {@code labels} vector must provide a single specific index for + * the true class for each row of {@code logits} (each minibatch entry). For soft softmax + * classification with a probability distribution for each entry, {@link * org.tensorflow.op.NnOps#softmaxCrossEntropyWithLogits}. * *

WARNING: * - *

This op expects unscaled logits, since it performs a {@code softmax} on {@code logits - * } internally for efficiency. Do not call this op with the output of {@code softmax}, - * as it will produce incorrect results. + *

This op expects unscaled logits, since it performs a {@code softmax} on {@code logits } + * internally for efficiency. Do not call this op with the output of {@code softmax}, as it will + * produce incorrect results. * - *

A common use case is to have logits of shape {@code [batchSize, numClasses]} and have - * labels of shape {@code [batchSize]}, but higher dimensions are supported, in which case - * the {@code dim}-th dimension is assumed to be of size {@code numClasses}. {@code - * logits} must have the {@code dataType} of {@code TFloat16}, {@code TFloat32} - * , or {@code TFloat64}, and {@code labels} must have the dtype of {@code TInt32} - * or {@code TInt64}. + *

A common use case is to have logits of shape {@code [batchSize, numClasses]} and have labels + * of shape {@code [batchSize]}, but higher dimensions are supported, in which case the {@code + * dim}-th dimension is assumed to be of size {@code numClasses}. {@code logits} must have the + * {@code dataType} of {@code TFloat16}, {@code TFloat32} , or {@code TFloat64}, and {@code + * labels} must have the dtype of {@code TInt32} or {@code TInt64}. * * @param scope current scope - * @param labels {@code Tensor} of shape {@code [d_0, d_1, ..., d_{r-1}]} (where {@code r - * } is rank of {@code labels} and result) and the dataType is {@code TInt32} - * or {@code TInt64}. Each entry in {@code labels} must be an index in {@code [0, - * numClasses)}. Other values will raise an exception when this op is run on CPU, and - * return {@code NaN} for corresponding loss and gradient rows on GPU. + * @param labels {@code Tensor} of shape {@code [d_0, d_1, ..., d_{r-1}]} (where {@code r } is + * rank of {@code labels} and result) and the dataType is {@code TInt32} or {@code TInt64}. + * Each entry in {@code labels} must be an index in {@code [0, numClasses)}. Other values will + * raise an exception when this op is run on CPU, and return {@code NaN} for corresponding + * loss and gradient rows on GPU. * @param logits Per-label activations (typically a linear output) of shape {@code [d_0, d_1, ..., - * d_{r-1}, numClasses]} and dataType of {@code TFloat16}, {@code TFloat32}, - * or {@code TFloat64}. These activation energies are interpreted as unnormalized log - * probabilities. + * d_{r-1}, numClasses]} and dataType of {@code TFloat16}, {@code TFloat32}, or {@code + * TFloat64}. These activation energies are interpreted as unnormalized log probabilities. * @param the data type for the labels * @param the data tyoe for the loss and logits. * @return the loss - * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if the rank - * of the labels is not equal to the rank of the logits minus one. + * @throws IllegalArgumentException If logits are scalars (need to have {@code rank >= 1}) or if + * the rank of the labels is not equal to the rank of the logits minus one. */ @SuppressWarnings("unchecked") @Endpoint(name = "sparseSoftmaxCrossEntropyWithLogits") @@ -74,15 +76,23 @@ public class SparseSoftmaxCrossEntropyWithLogits { Operand sparseSoftmaxCrossEntropyWithLogits( Scope scope, Operand labels, Operand logits) { scope = scope.withSubScope("SparseSoftmaxCrossEntropyWithLogits"); - Operand preciseLogits; + Operand preciseLogits; if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { preciseLogits = Cast.create(scope, logits, TFloat32.class); + } else if (TFloating.class.isAssignableFrom(logits.type())) { + preciseLogits = (Operand) logits; } else { - preciseLogits = logits; + preciseLogits = Cast.create(scope, logits, TFloat64.class); } - Shape labelsStaticShape = labels.shape(); + Operand iLabels; + if (TIntegral.class.isAssignableFrom(labels.type())) { + iLabels = (Operand) labels; + } else { + iLabels = Cast.create(scope, labels, TInt64.class); + } + Shape labelsStaticShape = iLabels.shape(); org.tensorflow.op.core.Shape labelsShape = - org.tensorflow.op.core.Shape.create(scope, labels); + org.tensorflow.op.core.Shape.create(scope, iLabels); Shape logitsShape = logits.shape(); Shape logitsShortened = logitsShape.take(logitsShape.numDimensions() - 1); @@ -113,7 +123,7 @@ Operand sparseSoftmaxCrossEntropyWithLogits( if (logitsShape.numDimensions() == 2) { org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits smax = org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create( - scope, preciseLogits, labels); + scope, preciseLogits, iLabels); Operand cost = smax.loss(); if (cost.type() != logits.type()) { return Cast.create(scope, cost, logits.type()); @@ -131,7 +141,7 @@ Operand sparseSoftmaxCrossEntropyWithLogits( scope, Equal.create( scope, - org.tensorflow.op.core.Shape.create(scope, labels), + org.tensorflow.op.core.Shape.create(scope, iLabels), Shapes.take( scope, org.tensorflow.op.core.Shape.create(scope, logits), @@ -148,12 +158,12 @@ Operand sparseSoftmaxCrossEntropyWithLogits( long numClassses = logitsShape.size(-1); preciseLogits = Reshape.create(scope, preciseLogits, Constant.arrayOf(scope, -1L, numClassses)); - labels = Reshape.create(scope, labels, Constant.scalarOf(scope, -1)); + iLabels = Reshape.create(scope, iLabels, Constant.scalarOf(scope, -1)); scope.withControlDependencies(shapeChecks); // call raw op org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits smax = org.tensorflow.op.nn.SparseSoftmaxCrossEntropyWithLogits.create( - scope, preciseLogits, labels); + scope, preciseLogits, iLabels); Operand cost = smax.loss(); cost = Reshape.create(scope, cost, labelsShape); if (cost.type() != logits.type()) { diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java new file mode 100644 index 00000000000..0436fdd57cf --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java @@ -0,0 +1,68 @@ +package org.tensorflow.framework.op; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +class NnOpsTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testSigmoidCrossEntropyWithLogits() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + float[] x = new float[] {-100, -2, -2, 0, 2, 2, 2, 100}; + float[] y = new float[] {0, 0, 1, 0, 0, 1, 0.5f, 1}; + + Operand logits = tf.constant(x); + Operand targets = tf.constant(y); + Operand loss = fops.nn.sigmoidCrossEntropyWithLogits(targets, logits); + Operand expected = + tf.constant( + new float[] { + 0.f, 0.126928f, 2.126928f, 0.6931472f, + 2.126928f, 0.126928f, 1.126928f, 0.f + }); + session.evaluate(expected, loss); + } + } + + @Test + public void testSoftmaxCrossEntropyWithLogits() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + float[] x = new float[] {-100, -2, -2, 0, 2, 2, 2, 100}; + float[] y = new float[] {0, 0, 1, 0, 0, 1, 0.5f, 1}; + + Operand logits = tf.constant(x); + Operand targets = tf.constant(y); + Operand loss = fops.nn.softmaxCrossEntropyWithLogits(targets, logits, 0); + + session.evaluate(249.0f, loss); + } + } + + @Test + public void testSparseSoftmaxCrossEntropyWithLogits() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + float[][] x = new float[][] {{0, 0}}; + int[] y = new int[] {0}; + + Operand logits = tf.constant(x); + Operand labels = tf.constant(y); + Operand loss = fops.nn.sparseSoftmaxCrossEntropyWithLogits(labels, logits); + + session.evaluate(0.69314718f, loss); + } + } +} From e4e65f2a09ffb980cf0c6881f90369f9d87633f5 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 2 May 2021 18:48:53 -0400 Subject: [PATCH 15/31] Moved SetOps to FrameworkOps --- .../org/tensorflow/framework/op/{SetsOps.java => SetOps.java} | 4 ++-- .../tensorflow/framework/{metrics/impl => op}/SetOpsTest.java | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) rename tensorflow-framework/src/main/java/org/tensorflow/framework/op/{SetsOps.java => SetOps.java} (98%) rename tensorflow-framework/src/test/java/org/tensorflow/framework/{metrics/impl => op}/SetOpsTest.java (97%) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetOps.java similarity index 98% rename from tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetOps.java index d7833cdbb06..f76947018b5 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetOps.java @@ -24,7 +24,7 @@ import org.tensorflow.types.family.TNumber; /** Implementation of set operations */ -public class SetsOps { +public class SetOps { private final Scope scope; @@ -35,7 +35,7 @@ public class SetsOps { * * @param frameworkOps the TensorFLow framework Ops */ - SetsOps(FrameworkOps frameworkOps) { + SetOps(FrameworkOps frameworkOps) { this.scope = frameworkOps.scope(); this.frameworkOps = frameworkOps; } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java similarity index 97% rename from tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetOpsTest.java rename to tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java index e10f016bd94..7dee866abf2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java @@ -1,9 +1,7 @@ -package org.tensorflow.framework.metrics.impl; +package org.tensorflow.framework.op; import org.junit.jupiter.api.Test; import org.tensorflow.Operand; -import org.tensorflow.framework.op.FrameworkOps; -import org.tensorflow.framework.op.SetsOps; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; From a2ed723aa7fa040e58439e4e17b1799747a68001 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 2 May 2021 18:51:05 -0400 Subject: [PATCH 16/31] Added tensordot and reduceLogSumExp --- .../org/tensorflow/framework/op/MathOps.java | 796 +++++++++++++++++- .../tensorflow/framework/op/MathOpsTest.java | 90 +- 2 files changed, 874 insertions(+), 12 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java index 36f5b692cab..4c2210feb9c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java @@ -14,37 +14,59 @@ =======================================================================*/ package org.tensorflow.framework.op; +import org.tensorflow.Graph; import org.tensorflow.Operand; +import org.tensorflow.Session; import org.tensorflow.framework.losses.impl.LossTuple; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; import org.tensorflow.op.core.AssertThat; +import org.tensorflow.op.core.Concat; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Gather; import org.tensorflow.op.core.Identity; import org.tensorflow.op.core.OnesLike; import org.tensorflow.op.core.Range; import org.tensorflow.op.core.Rank; import org.tensorflow.op.core.ReduceAll; import org.tensorflow.op.core.ReduceMax; +import org.tensorflow.op.core.ReduceProd; import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.op.core.Reshape; import org.tensorflow.op.core.ScatterNd; +import org.tensorflow.op.core.Select; +import org.tensorflow.op.core.SetDiff1d; +import org.tensorflow.op.core.Slice; import org.tensorflow.op.core.Squeeze; import org.tensorflow.op.core.Stack; -import org.tensorflow.op.core.Zeros; +import org.tensorflow.op.core.StopGradient; +import org.tensorflow.op.core.ZerosLike; import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.linalg.Transpose; import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Exp; import org.tensorflow.op.math.GreaterEqual; +import org.tensorflow.op.math.IsFinite; import org.tensorflow.op.math.Less; +import org.tensorflow.op.math.Log; import org.tensorflow.op.math.Maximum; import org.tensorflow.op.math.Mul; import org.tensorflow.op.math.Rsqrt; import org.tensorflow.op.math.Square; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat16; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; public class MathOps { private final Scope scope; @@ -123,7 +145,7 @@ public Operand l2Normalize(Operand x, int[] axis) { * predictions}. */ public Operand confusionMatrix(Operand labels, Operand predictions) { - return confusionMatrix(labels, predictions, null, null, labels.type()); + return confusionMatrix(labels, predictions, null, null); } /** @@ -167,7 +189,7 @@ public Operand confusionMatrix(Operand labels, Operand */ public Operand confusionMatrix( Operand labels, Operand predictions, Operand weights) { - return confusionMatrix(labels, predictions, weights, null, labels.type()); + return confusionMatrix(labels, predictions, weights, null); } /** @@ -204,7 +226,6 @@ public Operand confusionMatrix( * @param weights An optional Operand whose shape matches {@code predictions}. * @param numClasses The possible number of labels the classification task can have. If this value * is null, it will be calculated using both predictions and labels. - * @param type Data type of the confusion matrix. * @param Data type of the confusion matrix. * @return An Operand of type {@code type} with shape {@code [n, n]} representing the confusion * matrix, where {@code n} is the number of possible labels in the classification task. @@ -213,11 +234,7 @@ public Operand confusionMatrix( * predictions}. */ public Operand confusionMatrix( - Operand labels, - Operand predictions, - Operand weights, - Operand numClasses, - Class type) { + Operand labels, Operand predictions, Operand weights, Operand numClasses) { Scope lScope = scope.withSubScope("confusionMatrix"); LossTuple tuple = removeSqueezableDimensions(labels, predictions, 0); Operand lLabels = Cast.create(lScope, tuple.getLabels(), TInt64.class); @@ -293,7 +310,8 @@ public Operand confusionMatrix( Operand indices = Stack.create(lScope, Arrays.asList(lLabels, lPredictions), Stack.axis(1L)); Operand values = weights == null ? OnesLike.create(lScope, predictions) : weights; - Operand zeroMatrix = Zeros.create(lScope, Cast.create(lScope, shape, TInt32.class), type); + /// Operand zeroMatrix = Zeros.create(lScope, Cast.create(lScope, shape, TInt32.class), + // type); return ScatterNd.create(lScope, indices, values, shape); } @@ -317,7 +335,7 @@ public LossTuple removeSqueezableDimensions( int labelsRank = labelsShape.numDimensions(); if (predictionsRank != Shape.UNKNOWN_SIZE || labelsRank != Shape.UNKNOWN_SIZE) { - // Use static rank. + // Use rank. int rankDiff = predictionsRank - labelsRank; if (rankDiff == expectedRankDiff + 1 && Shape.isCompatible(predictionsShape.size(-1), 1)) { predictions = Squeeze.create(lScope, predictions); @@ -350,6 +368,13 @@ public LossTuple removeSqueezableDimensions( return new LossTuple<>(labels, predictions); } + /** + * Creates an Operand that has all axes contained in the Operand's shape. + * + * @param op the Operand + * @param THe Data type for the Operand + * @return an Operand that has all axes contained in the Operand's shape.. + */ public Operand allAxes(Operand op) { int rank = op.shape().numDimensions(); if (rank != Shape.UNKNOWN_SIZE) { @@ -363,4 +388,753 @@ public Operand allAxes(Operand op) { scope, Constant.scalarOf(scope, 0), Rank.create(scope, op), Constant.scalarOf(scope, 1)); } } + + /** + * Transpose and reshape the input for contraction op. + * + *

This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul` using + * `array_ops.transpose` and `array_ops.reshape`. The method takes a tensor and performs the + * correct transpose and reshape operation for a given set of indices. It returns the reshaped + * tensor as well as a list of indices necessary to reshape the tensor again after matrix + * multiplication. + * + * @param the type of Operand + * @param a the Tensor + * @param axis unique indices specifying valid axes of `a`. + * @param flipped whether to flip the dimensions or not + * @return A tuple (reshapedA, freeDims, freeDimsStatic) where reshapedA is a reshaped to allow + * contraction via matmul, freeDims` is a TInt32 Operand, depending on whether the shape of a + * is fully specified, and freeDimsStatic is either a list of integers and null values, or + * None, representing the inferred shape of the free dimensions + */ + private Object[] tensordotReshape( + Operand a, Operand axis, boolean flipped) { + Shape aShape = a.shape(); + + if (!aShape.hasUnknownDimension()) { // calculate using values + long[] aShapeDims = aShape.asArray(); + if (aShapeDims == null) aShapeDims = new long[0]; + long[] aDimsIndex = new long[aShapeDims.length]; + for (int i = 0; i < aDimsIndex.length; i++) aDimsIndex[i] = i; + + // get int array from axis Operand + int[] iAxes = getIntArray(axis); + // Convert negative axes to positive + for (int i = 0; i < iAxes.length; i++) + iAxes[i] = iAxes[i] >= 0 ? iAxes[i] : Math.floorMod(iAxes[i], iAxes.length); + + // convert integer axis to long axis + long[] lAxes = Arrays.stream(iAxes).mapToLong(i -> i).toArray(); + + // create list of the axes, dims, and free axes + List axesList = Arrays.stream(lAxes).boxed().collect(Collectors.toList()); + List freeList = Arrays.stream(aDimsIndex).boxed().collect(Collectors.toList()); + freeList.removeAll(axesList); + + // create array of free dims + long[] free = freeList.stream().mapToLong(i -> i).toArray(); + long[] freeDims = new long[free.length]; + for (int i = 0; i < free.length; i++) freeDims[i] = aShapeDims[(int) free[i]]; + + // Calculate the free dim by doing a reduce prod + long prodFree = 1; + for (long i : freeDims) { + prodFree *= i; + } + + // calculate the used dims by doing a reduce prod + long prodAxis = 1; + for (long i : lAxes) { + prodAxis *= aShapeDims[(int) i]; + } + + // setup the permutations array for the transpose + long[] perm = new long[freeDims.length + lAxes.length]; + Shape newShape; + if (flipped) { + System.arraycopy(lAxes, 0, perm, 0, lAxes.length); + System.arraycopy(free, 0, perm, lAxes.length, free.length); + newShape = Shape.of(prodAxis, prodFree); + } else { + System.arraycopy(free, 0, perm, 0, free.length); + System.arraycopy(lAxes, 0, perm, freeDims.length, lAxes.length); + newShape = Shape.of(prodFree, prodAxis); + } + + Operand aTrans; + long[] arrange = new long[lAxes.length]; + for (int i = 0; i < arrange.length; i++) arrange[i] = i; + + // if the permutations is not equals to the natural order of the dims, then do a transpose + if (!Arrays.equals(perm, arrange)) { + aTrans = Transpose.create(scope, a, Constant.vectorOf(scope, perm)); + } else { + aTrans = a; + } + + // reshape the final result to the new Shape, if necessary + Operand aReshaped = + aTrans.asOutput().shape().equals(newShape) + ? aTrans + : Reshape.create(scope, aTrans, Constant.vectorOf(scope, newShape.asArray())); + // return a tuple for the reshaped Operand, and Operand for the free dimensions, and a long + // array for the free dimensions + return new Object[] {aReshaped, Constant.vectorOf(scope, freeDims), freeDims}; + + } else { // calculate dynamically + + long[] freeDimsStatic = null; + Operand one = Constant.scalarOf(scope, 1); + Operand minusOne = Constant.scalarOf(scope, -1); + Operand zero = Constant.scalarOf(scope, 0); + org.tensorflow.op.core.Shape tShape = org.tensorflow.op.core.Shape.create(scope, a); + Operand axesT; + Operand freeT; + if (aShape.numDimensions() + != Shape.UNKNOWN_SIZE) { // we know the rank, but there are unknown dimensions + long[] aShapeDims = aShape.asArray(); + if (aShapeDims == null) aShapeDims = new long[0]; + + // get int array from axis Operand + int[] iAxes = getIntArray(axis); + // Convert negative axes to positive + for (int i = 0; i < iAxes.length; i++) + iAxes[i] = iAxes[i] >= 0 ? iAxes[i] : Math.floorMod(iAxes[i], iAxes.length); + + // convert integer axis to long axis + long[] lAxes = Arrays.stream(iAxes).mapToLong(i -> i).toArray(); + + // create list of the axes, dims, and free axes + List axesList = Arrays.stream(lAxes).boxed().collect(Collectors.toList()); + List dimsList = Arrays.stream(aShapeDims).boxed().collect(Collectors.toList()); + List freeList = new ArrayList<>(axesList); + freeList.removeAll(dimsList); + + // create array of free dims + long[] freeDims = freeList.stream().mapToLong(i -> i).toArray(); + freeDimsStatic = freeDims; + + axesT = Constant.vectorOf(scope, iAxes); + freeT = Cast.create(scope, Constant.vectorOf(scope, freeDims), TInt32.class); + + } else { // we don't know the rank yet + Rank rank = Rank.create(scope, a); + + // convert axis to positive + axesT = + Select.create( + scope, + GreaterEqual.create(scope, axis, Constant.scalarOf(scope, 0)), + axis, + Add.create(scope, axis, rank)); + + SetDiff1d diff = + SetDiff1d.create( + scope, Range.create(scope, Constant.scalarOf(scope, 0), rank, one), axesT); + freeT = diff.out(); + } + Operand freeDims = Gather.create(scope, tShape, freeT, zero); + Operand axesDims = Gather.create(scope, tShape, axesT, zero); + Operand prodFreeDims = ReduceProd.create(scope, freeDims, minusOne); + Operand prodAxesDims = ReduceProd.create(scope, axesDims, minusOne); + Operand perm; + Operand newShape; + if (flipped) { + perm = Concat.create(scope, Arrays.asList(axesT, freeT), zero); + newShape = Stack.create(scope, Arrays.asList(prodAxesDims, prodFreeDims)); + } else { + perm = Concat.create(scope, Arrays.asList(freeT, axesT), zero); + newShape = Stack.create(scope, Arrays.asList(prodFreeDims, prodAxesDims)); + } + Operand aReshaped = Reshape.create(scope, Transpose.create(scope, a, perm), newShape); + return new Object[] {aReshaped, freeDims, freeDimsStatic}; + } + } + + /** + * Gets an int array from an Operand<TInt32> operand. + * + * @param axes the Operand to fetch the values + * @return the int array from an Operand<TInt32> + */ + private int[] getIntArray(Operand axes) { + List result = new ArrayList<>(); + if (scope.env().isEager()) { + axes.asTensor().scalars().forEach(s -> result.add(s.getInt())); + } else { + try (Session session = new Session((Graph) scope.env()); + TInt32 tensor = (TInt32) session.runner().fetch(axes).run().get(0)) { + tensor.scalars().forEach(s -> result.add(s.getInt())); + } + } + return result.stream().mapToInt(i -> i).toArray(); + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param a the Operand to analyze + * @param axis the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings("unchecked") + private Operand[] tensordotAxes(Operand a, int axis) { + Shape aShape = a.asOutput().shape(); + if (axis < 0) { + throw new IllegalArgumentException("'axis' must be at least 0."); + } + int rank = aShape.numDimensions(); + Operand[] result = new Operand[2]; + if (rank != Shape.UNKNOWN_SIZE) { + if (axis > rank) { + throw new IllegalArgumentException( + String.format( + "'axis' must not be larger than the number of dimensions of tensor %s.", rank)); + } + int min = rank - axis; + int postRange = rank - min; + int[] postAxis = new int[postRange]; + for (int i = 0; i < postRange; i++) postAxis[i] = i + min; + + int[] preAxis = new int[axis]; + for (int i = 0; i < axis; i++) preAxis[i] = i; + + result[0] = Constant.vectorOf(scope, postAxis); + result[1] = Constant.vectorOf(scope, preAxis); + } else { + Rank rankT = Rank.create(scope, a); + Constant axisT = Constant.scalarOf(scope, axis); + Constant one = Constant.scalarOf(scope, 1); + Constant zero = Constant.scalarOf(scope, 0); + AssertThat assertion = + AssertThat.create( + scope, + Less.create(scope, axisT, rankT), + Arrays.asList( + Constant.scalarOf( + scope, "'axes' must not be larger than the number of dimensions of tensor "), + rankT)); + Scope scope1 = scope.withControlDependencies(Collections.singletonList(assertion)); + result[0] = Range.create(scope1, Sub.create(scope, rankT, axisT), rankT, one); + result[1] = Range.create(scope1, zero, axisT, one); + } + return result; + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param a the Operand to analyze + * @param axes the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings({"unchecked", "unused"}) + private Operand[] tensordotAxes(Operand a, int[] axes) { + if (axes.length != 2) + throw new IllegalArgumentException( + "'axes' must have length 1 or 2, provided with " + axes.length); + int[] aAxis = new int[] {axes[0]}; + int[] bAxis = new int[] {axes[1]}; + Operand[] result = new Operand[2]; + result[0] = Constant.vectorOf(scope, aAxis); + result[1] = Constant.vectorOf(scope, bAxis); + + return result; + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param a the Operand to analyze + * @param axes the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings({"unchecked", "unused"}) + private Operand[] tensordotAxes(Operand a, int[][] axes) { + if (axes.length != 2) + throw new IllegalArgumentException( + "'axes' must have length 1 or 2, provided with " + axes.length); + int[] aAxis = axes[0]; + int[] bAxis = axes[1]; + if (aAxis.length != bAxis.length) + throw new IllegalArgumentException( + String.format( + "Different number of contraction axes 'a' and 'b', %d != %d", + aAxis.length, bAxis.length)); + Operand[] result = new Operand[2]; + result[0] = Constant.vectorOf(scope, aAxis); + result[1] = Constant.vectorOf(scope, bAxis); + return result; + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param a the Operand to analyze + * @param axes the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings({"unchecked", "unused"}) + private Operand[] tensordotAxes(Operand a, Operand axes) { + + Constant one = Constant.scalarOf(scope, 1); + Constant zero = Constant.scalarOf(scope, 0); + Operand[] result = new Operand[2]; + result[0] = + Slice.create( + scope, + axes, + Cast.create(scope, zero, TInt32.class), + Cast.create(scope, one, TInt32.class)); + result[1] = + Slice.create( + scope, + axes, + Cast.create(scope, one, TInt32.class), + Cast.create(scope, one, TInt32.class)); + return result; + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param axis sum over the last N axes of a and the + * first N axes of b in order. If `axes=0`, computes the outer + * product between `a` and `b`. + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public Operand tensordot(Operand a, Operand b, int axis) { + + Operand[] abAxis = tensordotAxes(a, axis); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + return tensordot(a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param axes If axes is a scalar, sum over the last N axes of a and the + * first N axes of b in order. If axes is a list, the first and second row + * contain the set of unique integers specifying axes along which the + * contraction is computed, for `a` and `b`, respectively. The number of + * axes for `a` and `b` must be equal. If `axes=0`, computes the outer + * product between `a` and `b`. + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public Operand tensordot( + Operand a, Operand b, Operand axes) { + + Operand[] abAxis = tensordotAxes(a, axes); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + + return tensordot(a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param axes the first and second row + * contain the set of unique integers specifying axes along which the + * contraction is computed, for `a` and `b`, respectively. The number of + * axes for `a` and `b` must be equal. I + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public Operand tensordot(Operand a, Operand b, int[] axes) { + + Operand[] abAxis = tensordotAxes(a, axes); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + + return tensordot(a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param axes the first and second row + * contain the set of unique integers specifying axes along which the + * contraction is computed, for `a` and `b`, respectively. The number of + * axes for `a` and `b` must be equal. I + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public Operand tensordot(Operand a, Operand b, int[][] axes) { + + Operand[] abAxis = tensordotAxes(a, axes); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + + return tensordot(a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param aAxis axes for the a Operand + * @param bAxis axes for the b Operand + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @SuppressWarnings({"unchecked", "unused"}) + @Endpoint(name = "tensordot") + public Operand tensordot( + Operand a, Operand b, Operand aAxis, Operand bAxis) { + + if (a.type().equals(TBfloat16.class) || a.type().equals(TFloat16.class)) { + throw new IllegalArgumentException( + String.format( + "Operand 'a' must be either TFloat32 or TFloat64 DataType, 'a' is a %s DataType", + a.type().getSimpleName())); + } + if (!a.type().equals(b.type())) { + throw new IllegalArgumentException( + String.format( + "Operands a and b must be the same data type, a is %s DataType, b is %s DataType", + a.type().getSimpleName(), b.type().getSimpleName())); + } + + // first result is Operand, second result is Operand, third result is long[] and it is + // ignored here. + Object[] aResult = tensordotReshape(a, aAxis, false); + Operand reshapedA = (Operand) aResult[0]; + Operand aFreeDims = (Operand) aResult[1]; + long[] aFreeDimsStatic = (long[]) aResult[2]; + + // first result is Operand, second result is Operand, third result is long[] and it is + // ignored here. + Object[] bResult = tensordotReshape(b, bAxis, true); + Operand reshapedB = (Operand) bResult[0]; + Operand bFreeDims = (Operand) bResult[1]; + long[] bFreeDimsStatic = (long[]) bResult[2]; + + Operand abMatmul = frameworkOps.linalg.matmul(reshapedA, reshapedB); + long[] abDimsStatic = new long[aFreeDimsStatic.length + bFreeDimsStatic.length]; + System.arraycopy(aFreeDimsStatic, 0, abDimsStatic, 0, aFreeDimsStatic.length); + System.arraycopy( + bFreeDimsStatic, 0, abDimsStatic, aFreeDimsStatic.length, bFreeDimsStatic.length); + if (!abMatmul.shape().hasUnknownDimension() + && abMatmul.shape().equals(Shape.of(abDimsStatic))) { + return abMatmul; + } else { + return Reshape.create(scope, abMatmul, Constant.vectorOf(scope, abDimsStatic)); + } + } + + /** + * Computes log(sum(exp(elements across dimensions of a tensor))). Reduces {@code input_tensor} + * along the dimensions given in {@code axes}. + * + *

Reduces `{@code input} along the dimensions given in {@code axes}. Unless {@code keepdims} + * is true, the rank of the tensor is reduced by 1 for each of the entries in {@code axes}, which + * must be unique. If {@code keepdims} is true, the reduced dimensions are retained with length 1. + * If {@code axes} has no entries, all dimensions are reduced, and a tensor with a single element + * is returned. This function is more numerically stable than {@code log(sum(exp(input)))}. It + * avoids overflows caused by taking the exp of large inputs and underflows caused by taking the + * log of small inputs. + * + * @param input The tensor to reduce. + * @param axes The dimensions to reduce. If null, reduces all dimensions. Must be in the range + * {@link [-rank(input_tensor), rank(input_tensor)]}. + * @param keepDims If true, retains reduced dimensions with length 1. + * @return The reduced tensor. + */ + @Endpoint(name = "reduceLogSumExp") + public Operand reduceLogSumExp( + Operand input, int[] axes, boolean keepDims) { + Operand reduceDims = reductionDims(input, axes); + Operand rawMax = reduceMaxWithDims(input, axes, keepDims, reduceDims); + Operand myMax = + StopGradient.create( + scope, + Select.create( + scope, IsFinite.create(scope, rawMax), rawMax, ZerosLike.create(scope, rawMax))); + + Operand result = + Log.create( + scope, + reduceSumWithDims( + Exp.create(scope, Sub.create(scope, input, myMax)), axes, keepDims, reduceDims)); + + if (!keepDims) { + myMax = Reshape.create(scope, myMax, org.tensorflow.op.core.Shape.create(scope, result)); + } + result = Add.create(scope, result, myMax); + return mayReduceToScalar(keepDims, axes, result); + } + + private Operand reduceSumWithDims( + Operand input, int[] axes, boolean keepDims, Operand dims) { + return mayReduceToScalar( + keepDims, axes, ReduceSum.create(scope, input, dims, ReduceSum.keepDims(keepDims))); + } + + private Operand reduceMaxWithDims( + Operand input, int[] axes, boolean keepDims, Operand dims) { + return mayReduceToScalar( + keepDims, axes, ReduceMax.create(scope, input, dims, ReduceMax.keepDims(keepDims))); + } + + /** + * Sets a reduction's output shape to be a scalar if possible. + * + * @return the operand, possibly reduced to a scalar. + */ + private Operand mayReduceToScalar( + boolean keepDims, int[] axes, Operand output) { + + if ((output.shape().numDimensions() == Shape.UNKNOWN_SIZE + || output.shape().hasUnknownDimension()) + && !keepDims + && axes == null) { + return Reshape.create(scope, output, Constant.tensorOf(scope, Shape.scalar())); + } else { + return output; + } + } + + /** + * Reduce dimensions based on axis + * + * @param input the input + * @param axes he dimensions to reduce, may be null + * @return the dimensions to be reduced. + */ + private Operand reductionDims(Operand input, int[] axes) { + if (axes != null) { + return Constant.vectorOf(scope, axes); + } + long rank = input.shape().numDimensions(); + if (rank != Shape.UNKNOWN_SIZE) { + int[] dims = new int[(int) rank]; + for (int i = 0; i < rank; i++) { + dims[i] = i; + } + return Constant.vectorOf(scope, dims); + + } else { + return Range.create( + scope, + Constant.scalarOf(scope, 0), + Rank.create(scope, input), + Constant.scalarOf(scope, 1)); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java index 326e3cdc2d1..dda5a7c6eaa 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/MathOpsTest.java @@ -5,9 +5,12 @@ import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt64; +import static org.junit.jupiter.api.Assertions.assertThrows; + class MathOpsTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -386,7 +389,7 @@ public void testL2Normalize() { Ops tf = session.getTF(); FrameworkOps fops = FrameworkOps.create(tf); Operand input = tf.constant(array); - Operand result = fops.math.l2Normalize(tf.constant(array), new int[]{ 0,1,2}); + Operand result = fops.math.l2Normalize(tf.constant(array), new int[] {0, 1, 2}); session.evaluate(tf.constant(expectedArray), result); } } @@ -410,4 +413,89 @@ public void testConfusionMatrix() { session.evaluate(tf.constant(expected), result); } } + + @Test + public void testTensorDotValid() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + int[] axes1 = new int[] {1, 2}; + int[][] axes2 = new int[][] {{1}, {2}}; + int[][] axes3 = new int[2][0]; + int axes4 = 0; + + Operand a = tf.ones(tf.constant(Shape.of(3, 3)), TFloat32.class); + Operand b = tf.constant(new float[][][] {{{2, 3, 1}}}); + + Operand ans = fops.math.tensordot(a, b, axes1); + Operand expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}}); + session.evaluate(expected, ans); + + ans = fops.math.tensordot(a, b, axes2); + expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}}); + session.evaluate(expected, ans); + + ans = fops.math.tensordot(a, b, axes3); + + float[][][][][] expectedArray = + new float[][][][][] { + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}} + }; + ans = fops.math.tensordot(a, b, axes3); + expected = tf.constant(expectedArray); + session.evaluate(expected, ans); + + ans = fops.math.tensordot(a, b, axes4); + expected = tf.constant(expectedArray); + session.evaluate(expected, ans); + } + } + + @Test + public void testTensorDotInValidAxis() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + Operand a = tf.constant(new float[][] {{1, 2}, {3, 4}}); + Operand b = tf.constant(new float[][] {{1, 2}, {3, 4}}); + assertThrows(IllegalArgumentException.class, () -> fops.math.tensordot(a, b, -1)); + assertThrows(IllegalArgumentException.class, () -> fops.math.tensordot(a, b, 3)); + assertThrows( + IllegalArgumentException.class, () -> fops.math.tensordot(a, b, new int[] {1})); + assertThrows( + IllegalArgumentException.class, () -> fops.math.tensordot(a, b, new int[][] {{1}})); + assertThrows( + IllegalArgumentException.class, + () -> fops.math.tensordot(a, b, new int[][] {{1}, {0, 1}})); + + assertThrows( + ArrayIndexOutOfBoundsException.class, + () -> fops.math.tensordot(a, b, new int[][] {{0}, {7}})); + } + } + + @Test + public void testReduceLogSumExp() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + Operand x = + tf.constant( + new float[][] { + {0.43346116f, 0.8569728f, 0.57155997f, 0.0743812f, 0.63846475f}, + {0.8165283f, 0.26554802f, 0.37025765f, 0.8255019f, 0.45682374f}, + {0.93511814f, 0.52291054f, 0.80983895f, 0.11580781f, 0.8111686f}, + {0.49967498f, 0.27537802f, 0.48554695f, 0.28238368f, 0.7989301f}, + {0.8958915f, 0.84870094f, 0.56874424f, 0.08818512f, 0.13915819f} + }); + + Operand result = fops.math.reduceLogSumExp(x, new int[] {0, 1}, false); + session.evaluate(3.7911222f, result); + } + } } From be1fe6678c7e8d44fdb81d8bc47bba250c909b1f Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 2 May 2021 18:52:43 -0400 Subject: [PATCH 17/31] Added frameworkOps for nn and linalg --- .../java/org/tensorflow/framework/op/FrameworkOps.java | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java index c8b234f2c51..d9e3eec4b21 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -31,8 +31,9 @@ public class FrameworkOps { public final Ops coreOps; public final NnOps nn; - public final SetsOps sets; + public final SetOps sets; public final MathOps math; + public final LinalgOps linalg; private final Scope scope; /** @@ -44,8 +45,9 @@ private FrameworkOps(Scope scope) { this.coreOps = Ops.create(scope.env()); this.scope = scope; nn = new NnOps(this); - sets = new SetsOps(this); + sets = new SetOps(this); math = new MathOps(this); + linalg = new LinalgOps(this); } /** @@ -57,8 +59,9 @@ private FrameworkOps(Ops coreOps) { this.coreOps = coreOps; this.scope = coreOps.scope(); nn = new NnOps(this); - sets = new SetsOps(this); + sets = new SetOps(this); math = new MathOps(this); + linalg = new LinalgOps(this); } /** From 7b51e7fea5ae7e738c7360d4da6c8f7994266a96 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 2 May 2021 18:53:15 -0400 Subject: [PATCH 18/31] Modified to use FrameworkOps --- .../tensorflow/framework/losses/Losses.java | 1 + .../tensorflow/framework/metrics/MeanIoU.java | 11 +- .../framework/metrics/impl/MetricsHelper.java | 561 ++++++++++++++++-- .../metrics/impl/WeightsBroadcastOps.java | 4 +- 4 files changed, 532 insertions(+), 45 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 398588cee67..6700f2569f0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -27,6 +27,7 @@ import org.tensorflow.op.math.Softplus; import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; import static org.tensorflow.framework.utils.CastHelper.cast; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java index 22baab3d6cb..70cd826f625 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanIoU.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.initializers.Zeros; -import org.tensorflow.framework.metrics.impl.MetricsHelper; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -124,8 +124,8 @@ public Assign getInitializer() { * @param sampleWeights Optional weighting of each example. Defaults to 1, if null. Rank is either * 0, or the same rank as labels, and must be broadcastable to labels. * @return the Operands that updates totalConfusionMatrix variable - * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} labels rank, - * and if the predictions size is not equal to the labels size + * @throws IllegalArgumentException if the weights rank is not 0, and weights rank @{code !=} + * labels rank, and if the predictions size is not equal to the labels size */ @Override public List updateStateList( @@ -167,10 +167,11 @@ public List updateStateList( tSampleWeights = getTF().shape.flatten(tSampleWeights); } + FrameworkOps fops = FrameworkOps.create(getTF()); // Accumulate the prediction to current confusion matrix. Operand currentCM = - MetricsHelper.confusionMatrix( - getTF(), tLabels, tPredictions, getTF().constant(numClasses), tSampleWeights, type); + fops.math.confusionMatrix( + tLabels, tPredictions, tSampleWeights, getTF().constant(numClasses)); return Collections.singletonList(getTF().assignAdd(totalConfusionMatrix, currentCM)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java index a82e1760d1f..a4e19d58bcb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -15,21 +15,35 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.OneHot; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.op.core.Variable; import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.TopK; import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -44,8 +58,8 @@ public class MetricsHelper { "weights can not be broadcast to values."; /** - * Asserts that the sampleWeights can be broadcast to the same shape as values - * + * Asserts that the {@code sampleWeights} can be broadcast to the same shape as {@code values + * } * *

In losses and metrics, limited weight broadcasting is supported. Weights must be either * scalar, or the same rank as the target values, with each dimension either 1, or the same as the @@ -54,11 +68,11 @@ public class MetricsHelper { * @param tf the TensorFlow Ops * @param sampleWeights the sample weights. * @param values the values to which weights are applied. - * @return Operation with control dependencies to ensure sampleWeight - * can be broadcast to values + * @return {@code Operation} with control dependencies to ensure {@code sampleWeight} + * can be broadcast to {@code values} * @param the type of Operand - * @throws NotBroadcastableException If static checks determine sampleWeights has an - * incorrect shape that prohibit broadcasting to values + * @throws NotBroadcastableException If static checks determine {@code sampleWeights} has an + * incorrect shape that prohibit broadcasting to {@code values} */ @SuppressWarnings("unchecked") public static Op assertBroadcastable( @@ -79,7 +93,7 @@ public static Op assertBroadcastable( && !valuesShapeStatic.hasUnknownDimension()) { if (weightsRankStatic == 0) { return tf.withSubScope("staticScalarCheckSuccess") - .withControlDependencies(Collections.EMPTY_LIST) + .withControlDependencies(Collections.emptyList()) .noOp(); } if (weightsRankStatic != valuesRankStatic) { @@ -89,8 +103,8 @@ public static Op assertBroadcastable( ASSERT_BROADCAST_ERROR_PREFIX, valuesRankStatic, weightsRankStatic, - valuesShapeStatic.toString(), - weightsShapeStatic.toString())); + valuesShapeStatic, + weightsShapeStatic)); } for (int i = 0; i < valuesRankStatic; i++) { @@ -101,8 +115,8 @@ public static Op assertBroadcastable( "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", ASSERT_BROADCAST_ERROR_PREFIX, i, - valuesShapeStatic.toString(), - weightsShapeStatic.toString())); + valuesShapeStatic, + weightsShapeStatic)); } } return tf.withSubScope("staticDimsCheckSuccess") @@ -187,13 +201,13 @@ private static Operand canBroadcastDims( } /** - * Broadcast weights to the same shape as values. + * Broadcast {@code weights} to the same shape as {@code values}. * * @param tf the TensorFlow ops - * @param weights Operand whose shape is broadcastable to values. + * @param weights Operand whose shape is broadcastable to {@code values}. * @param values Operand of any shape * @param the type of Operands - * @return weights broadcast to values shape + * @return {@code weights} broadcast to {@code values} shape */ public static Operand broadcastWeights( Ops tf, Operand weights, Operand values) { @@ -214,11 +228,473 @@ public static Operand broadcastWeights( return ctf.math.mul(weights, tf.onesLike(values)); } - // aliases for mean + /** + * Checks that all the Symbolic Shapes are consistent. + * + * @param tf the TensorFlow Ops + * @param symbols the list of Symbolic Shapes + * @param message the error message if the shapes are not consistent. + * @return a list of Operands to check the consistency of the symbolic shapes ready to add to a + * control dependency. + */ + public static List assertShapes( + Ops tf, List> symbols, String message) { + List updateOperations = new ArrayList<>(); + // check that the symbolic shape rank matches the operands rank. + symbols.forEach( + symbol -> { + Operand operand = symbol.getOperand(); + int rank = symbol.rank(); + Rank tfRank = tf.rank(operand); + Op assertion = + tf.withSubScope("assertShapes-1") + .assertThat( + tf.math.equal(tfRank, tf.constant(rank)), + Collections.singletonList(tf.constant(message))); + updateOperations.add(assertion); + }); + + Map> dict = new HashMap<>(); + + // check that each operand's dimension size equals the corresponding symbolic shape's dimensions + // size + symbols.forEach( + symbol -> { + AtomicLong ll = new AtomicLong(); + symbol + .getSymbols() + .forEach( + s -> { + Operand size = dict.get(s); + if (size == null) { + // save size for later checks + size = + tf.shape.size(symbol.getOperand(), tf.constant(ll.get()), TInt64.class); + dict.put(s, size); + } + Op assertion = + tf.withSubScope("assertShapes-2") + .assertThat( + tf.math.equal( + tf.shape.size( + symbol.getOperand(), + tf.constant(ll.getAndIncrement()), + TInt64.class), + size), + Collections.singletonList(tf.constant(message))); + updateOperations.add(assertion); + }); + }); + + return updateOperations; + } + + /** + * Returns an op to update the given confusion matrix variables. + * + *

For every pair of values in {@code labels} and {@code predictions}: + * + *

+   * TRUE_POSITIVES:  {@code labels} == true and {@code predictions} > thresholds
+   * FALSE_POSITIVES: {@code labels} == true and {@code predictions} <= thresholds
+   * TRUE_NEGATIVES:  {@code labels} == false and {@code predictions} <= thresholds
+   * FALSE_NEGATIVE:  {@code labels} == false and {@code predictions} > thresholds
+   * 
+ * + *

The results will be weighted and added together. When multiple thresholds are provided, we + * will repeat the same for every threshold. + * + *

For estimation of these metrics over a stream of data, the function creates an `update_op` + * operation that updates the given variables. + * + *

{@code labels}, {@code predictions}, and {@code sampleWeight} tensors are + * aligned by {@link LossesHelper#removeSqueezableDimensions(Ops, Operand, Operand)}. {@code + * sampleWeight} is then broadcast to the shape of {@code predictions}. + * + * @param tf the TensorFlow Ops + * @param variablesToUpdate map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding variables to update as values. If {@code multiLabel}, then the variable + * shapes are (T, D), where T is the number of thresholds and D is the number of classes + * (after slicing by {@code classIndex}, if provided). If {@code multiLabels}, then + * the variable shapes are (T). + * @param varInitializers map with {@link ConfusionMatrixEnum} values as valid keys and + * corresponding initializer Operands to for {@code variablesToUpdate}. + * @param labels the labels. Will be cast to {@link TBool}. Shape (N, Cx, L1?), where N is the + * number of examples, Cx is zero or more class dimensions, and L1 is a potential extra + * dimension of size 1 that would be squeezed. + * @param predictions the predictions shape (N, Cx, P1?) + * @param thresholds thresholds in the range {@code [0, 1]}, or {@link #NEG_INF} is used when + * topK is set + * @param topK optional, indicates that only the top k predictions should be considered. Applied + * before possibly slicing by {@code classIndex}. + * @param classIndex optional, limits the prediction and labels to the specified class. This is an + * integer index into the first dimension of Cx. + * @param sampleWeight optional {@code Tensor} that is aligned with labels and predictions as + * explained above. Use weights of 0 to mask values. + * @param multiLabel indicates whether multidimensional prediction/labels should be treated as + * multilabel responses, or flattened into a single label. When true, the values of {@code + * variablesToUpdate} must have a second dimension equal to the number of labels and + * predictions per example, and those tensors must not be RaggedTensors. + * @param labelWeights tensor of non-negative weights for multilabel data. The weights are applied + * when calculating TRUE_POSITIVES, FALSE_POSITIVES, TRUE_NEGATIVES, and FALSE_NEGATIVES + * without explicit multilabel handling (i.e. when the data is to be flattened). Must have + * shape (Dx), which is the same as (Cx) referenced above, except that if {@code classIndex + * } is provided, then the final dimension of Dx is 1. These weights will be broadcast + * across the 0th dimension (the examples dimension) of {@code predictions}. May be null. + * Must be null if {@code multiLabel}. + * @param the data type for the variables + * @throws IllegalArgumentException If {@code predictions} and {@code labels} have + * mismatched shapes, or if {@code sampleWeight} is not null and its shape + * doesn't match {@code predictions}, or if {@code multiLabel && labelWeights != null}.. + * @return an op to update the given confusion matrix variables. + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static List updateConfusionMatrixVariables( + Ops tf, + Map> variablesToUpdate, + Map> varInitializers, + Operand labels, + Operand predictions, + Operand thresholds, + Integer topK, + Integer classIndex, + Operand sampleWeight, + boolean multiLabel, + Operand labelWeights) { + if (multiLabel && labelWeights != null) + throw new IllegalArgumentException( + "labelWeights for multilabel data should be handled outside of updateConfusionMatrixVariables when multiLabel is true."); + + if (variablesToUpdate == null || variablesToUpdate.isEmpty()) { + return Collections.EMPTY_LIST; + } + + Operand tLabels = labels; + Operand tPredictions = predictions; + Operand tSampleWeight = sampleWeight; + + // We will tile data for threshold comparisons. We want a cross product of thresholds and + // predictions/labels: + // In the multilabel case, we want a data shape of (T, N, D). + // else (T, ND). + // where + // T is numThresholds (the size of the 0th dimension of thresholds) + // N is the number of examples (the 0th dimension of labels and predictions) + // Dx == Cx except that if classIndex != null, + // then the last dimension of Dx is size 1 + // D is the product of all Dx + // ND is N * D + + // size of the 0th dimension of thresholds + // reshape to scalar for operations later. + Operand numThresholds = + tf.reshape(tf.shape.size(thresholds, tf.constant(0)), tf.constant(Shape.scalar())); + + // if multilabel, then (rank(thresholds) == 1) + // else true + Operand oneThresh; + if (multiLabel) { + oneThresh = tf.math.equal(tf.constant(1), tf.rank(thresholds)); + } else { + // TODO handle Ragged Tensors???? + // [y_pred, + // y_true], _ = ragged_assert_compatible_and_get_flat_values([y_pred, y_true], + // sampleWeights) + oneThresh = tf.constant(true); + } + + List controlOps = new ArrayList<>(); + Operand axes = allAxes(tf, tPredictions); + controlOps.add( + tf.withSubScope("updateConfusionMatrixVariables-1") + .assertThat( + tf.reduceAll( + tf.math.greaterEqual( + tPredictions, cast(tf, tf.constant(0), tPredictions.type())), + axes), + Collections.singletonList(tf.constant("predictions must be >= 0")))); + controlOps.add( + tf.withSubScope("updateConfusionMatrixVariables-2") + .assertThat( + tf.reduceAll( + tf.math.lessEqual(tPredictions, cast(tf, tf.constant(1), tPredictions.type())), + axes), + Collections.singletonList(tf.constant("predictions must be <= 1")))); + + LossTuple result = + LossesHelper.squeezeOrExpandDimensions(tf, tLabels, tPredictions, tSampleWeight); + tPredictions = result.getTarget(); // shape (N, Cx) + tLabels = result.getLabels(); // shape (N, Cx) + tSampleWeight = result.getSampleWeights(); // broadcastable to (N, Dx) + + if (!tPredictions.shape().isCompatibleWith(tLabels.shape())) + throw new IllegalArgumentException( + String.format( + "Shapes %s and %s are incompatible)", + tPredictions.shape().toString(), tLabels.shape().toString())); + + if (topK != null) { + tPredictions = filterTopK(tf, tPredictions, topK); + } + + if (classIndex != null) { + // Slice to new shapes (N, Dx) + tLabels = tf.squeeze(tf.gather(tLabels, + tf.constant(new int[] {classIndex}), tf.constant(-1)), + Squeeze.axis(Collections.singletonList(1L))); + tPredictions = tf.squeeze(tf.gather(tPredictions, + tf.constant(new int[] {classIndex}), tf.constant(-1)), + Squeeze.axis(Collections.singletonList(1L))); + } + org.tensorflow.op.core.Shape predShape = tf.shape(tPredictions); + + Operand numExamples = + tf.reshape(tf.shape.size(tPredictions, tf.constant(0)), tf.constant(Shape.scalar())); + + // number of labels (and predictions) per example (after possibly slicing by classIndex) + // In the notation we are using for comments, this is D. + Operand numLabels = + tf.select( + tf.math.equal(tf.shape.numDimensions(predShape), tf.constant(1)), + tf.constant(1), + tf.reduceProd( + // take all but the first dimension + tf.shape.takeLast( + predShape, tf.math.sub(tf.shape.numDimensions(predShape), tf.constant(1))), + tf.constant(0))); + + // threshLabelTile == numLabels except in one case: + // if multilabel and rank(thresholds) != 1, then threshLabelTile is 1 + Operand threshLabelTile = tf.select(oneThresh, numLabels, tf.constant(1)); + + // if multilabel, then shape (1, N, Dx) + // else shape (1, ND), + Operand predictionsExtraDim; + Operand labelsExtraDim; + + if (multiLabel) { + predictionsExtraDim = tf.expandDims(tPredictions, tf.constant(0)); + labelsExtraDim = tf.expandDims(cast(tf, tLabels, TBool.class), tf.constant(0)); + } else { + predictionsExtraDim = tf.reshape(tPredictions, tf.constant(Shape.of(1, -1))); + labelsExtraDim = tf.reshape(cast(tf, tLabels, TBool.class), tf.constant(Shape.of(1, -1))); + } + + // the shape of each thresholds tile + // if multilabel, then [T, 1, -1] + // else [T, -1] + List> threshPretileShape; + + // the tiling multiples for thresholds + // We want to repeat the thresholds for each data position. + // if multilabel, then [1, N, threshLabelTile]. (threshLabelTile is typically numLabels) + // else [1, ND] + List> threshTiles; + + // tiling multiples for predictionsExtraDim and labelsExtraDim + // We want to repeat the predictions and labels for each threshold. + // If multilabel, then [T, 1, 1] + // else [T, 1] + List> dataTiles; + + if (multiLabel) { + threshPretileShape = Arrays.asList(numThresholds, tf.constant(1), tf.constant(-1)); + threshTiles = Arrays.asList(tf.constant(1), numExamples, threshLabelTile); + dataTiles = Arrays.asList(numThresholds, tf.constant(1), tf.constant(1)); + } else { + threshPretileShape = + Arrays.asList(tf.reshape(numThresholds, tf.constant(Shape.scalar())), tf.constant(-1)); + Operand mul = tf.math.mul(numExamples, numLabels); + threshTiles = Arrays.asList(tf.constant(1), mul); + dataTiles = Arrays.asList(numThresholds, tf.constant(1)); + } + + // if multilabel, then shape (T, 1, T*) + // else shape (T, T*) + // where T* is the product of all threshold dimension sizes beyond 0 + Operand thresholdsReshaped = + tf.reshape(cast(tf, thresholds, predictions.type()), tf.stack(threshPretileShape)); + + Operand threshTilesShape = tf.stack(threshTiles); + + // if multilabel, then + // if thresholds has rank > 1, then shape (T, N, T*) + // else shape (T, N, D) + // else shape (T, ND) + Operand threshTiled = tf.tile(thresholdsReshaped, threshTilesShape); + + Operand dataTilesShape = tf.stack(dataTiles); + + // if multilabel, then shape (T, N, D) + // else (T, ND) + Operand predsTiled = tf.tile(predictionsExtraDim, dataTilesShape); + + // Compare predictions and threshold. + Operand predIsPos = tf.math.greater(predsTiled, threshTiled); + // Tile labels by number of thresholds + Operand labelIsPos = tf.tile(labelsExtraDim, tf.stack(dataTiles)); + Operand weightsTiled; + if (tSampleWeight != null) { + tSampleWeight = tf.broadcastTo(tSampleWeight, tf.shape(tPredictions)); + // if multilabel, then + // reshape tSampleWeight to (1, N, threshLabelTile) + // tile the result into shape (T, N, threshLabelTile) + // where threshLabelTile is typically D + // else + // reshape tSampleWeight to (1, ND) + // tile the result into shape (T, ND) + weightsTiled = tf.tile(tf.reshape(tSampleWeight, threshTilesShape), dataTilesShape); + } else { + weightsTiled = null; + } + + if (labelWeights != null) { + // Change shape to (1, Dx). + Operand lLabelWeights = tf.expandDims(tf.identity(labelWeights), tf.constant(0)); + + // Broadcast to shape (N, Dx). + lLabelWeights = tf.broadcastTo(lLabelWeights, tPredictions); + + // If multilabel: shape (T, N, D) + // else: shape (T, ND) + Operand labelWeightsTiled = + tf.tile(tf.reshape(lLabelWeights, tf.stack(threshTiles)), tf.stack(dataTiles)); + + if (weightsTiled == null) { + weightsTiled = labelWeightsTiled; + } else { + weightsTiled = tf.math.mul(weightsTiled, labelWeightsTiled); + } + } + + Map loopVars = new HashMap<>(); + loopVars.put(ConfusionMatrixEnum.TRUE_POSITIVES, new Operand[] {labelIsPos, predIsPos}); + Variable updateTN = variablesToUpdate.get(ConfusionMatrixEnum.TRUE_NEGATIVES); + Variable updateFP = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_POSITIVES); + Variable updateFN = variablesToUpdate.get(ConfusionMatrixEnum.FALSE_NEGATIVES); + + Operand predIsNeg = null; + Operand labelIsNeg; + if (updateFN != null || updateTN != null) { + predIsNeg = tf.math.logicalNot(predIsPos); + loopVars.put(ConfusionMatrixEnum.FALSE_NEGATIVES, new Operand[] {labelIsPos, predIsNeg}); + } + + if (updateFP != null || updateTN != null) { + labelIsNeg = tf.math.logicalNot(labelIsPos); + loopVars.put(ConfusionMatrixEnum.FALSE_POSITIVES, new Operand[] {labelIsNeg, predIsPos}); + if (updateTN != null) { + loopVars.put(ConfusionMatrixEnum.TRUE_NEGATIVES, new Operand[] {labelIsNeg, predIsNeg}); + } + } + + final Operand weightsTiledF = weightsTiled; + loopVars + .keySet() + .forEach( + (c) -> { + if (variablesToUpdate.containsKey(c)) { + Operand[] op = loopVars.get(c); + // op[0] = label, op[1] == prediction + controlOps.add( + weightedAssignAdd( + tf, + op[0], + op[1], + weightsTiledF, + variablesToUpdate.get(c), + varInitializers.get(c))); + } + }); + + return controlOps; + } /** - * Calculate the mean of the operand, along all axes and keepDims is false - * + * Creates an Operand that adds the values by taking the logical and of labels and predictions to + * the specified confusion matrix variable. + * + * @param tf The TensorFlow Ops + * @param labels the labels + * @param predictions the predictions + * @param weights the weights applied to the logical and result, may be null + * @param variable the variable to update + * @param initializer the variable initializer to be applied to the variable, may be null. + * @param the data type for the variable. + * @return an Operand that updates the variable. + */ + private static Operand weightedAssignAdd( + Ops tf, + Operand labels, + Operand predictions, + Operand weights, + Variable variable, + Assign initializer) { + Class type = variable.type(); + Operand labelAndPred = cast(tf, tf.math.logicalAnd(labels, predictions), type); + + if (weights != null) { + labelAndPred = tf.math.mul(labelAndPred, weights); + } + // if multilabel: + // sum across examples, leaving shape (T, D) + // else: + // sum across ND, leaving shape (T) + Operand valueSum = tf.reduceSum(labelAndPred, tf.constant(1)); + Operand assignAdd; + if (initializer != null) { + Ops tfc = + tf.withSubScope("weightedAssignAdd") + .withControlDependencies(Collections.singletonList(initializer)); + assignAdd = tfc.assignAdd(variable, valueSum); + } else { + assignAdd = tf.assignAdd(variable, valueSum); + } + return assignAdd; + } + + /** + * Filters top-k values in the last dim of x and set the rest to NEG_INF. + * + *

Used for computing top-k prediction values in dense labels (which has the same shape as + * predictions) for recall and precision top-k metrics. + * + * @param tf The TensorFlow Ops + * @param x the tensor with any dimensions to filter + * @param topK the number of values to keep. + * @param the data type for x and the return value. + * @return the topK prediction values. + */ + private static Operand filterTopK(Ops tf, Operand x, int topK) { + Class type = x.type(); + Shape xShape = x.shape(); + // top has the same rank as x; the last dimension becomes indices of the topK features. + TopK top = tf.nn.topK(x, tf.constant(topK), TopK.sorted(false)); + // oneHot has an additional dimension: the one-hot representation of each topK index. + OneHot oneHot = + tf.oneHot( + top.indices(), + cast(tf, tf.constant(xShape.size(xShape.numDimensions() - 1)), TInt32.class), + tf.constant(1), + tf.constant(0), + OneHot.axis(-1L)); + // Sum the one-hot representations along the last dimension of x. + Operand topKMask = cast(tf, tf.reduceSum(oneHot, tf.constant(-2)), type); + + // x * top_k_mask + NEG_INF * (1 - top_k_mask) + Operand add1 = tf.math.mul(x, topKMask); + Operand add2 = + tf.math.mul( + cast(tf, tf.constant(NEG_INF), type), + tf.math.sub(cast(tf, tf.constant(1), type), topKMask)); + return tf.math.add(add1, add2); + } + + // alias for mean + + /** + * Calculate the mean of the operand, along all axes and {@code keepDims} is {@code false + * } * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -230,8 +706,8 @@ public static Operand mean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is + * {@code false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -249,12 +725,12 @@ public static Operand mean( * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. * @param the type of the operand - * @return the mean of elements of x. + * @return the mean of elements of {@code x}. */ public static Operand mean(Ops tf, Operand x, boolean keepDims) { return mean(tf, x, null, keepDims); @@ -266,12 +742,12 @@ public static Operand mean(Ops tf, Operand x, boolean * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. * @param the data type of the Operand - * @return the mean of elements of x. + * @return the mean of elements of {@code x}. */ public static Operand mean( Ops tf, Operand x, Operand axes, boolean keepDims) { @@ -281,9 +757,16 @@ public static Operand mean( return tf.math.mean(x, axes, Mean.keepDims(keepDims)); } + public static + LossTuple raggedAssertCompatibleAndGetFlatValues( + Ops tf, Operand labels, Operand predictions) { + // TODO handle ragged Tensors + Operand tLabels = cast(tf, labels, predictions.type()); + return new LossTuple<>(tLabels, predictions); + } + /** - * Calculate the mean of the operand, along all axes and keepDims is false - * + * Calculate the mean of the operand, along all axes and {@code keepDims} is false * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -294,8 +777,8 @@ public static Operand booleanMean(Ops tf, Operand x) { } /** - * Calculate the mean of the operand, alongside the specified axis with keepDims is - * false + * Calculate the mean of the operand, alongside the specified axis with {@code keepDims} is + * {@code false} * * @param tf the TensorFlow Ops * @param x the Operand used to calculate the mean @@ -312,11 +795,11 @@ public static Operand booleanMean( * * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. - * @return the mean of elements of x containing floating point numbers + * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean(Ops tf, Operand x, boolean keepDims) { return booleanMean(tf, x, null, keepDims); @@ -328,11 +811,11 @@ public static Operand booleanMean(Ops tf, Operand x, boolean ke * @param tf the TensorFlow Ops * @param x the boolean Operand used to calculate the mean * @param axes Axes to compute the mean. - * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is - * false, the rank of the tensor is reduced by 1 for each entry in axes - * . If keepdims is true, the reduced dimensions are retained + * @param keepDims Indicates whether to keep the dimensions or not. If {@code keepdims} is + * {@code false}, the rank of the tensor is reduced by 1 for each entry in {@code axes + * }. If {@code keepdims} is {@code true}, the reduced dimensions are retained * with length 1. - * @return the mean of elements of x containing floating point numbers + * @return the mean of elements of {@code x} containing floating point numbers */ public static Operand booleanMean( Ops tf, Operand x, Operand axes, boolean keepDims) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java index 6583465da2e..47d7f8ab737 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/WeightsBroadcastOps.java @@ -15,6 +15,7 @@ package org.tensorflow.framework.metrics.impl; import org.tensorflow.Operand; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; @@ -150,12 +151,13 @@ private static Operand hasValidNonscalarShape( private static Operand hasValidDims( Ops tf, Operand weightsShape, Operand valuesShape) { tf = tf.withSubScope("hasInvalidDims"); + FrameworkOps fops = FrameworkOps.create(tf); Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); Operand validDims = tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); Operand weightsShape2d = tf.expandDims(weightsShape, tf.constant(-1)); - Operand invalidDims = SetsOps.difference(tf, weightsShape2d, validDims); + Operand invalidDims = fops.sets.difference(weightsShape2d, validDims); Operand numInvalidDims = tf.size(invalidDims, TInt32.class); return tf.math.equal(tf.constant(0), numInvalidDims); } From f1c63c049dcad74d7926f333f381f98a47416276 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 2 May 2021 18:54:36 -0400 Subject: [PATCH 19/31] move nn.raw classes to nn in core, remove nn.raw --- .../src/gen/resources/ops.pb | Bin 1462296 -> 1462288 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb b/tensorflow-core/tensorflow-core-api/src/gen/resources/ops.pb index 5472f5f883985bf6f4f266c9cc29668e59833462..fbcecceb5bd35e3681296ef57e58ae324733949a 100644 GIT binary patch delta 142 zcmbQSAacTj$c7fi7N!>F7M2#)7Pc1l7LFFq7OocV7M>Q~7QPn#7J)5-@?wmA(-p)7 zbD0*fOji^Wl;;#+lv3hS$jj54K2c0icKX?3A&Yhaal!2Z;zH-68Rt!ZP$V>sshw^5 n!y+M7exPCvE;XRy;DW@W;#9DfcAyp@76xLG?PrTcz3%`3&nztB delta 138 zcmbQRAacfn$c7fi7N!>F7M2#)7Pc1l7LFFq7OocV7M>Q~7QPn#7J)5-@?wmWrYndE z<}xi|nXV`%D9LKf{j;)2_G#D&gBGcKBbr$}fT kQ#af6yG25(JU|UPTv`wf?F);AfLIuaMYb<27R|o{05Q}nM*si- From 043654b8448fe97a626cbddf678c7befbba3756d Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sun, 2 May 2021 19:13:00 -0400 Subject: [PATCH 20/31] Update FrameworkOps.java --- .../src/main/java/org/tensorflow/framework/op/FrameworkOps.java | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java index d9e3eec4b21..f182d9d7b80 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/FrameworkOps.java @@ -62,6 +62,7 @@ private FrameworkOps(Ops coreOps) { sets = new SetOps(this); math = new MathOps(this); linalg = new LinalgOps(this); + } /** From 06c28df060a961f83a1ace310c5831010d5cd918 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 3 May 2021 10:05:28 -0400 Subject: [PATCH 21/31] Fix unusual regression error in confustion matrix. Needed to reduceAll on the AssertThats. This change is unrelated to this PR, but the bug showed up here. --- .../org/tensorflow/framework/op/MathOps.java | 300 +++++++++--------- 1 file changed, 151 insertions(+), 149 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java index 4c2210feb9c..8fda58806ca 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/MathOps.java @@ -56,11 +56,13 @@ import org.tensorflow.op.math.Square; import org.tensorflow.op.math.Sub; import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat16; import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; import java.util.ArrayList; import java.util.Arrays; @@ -110,27 +112,27 @@ public Operand l2Normalize(Operand x, int[] axis) { * Computes the confusion matrix from predictions and labels. * *

The matrix columns represent the prediction labels and the rows represent the real labels. - * The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid + * The confusion matrix is always a 2-D array of shape {@code [n,n]}, where {@code n} is the number of valid * labels for a given classification task. Both prediction and labels must be 1-D arrays of the * same shape in order for this function to work. * - *

If `num_classes` is `None`, then `num_classes` will be set to one plus the maximum value in + *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum value in * either predictions or labels. Class labels are expected to start at 0. For example, if - * `num_classes` is 3, then the possible labels would be `[0, 1, 2]`. + * {@code numClasses} is 3, then the possible labels would be {@code [0, 1, 2]}. * - *

If `weights` is not `None`, then each prediction contributes its corresponding weight to the + *

If {@code weights} is not null, then each prediction contributes its corresponding weight to the * total value of the confusion matrix cell. * *

For example: * - *

+   * 
{@code
    *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
    *         [[0 0 0 0 0]
    *          [0 0 1 0 0]
    *          [0 0 1 0 0]
    *          [0 0 0 0 0]
    *          [0 0 0 0 1]]
-   * 
+ * }
* *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 * confusion matrix. @@ -152,27 +154,27 @@ public Operand confusionMatrix(Operand labels, Operand * Computes the confusion matrix from predictions and labels. * *

The matrix columns represent the prediction labels and the rows represent the real labels. - * The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid + * The confusion matrix is always a 2-D array of shape {@code [n,n]}, where {@code n} is the number of valid * labels for a given classification task. Both prediction and labels must be 1-D arrays of the * same shape in order for this function to work. * - *

If `num_classes` is `None`, then `num_classes` will be set to one plus the maximum value in + *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum value in * either predictions or labels. Class labels are expected to start at 0. For example, if - * `num_classes` is 3, then the possible labels would be `[0, 1, 2]`. + * {@code numClasses} is 3, then the possible labels would be {@code [0, 1, 2]}. * - *

If `weights` is not `None`, then each prediction contributes its corresponding weight to the + *

If {@code weights} is not null, then each prediction contributes its corresponding weight to the * total value of the confusion matrix cell. * *

For example: * - *

+   * 
{@code
    *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
    *         [[0 0 0 0 0]
    *          [0 0 1 0 0]
    *          [0 0 1 0 0]
    *          [0 0 0 0 0]
    *          [0 0 0 0 1]]
-   * 
+ * }
* *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 * confusion matrix. @@ -196,27 +198,27 @@ public Operand confusionMatrix( * Computes the confusion matrix from predictions and labels. * *

The matrix columns represent the prediction labels and the rows represent the real labels. - * The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid + * The confusion matrix is always a 2-D array of shape {@code [n,n]}, where {@code n} is the number of valid * labels for a given classification task. Both prediction and labels must be 1-D arrays of the * same shape in order for this function to work. * - *

If `num_classes` is `None`, then `num_classes` will be set to one plus the maximum value in + *

If {@code numClasses} is null, then {@code numClasses} will be set to one plus the maximum value in * either predictions or labels. Class labels are expected to start at 0. For example, if - * `num_classes` is 3, then the possible labels would be `[0, 1, 2]`. + * {@code numClasses} is 3, then the possible labels would be {@code [0, 1, 2]}. * - *

If `weights` is not `None`, then each prediction contributes its corresponding weight to the + *

If {@code weights} is not null, then each prediction contributes its corresponding weight to the * total value of the confusion matrix cell. * *

For example: * - *

+   * 
{@code
    *     fops.math.confusion_matrix(tf.constant(new int[] {1, 2, 4}), tf.constant(new int[] {2, 2, 4})) ==>
    *         [[0 0 0 0 0]
    *          [0 0 1 0 0]
    *          [0 0 1 0 0]
    *          [0 0 0 0 0]
    *          [0 0 0 0 1]]
-   * 
+ * }
* *

Note that the possible labels are assumed to be {@code [0, 1, 2, 3, 4]}, resulting in a 5x5 * confusion matrix. @@ -277,19 +279,21 @@ public Operand confusionMatrix( one); } else { lNumClasses = Cast.create(lScope, numClasses, TInt64.class); + Operand less = Less.create(lScope, lLabels, lNumClasses); AssertThat labelsLess = AssertThat.create( lScope, - Less.create(lScope, lLabels, lNumClasses), + ReduceAll.create(scope, less, allAxes(less), ReduceAll.keepDims(false)), Collections.singletonList(Constant.scalarOf(lScope, "labels out of bounds"))); lLabels = Identity.create( lScope.withControlDependencies(Collections.singletonList(labelsLess)), lLabels); + less = Less.create(lScope, lPredictions, lNumClasses); AssertThat predictionsLess = AssertThat.create( lScope, - Less.create(lScope, lPredictions, lNumClasses), + ReduceAll.create(scope, less, allAxes(less), ReduceAll.keepDims(false)), Collections.singletonList(Constant.scalarOf(lScope, "predictions out of bounds"))); lPredictions = Identity.create( @@ -319,12 +323,12 @@ public Operand confusionMatrix( /** * Squeeze last dim if ranks differ from expected by exactly 1. * - * @param labels Label values, a Operand whose dimensions match predictions - * . - * @param predictions Predicted values, a Tensor of arbitrary dimensions. - * @param expectedRankDiff Expected result of rank(predictions) - rank(labels). + * @param labels Label values, a {@code Operand} whose dimensions match {@code predictions + * }. + * @param predictions Predicted values, a {@code Tensor} of arbitrary dimensions. + * @param expectedRankDiff Expected result of {@code rank(predictions) - rank(labels)}. * @param the data type for the labels, predictions and result - * @return labels and predictions, possibly with last dim squeezed. + * @return {@code labels} and {@code predictions}, possibly with last dim squeezed. */ public LossTuple removeSqueezableDimensions( Operand labels, Operand predictions, int expectedRankDiff) { @@ -372,10 +376,9 @@ public LossTuple removeSqueezableDimensions( * Creates an Operand that has all axes contained in the Operand's shape. * * @param op the Operand - * @param THe Data type for the Operand * @return an Operand that has all axes contained in the Operand's shape.. */ - public Operand allAxes(Operand op) { + public Operand allAxes(Operand op) { int rank = op.shape().numDimensions(); if (rank != Shape.UNKNOWN_SIZE) { int[] axes = new int[rank]; @@ -392,18 +395,18 @@ public Operand allAxes(Operand op) { /** * Transpose and reshape the input for contraction op. * - *

This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul` using - * `array_ops.transpose` and `array_ops.reshape`. The method takes a tensor and performs the + *

This method is helpful in reducing {@code math.tensordot} to {@code math_ops.matmul} using + * {@code array_ops.transpose} and {@code array_ops.reshape}. The method takes a tensor and performs the * correct transpose and reshape operation for a given set of indices. It returns the reshaped * tensor as well as a list of indices necessary to reshape the tensor again after matrix * multiplication. * * @param the type of Operand * @param a the Tensor - * @param axis unique indices specifying valid axes of `a`. + * @param axis unique indices specifying valid axes of {@code a}. * @param flipped whether to flip the dimensions or not * @return A tuple (reshapedA, freeDims, freeDimsStatic) where reshapedA is a reshaped to allow - * contraction via matmul, freeDims` is a TInt32 Operand, depending on whether the shape of a + * contraction via matmul, freeDims is a TInt32 Operand, depending on whether the shape of a * is fully specified, and freeDimsStatic is either a list of integers and null values, or * None, representing the inferred shape of the free dimensions */ @@ -703,50 +706,48 @@ private Operand[] tensordotAxes(Operand a, Operan * Tensor contraction of a and b along specified axes and outer product. *

* Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length * and consist of unique integers that specify valid axes for each of the * tensors. Additionally outer product is supported by passing - * axes=0. + * {@code axes=0}. *

- * This operation corresponds to numpy.tensordot(a, b, axes). + * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. + * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. *

- * Example 2: When a and`b are matrices (order 2), + * Example 2: When {@code a} and {@code b} are matrices (order 2), * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order + * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order * 4. *

* Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor * cjklm whose entry corresponding to the indices * (j,k,l,m) is given by: - *

* cjklm = Σi aijk * blmi . *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). - *

+ * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. * - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. * @param axis sum over the last N axes of a and the - * first N axes of b in order. If `axes=0`, computes the outer - * product between `a` and `b`. + * first N axes of b in order. If {@code axis=0}, computes the outer + * product between {@code a} and {@code b}. * @param the datatype of the Operands, must be either TFloat32 or * TFloat64 - * @return A `Operand` with the same type as `a`. + * @return A {@code Operand} with the same type as {@code a}. * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type */ @Endpoint(name = "tensordot") @@ -762,53 +763,53 @@ public Operand tensordot(Operand a, Operand b, in * Tensor contraction of a and b along specified axes and outer product. *

* Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length * and consist of unique integers that specify valid axes for each of the * tensors. Additionally outer product is supported by passing - * axes=0. + * {@code axes=0}. *

- * This operation corresponds to numpy.tensordot(a, b, axes). + * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. + * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. *

- * Example 2: When a and`b are matrices (order 2), + * Example 2: When {@code a} and {@code b} are matrices (order 2), * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order + * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order * 4. *

* Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor * cjklm whose entry corresponding to the indices * (j,k,l,m) is given by: *

* cjklm = Σi aijk * blmi . *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. *

* - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. * @param axes If axes is a scalar, sum over the last N axes of a and the * first N axes of b in order. If axes is a list, the first and second row * contain the set of unique integers specifying axes along which the - * contraction is computed, for `a` and `b`, respectively. The number of - * axes for `a` and `b` must be equal. If `axes=0`, computes the outer - * product between `a` and `b`. + * contraction is computed, for {@code a} and {@code b}, respectively. The number of + * axes for {@code a} and {@code b} must be equal. If {@code axis=0}, computes the outer + * product between {@code a} and {@code b}. * @param the datatype of the Operands, must be either TFloat32 or * TFloat64 - * @return A `Operand` with the same type as `a`. + * @return A {@code Operand} with the same type as {@code a}. * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type */ @Endpoint(name = "tensordot") @@ -826,51 +827,51 @@ public Operand tensordot( * Tensor contraction of a and b along specified axes and outer product. *

* Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length * and consist of unique integers that specify valid axes for each of the * tensors. Additionally outer product is supported by passing - * axes=0. + * {@code axes=0}. *

- * This operation corresponds to numpy.tensordot(a, b, axes). + * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. + * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. *

- * Example 2: When a and`b are matrices (order 2), + * Example 2: When {@code a} and{@code b} are matrices (order 2), * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order + * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order * 4. *

* Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor * cjklm whose entry corresponding to the indices * (j,k,l,m) is given by: *

* cjklm = Σi aijk * blmi . *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. *

* - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. * @param axes the first and second row * contain the set of unique integers specifying axes along which the - * contraction is computed, for `a` and `b`, respectively. The number of - * axes for `a` and `b` must be equal. I + * contraction is computed, for {@code a} and {@code b}, respectively. The number of + * axes for {@code a} and {@code b} must be equal. I * @param the datatype of the Operands, must be either TFloat32 or * TFloat64 - * @return A `Operand` with the same type as `a`. + * @return A {@code Operand} with the same type as {@code a}. * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type */ @Endpoint(name = "tensordot") @@ -887,51 +888,51 @@ public Operand tensordot(Operand a, Operand b, in * Tensor contraction of a and b along specified axes and outer product. *

* Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length * and consist of unique integers that specify valid axes for each of the * tensors. Additionally outer product is supported by passing - * axes=0. + * {@code axes=0}. *

- * This operation corresponds to numpy.tensordot(a, b, axes). + * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. + * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. *

- * Example 2: When a and`b are matrices (order 2), + * Example 2: When {@code a} and{@code b} are matrices (order 2), * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order + * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order * 4. *

* Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor * cjklm whose entry corresponding to the indices * (j,k,l,m) is given by: *

* cjklm = Σi aijk * blmi . *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. *

* - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. * @param axes the first and second row * contain the set of unique integers specifying axes along which the - * contraction is computed, for `a` and `b`, respectively. The number of - * axes for `a` and `b` must be equal. I + * contraction is computed, for {@code a} and {@code b}, respectively. The number of + * axes for {@code a} and {@code b} must be equal. I * @param the datatype of the Operands, must be either TFloat32 or * TFloat64 - * @return A `Operand` with the same type as `a`. + * @return A {@code Operand} with the same type as {@code a}. * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type */ @Endpoint(name = "tensordot") @@ -948,49 +949,49 @@ public Operand tensordot(Operand a, Operand b, in * Tensor contraction of a and b along specified axes and outer product. *

* Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length + * from {@code a} and {@code b} over the indices specified by + * {@code a_axes} and {@code b_axes}. The lists + * {@code a_axes} and {@code b_axes} specify those pairs of axes + * along which to contract the tensors. The axis {@code a_axes[i]} of + * {@code a} must have the same dimension as axis + * {@code b_axes[i]} of {@code b} for all {@code i} in + * {@code range(0, len(a_axes))}. The lists + * {@code a_axes} and {@code b_axes} must have identical length * and consist of unique integers that specify valid axes for each of the * tensors. Additionally outer product is supported by passing - * axes=0. + * {@code axes=0}. *

- * This operation corresponds to numpy.tensordot(a, b, axes). + * This operation corresponds to {@code numpy.tensordot(a, b, axes)}. *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. + * Example 1: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes = 1} is equivalent to matrix multiplication. *

- * Example 2: When a and`b are matrices (order 2), + * Example 2: When {@code a} and{@code b} are matrices (order 2), * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. + * {@code axes = [[1], [0]]} is equivalent to matrix multiplication. *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order + * Example 3: When {@code a} and {@code b} are matrices (order 2), + * the case {@code axes=0} gives the outer product, a tensor of order * 4. *

* Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * represent two tensors of order 3. Then, {@code contract(a, b, [[0], [2]])} is the order 4 tensor * cjklm whose entry corresponding to the indices * (j,k,l,m) is given by: *

* cjklm = Σi aijk * blmi . *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + * In general, {@code order(c) = order(a) + order(b) - 2*len(axes[0])}. *

* - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. + * @param a {@code Operand} of type {@code TFloat32} or {@code TFloat64}. + * @param b {@code Operand} with the same type as {@code a}. * @param aAxis axes for the a Operand * @param bAxis axes for the b Operand * @param the datatype of the Operands, must be either TFloat32 or * TFloat64 - * @return A `Operand` with the same type as `a`. + * @return A {@code Operand} with the same type as {@code a}. * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type */ @SuppressWarnings({"unchecked", "unused"}) @@ -1042,7 +1043,7 @@ public Operand tensordot( * Computes log(sum(exp(elements across dimensions of a tensor))). Reduces {@code input_tensor} * along the dimensions given in {@code axes}. * - *

Reduces `{@code input} along the dimensions given in {@code axes}. Unless {@code keepdims} + *

Reduces {@code input} along the dimensions given in {@code axes}. Unless {@code keepdims} * is true, the rank of the tensor is reduced by 1 for each of the entries in {@code axes}, which * must be unique. If {@code keepdims} is true, the reduced dimensions are retained with length 1. * If {@code axes} has no entries, all dimensions are reduced, and a tensor with a single element @@ -1052,8 +1053,9 @@ public Operand tensordot( * * @param input The tensor to reduce. * @param axes The dimensions to reduce. If null, reduces all dimensions. Must be in the range - * {@link [-rank(input_tensor), rank(input_tensor)]}. + * {@code [-rank(input_tensor), rank(input_tensor)]}. * @param keepDims If true, retains reduced dimensions with length 1. + * @param the data type for the input and the result * @return The reduced tensor. */ @Endpoint(name = "reduceLogSumExp") From 8f33d21c2a79fa554138cface5d771905ff597e8 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 3 May 2021 10:05:53 -0400 Subject: [PATCH 22/31] javadoc fixes --- .../tensorflow/framework/op/LinalgOps.java | 94 +++++++++---------- 1 file changed, 46 insertions(+), 48 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java index eb069a2db22..931f7f851c2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/LinalgOps.java @@ -45,15 +45,15 @@ public class LinalgOps { } /** - * Multiplies matrix a by matrix b, producing a * b - * . + * Multiplies matrix {@code a} by matrix {@code b}, producing {@code a} * {@code b + * }. * - *

The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 + *

The inputs must, following any transpositions, be tensors of {@code rank >= 2} where the inner 2 * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions * specify matching batch size. * - *

Both matrices must be of the same type. The supported types are: TFloat16, - * TFloat32, TFloat64, TInt32. + *

Both matrices must be of the same type. The supported types are: {@code TFloat16}, + * {@code TFloat32}, {@code TFloat64}, {@code TInt32}. * *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by * setting one of the corresponding flag to true. These are false by default. @@ -80,15 +80,15 @@ public class LinalgOps { * *

Note: This is matrix product, not element-wise product. * - * @param a an Operand of of type TFloat16, TFloat32, TFloat64 - * , TInt32. with a rank > 1 - * @param b an Operand with same type and rank as a. + * @param a an Operand of of type {@code TFloat16}, {@code TFloat32}, {@code TFloat64 + * }, {@code TInt32}. with a {@code rank > 1} + * @param b an Operand with same type and rank as {@code a}. * @param the data type of the Operands - * @return A Operand of the same type as a and b where each inner-most - * matrix is the product of the corresponding matrices in a and b. + * @return A Operand of the same type as {@code a} and {@code b} where each inner-most + * matrix is the product of the corresponding matrices in {@code a} and {@code b}. * This is the matrix product not an element-wise product. - * @throws java.lang.IllegalArgumentException If transposeA and adjointA - * , or transposeB and adjointB are both set to `true`. + * @throws java.lang.IllegalArgumentException If {@code transposeA} and {@code adjointA} + * , or {@code transposeB} and {@code adjointB} are both set to `true`. */ @Endpoint(name = "matmul") public Operand matmul(Operand a, Operand b) { @@ -96,21 +96,19 @@ public Operand matmul(Operand a, Operand b) { } /** - * Multiplies matrix a by matrix b, producing a * b - * . + * Multiplies matrix {@code a} by matrix {@code b}, producing {@code a} * {@code b + * }. * - *

The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 + *

The inputs must, following any transpositions, be tensors of {@code rank >= 2} where the inner 2 * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions * specify matching batch size. * - *

Both matrices must be of the same type. The supported types are: TFloat16, - * TFloat32, TFloat64, TInt32. + *

Both matrices must be of the same type. The supported types are: {@code TFloat16}, + * {@code TFloat32}, {@code TFloat64}, {@code TInt32}. * *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by * setting one of the corresponding flag to true. These are false by default. * - *

- * *

Note: This is matrix product, not element-wise product. * *

A simple 2-D tensor matrix multiplication: @@ -133,17 +131,17 @@ public Operand matmul(Operand a, Operand b) { * * }

* - * @param a an Operand of of type TFloat16, TFloat32, TFloat64 - * , TInt32. with a rank > 1 - * @param b an Operand with same type and rank as a. - * @param transposeA If `true`, a is transposed before multiplication. - * @param transposeB If `True`, b is transposed before multiplication + * @param a an Operand of of type {@code TFloat16}, {@code TFloat32}, {@code TFloat64 + * }, {@code TInt32}. with a {@code rank > 1} + * @param b an Operand with same type and rank as {@code a}. + * @param transposeA If true, {@code a} is transposed before multiplication. + * @param transposeB If true, {@code b} is transposed before multiplication * @param the data type of the Operands - * @return A Operand of the same type as a and b where each inner-most - * matrix is the product of the corresponding matrices in a and b. + * @return A Operand of the same type as {@code a} and {@code b} where each inner-most + * matrix is the product of the corresponding matrices in {@code a} and {@code b}. * This is the matrix product not an element-wise product. - * @throws java.lang.IllegalArgumentException If transposeA and adjointA - * , or transposeB and adjointB are both set to `true`. + * @throws java.lang.IllegalArgumentException If {@code transposeA} and {@code adjointA} + * , or {@code transposeB} and {@code adjointB} are both set to `true`. */ @Endpoint(name = "matmul") public Operand matmul( @@ -152,15 +150,15 @@ public Operand matmul( } /** - * Multiplies matrix a by matrix b, producing a * b - * . + * Multiplies matrix {@code a} by matrix {@code b}, producing {@code a} * {@code b + * }. * - *

The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 + *

The inputs must, following any transpositions, be tensors of {@code rank >= 2} where the inner 2 * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions * specify matching batch size. * - *

Both matrices must be of the same type. The supported types are: TFloat16, - * TFloat32, TFloat64, TInt32. + *

Both matrices must be of the same type. The supported types are: {@code TFloat16}, + * {@code TFloat32}, {@code TFloat64}, {@code TInt32}. * *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by * setting one of the corresponding flag to true. These are false by default. @@ -187,25 +185,25 @@ public Operand matmul( * * } * - * @param a an Operand of of type TFloat16, TFloat32, TFloat64 - * , TInt32. with a rank > 1 - * @param b an Operand with same type and rank as a. - * @param transposeA If true, a is transposed before multiplication. - * @param transposeB If True, b is transposed before multiplication - * @param adjointA If true, a is conjugated and transposed before multiplication. - * @param adjointB If true, b is conjugated and transposed before multiplication. - * @param aIsSparse If true, a is treated as a sparse matrix. Notice, this does + * @param a an Operand of of type {@code TFloat16}, {@code TFloat32}, {@code TFloat64 + * }, {@code TInt32}. with a {@code rank > 1} + * @param b an Operand with same type and rank as {@code a}. + * @param transposeA If true, {@code a} is transposed before multiplication. + * @param transposeB If True, {@code b} is transposed before multiplication + * @param adjointA If true, {@code a} is conjugated and transposed before multiplication. + * @param adjointB If true, {@code b} is conjugated and transposed before multiplication. + * @param aIsSparse If true, {@code a} is treated as a sparse matrix. Notice, this does * not support {@link SparseTensor}, it just makes optimizations that assume most values - * in a are zero. - * @param bIsSparse If true, b is treated as a sparse matrix. Notice, this does + * in {@code a} are zero. + * @param bIsSparse If true, {@code b} is treated as a sparse matrix. Notice, this does * not support {@link SparseTensor}, it just makes optimizations that assume most values - * in b are zero. + * in {@code b} are zero. * @param the data type of the Operands - * @return A Operand of the same type as a and b where each inner-most - * matrix is the product of the corresponding matrices in a and b. + * @return A Operand of the same type as {@code a} and {@code b} where each inner-most + * matrix is the product of the corresponding matrices in {@code a} and {@code b}. * This is the matrix product not an element-wise product. - * @throws java.lang.IllegalArgumentException If transposeA and adjointA - * , or transposeB and adjointB are both set to `true`. + * @throws java.lang.IllegalArgumentException If {@code transposeA} and {@code adjointA} + * , or {@code transposeB} and {@code adjointB} are both set to `true`. */ @SuppressWarnings("unchecked") @Endpoint(name = "matmul") From 198ea27f4e0a6e1c5b37f68a5c275adf4f4b5e6c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 26 Apr 2021 16:44:33 -0400 Subject: [PATCH 23/31] Initial checkin --- .../framework/layers/Activation.java | 122 +++ .../org/tensorflow/framework/layers/Add.java | 91 ++ .../framework/layers/AlphaDropout.java | 175 ++++ .../tensorflow/framework/layers/Average.java | 90 ++ .../framework/layers/Concatenate.java | 388 ++++++++ .../tensorflow/framework/layers/Dense.java | 360 +++++++ .../org/tensorflow/framework/layers/Dot.java | 555 +++++++++++ .../tensorflow/framework/layers/Dropout.java | 256 +++++ .../org/tensorflow/framework/layers/ELU.java | 109 ++ .../tensorflow/framework/layers/Flatten.java | 197 ++++ .../framework/layers/GaussianDropout.java | 173 ++++ .../framework/layers/GaussianNoise.java | 170 ++++ .../tensorflow/framework/layers/Input.java | 345 +++++++ .../tensorflow/framework/layers/Lambda.java | 203 ++++ .../tensorflow/framework/layers/Layer.java | 940 ++++++++++++++++++ .../framework/layers/LeakyReLU.java | 108 ++ .../tensorflow/framework/layers/Maximum.java | 90 ++ .../tensorflow/framework/layers/Minimum.java | 92 ++ .../tensorflow/framework/layers/Multiply.java | 90 ++ .../org/tensorflow/framework/layers/ReLU.java | 233 +++++ .../framework/layers/RepeatVector.java | 123 +++ .../tensorflow/framework/layers/Reshape.java | 110 ++ .../tensorflow/framework/layers/Softmax.java | 118 +++ .../tensorflow/framework/layers/Subtract.java | 104 ++ .../framework/layers/ThresholdedReLU.java | 116 +++ .../framework/layers/impl/InputSpec.java | 473 +++++++++ .../framework/layers/impl/Merge.java | 382 +++++++ .../framework/layers/impl/TensorFormat.java | 23 + .../framework/layers/impl/VariableDef.java | 213 ++++ .../framework/layers/ActivationTest.java | 48 + .../tensorflow/framework/layers/AddTest.java | 236 +++++ .../framework/layers/AlphaDropoutTest.java | 53 + .../framework/layers/AverageTest.java | 107 ++ .../framework/layers/ConcatenateTest.java | 209 ++++ .../framework/layers/DenseTest.java | 625 ++++++++++++ .../tensorflow/framework/layers/DotTest.java | 112 +++ .../framework/layers/DropoutTest.java | 89 ++ .../tensorflow/framework/layers/ELUTest.java | 119 +++ .../framework/layers/FlattenTest.java | 74 ++ .../framework/layers/GaussianDropoutTest.java | 63 ++ .../framework/layers/GaussianNoiseTest.java | 65 ++ .../framework/layers/InputTest.java | 99 ++ .../framework/layers/LambdaTest.java | 43 + .../framework/layers/LeakyReLUTest.java | 119 +++ .../framework/layers/MaximumTest.java | 106 ++ .../framework/layers/MinimumTest.java | 106 ++ .../framework/layers/MultiplyTest.java | 134 +++ .../tensorflow/framework/layers/ReLUTest.java | 128 +++ .../framework/layers/RepeatVectorTest.java | 47 + .../framework/layers/ReshapeTest.java | 121 +++ .../layers/SequentialLayersTest.java | 67 ++ .../framework/layers/SubtractTest.java | 186 ++++ .../framework/layers/ThresholdedReLUTest.java | 56 ++ .../framework/layers/impl/InputSpecTest.java | 72 ++ .../framework/layers/impl/TensorDotTest.java | 186 ++++ 55 files changed, 9719 insertions(+) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Reshape.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/TensorFormat.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ActivationTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AlphaDropoutTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DropoutTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/FlattenTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SequentialLayersTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ThresholdedReLUTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/TensorDotTest.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java new file mode 100644 index 00000000000..5698e4766a2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java @@ -0,0 +1,122 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that applies an activation function to an output. + * + * @param the data type for the layer's weights and computation. + */ +public class Activation extends Layer { + private final org.tensorflow.framework.activations.Activation activation; + + /** + * Creates an Activation layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops. + * @param activation the activation to apply + * @param type the data type for the weights and computation + */ + public Activation( + Ops tf, + org.tensorflow.framework.activations.Activation activation, + Class type) { + this(tf, null, activation, type, null); + } + + /** + * Creates an Activation layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops. + * @param activation the activation to apply + * @param type the data type for the weights and computation + * @param options the layer's options, may be null + */ + public Activation( + Ops tf, + org.tensorflow.framework.activations.Activation activation, + Class type, + Options options) { + this(tf, null, activation, type, options); + } + + /** + * Creates an Activation layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param activation the activation to apply + * @param type the data type for the weights and computation + */ + public Activation( + Ops tf, + String name, + org.tensorflow.framework.activations.Activation activation, + Class type) { + this(tf, name, activation, type, null); + } + /** + * Creates an Activation layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param activation the activation to apply + * @param type the data type for the weights and computation + * @param options the layer's options, may be null + */ + public Activation( + Ops tf, + String name, + org.tensorflow.framework.activations.Activation activation, + Class type, + Options options) { + super(tf, name, true, type, options); + this.activation = activation; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + List> results = new ArrayList<>(); + inputs.forEach( + input -> results.add(cast(tf, activation.call(cast(tf, input, getType())), resultType))); + return callPostProcess(results, training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java new file mode 100644 index 00000000000..5a7c0ce65e3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that adds a list of inputs element-wise. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Add extends Merge { + + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Add(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Add(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Add(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Add(Ops tf, String name, Class type, Options options) { + + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.add(output, cast(tf, inputs.get(i), getType())); + } + return output; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java new file mode 100644 index 00000000000..3c3f723ecf7 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java @@ -0,0 +1,175 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Applies Alpha Dropout to the input. + * + *

Alpha Dropout is a Dropout that keeps mean and variance of inputs to their + * original values, in order to ensure the self-normalizing property even after this dropout. Alpha + * Dropout fits well to Scaled Exponential Linear Units by randomly setting activations to the + * negative saturation value. + */ +public class AlphaDropout extends Layer { + private static final long DEFAULT_GRAPH_SEED = 87654321; + private final float rate; + private final Shape noiseShape; + private final long seed; + + /** + * Creates a AlphaDropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops, may be null but will need to be set before the first call to the + * {@link #call} method method is called. + * @param rate A number between 0 and 1. Drop probability (as with {@link Dropout}). The + * multiplicative noise will have standard deviation sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public AlphaDropout(Ops tf, float rate, long seed, Class type, Options options) { + + this(tf, null, rate, null, seed, type, options); + } + + /** + * Creates a AlphaDropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param rate A number between 0 and 1. Drop probability (as with {@link Dropout}). The + * multiplicative noise will have standard deviation sqrt(rate / (1 - rate)). + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public AlphaDropout( + Ops tf, float rate, Shape noiseShape, long seed, Class type, Options options) { + this(tf, null, rate, noiseShape, seed, type, options); + } + + /** + * Creates a AlphaDropout layer + * + * @param tf the TensorFlow Ops, may be null but will need to be set before the first call to the + * {@link #call} method method is called. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Drop probability (as with {@link Dropout}). The + * multiplicative noise will have standard deviation sqrt(rate / (1 - rate)). + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public AlphaDropout( + Ops tf, + String name, + float rate, + Shape noiseShape, + long seed, + Class type, + Options options) { + super(tf, name, true, type, options); + if (rate < 0 || rate >= 1) + throw new IllegalArgumentException("The rate must be between >= 0 and < 1, inclusive."); + this.rate = rate; + this.noiseShape = noiseShape; + this.seed = seed; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + + if (!training || rate < 0 || rate > 1) { + return convertList(inputs, resultType); + } + + // training = true + List> outputs = new ArrayList<>(); + Operand rateT = cast(tf, tf.constant(rate), getType()); + Operand alpha = cast(tf, tf.constant(1.6732632423543772848170429916717), getType()); + Operand scale = cast(tf, tf.constant(1.0507009873554804934193349852946), getType()); + // alpha_p = -alpha * scale + Operand alpha_p = tf.math.mul(tf.math.neg(alpha), scale); + Operand one = cast(tf, tf.constant(1), getType()); + Operand minusPoint5 = cast(tf, tf.constant(-0.5), getType()); + // a = ((1 - rate) * (1 + rate * alpha_p**2))**-0.5 + Operand a = + tf.math.pow( + tf.math.mul( + tf.math.sub(one, rateT), + tf.math.add(one, tf.math.mul(rateT, tf.math.mul(alpha_p, alpha_p)))), + minusPoint5); + // b = -a * alpha_p * rate + Operand b = tf.math.mul(tf.math.neg(a), tf.math.mul(alpha_p, rateT)); + + for (Operand input : inputs) { + Operand tInput = cast(tf, input, getType()); + Operand noise = + noiseShape == null ? tf.shape(input, TInt64.class) : tf.constant(noiseShape); + Operand randomTensor = + tf.random.randomUniform( + noise, getType(), RandomUniform.seed(DEFAULT_GRAPH_SEED), RandomUniform.seed2(seed)); + Operand keptIdx = cast(tf, tf.math.greaterEqual(randomTensor, rateT), getType()); + Operand x = + tf.math.add( + tf.math.mul(tInput, keptIdx), tf.math.mul(alpha_p, tf.math.sub(one, keptIdx))); + // result = a*x + b + //noinspection SuspiciousNameCombination + Operand result = tf.math.add(tf.math.mul(a, x), b); + outputs.add(result); + } + + return callPostProcess(convertTo(outputs, resultType), training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java new file mode 100644 index 00000000000..5dd116aa38e --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that averages a list of inputs element-wise. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Average extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Average(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Average(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Average(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Average(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.add(output, cast(tf, inputs.get(i), getType())); + } + return tf.math.div(output, cast(tf, tf.constant(inputs.size()), getType())); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java new file mode 100644 index 00000000000..3243f8f5f17 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java @@ -0,0 +1,388 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that concatenates a list of inputs element-wise. + * + *

It takes as input a list of tensors, all of the same shape except for the concatenation axis, + * and returns a single tensor that is the concatenation of all inputs. + * + * @param the data type for the layer's weights and computation. + */ +public class Concatenate extends Merge { + public static final int DEFAULT_AXIS = -1; + private int axis; + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name , and using + * {@link #DEFAULT_AXIS} for the axis along which to concatenate. + * + * @param type the data type for the weights and computation + */ + public Concatenate(Class type) { + this(null, null, DEFAULT_AXIS, type, null); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name , and using + * {@link #DEFAULT_AXIS} for the axis along which to concatenate. + * + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Class type, Options options) { + this(null, null, DEFAULT_AXIS, type, options); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + */ + public Concatenate(int axis, Class type) { + this(null, null, axis, type, null); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(int axis, Class type, Options options) { + this(null, null, axis, type, options); + } + + /** + * Creates a Concatenate Layer using {@link #DEFAULT_AXIS} for the axis along which to + * concatenate. + * + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Concatenate(String name, Class type) { + this(null, name, DEFAULT_AXIS, type, null); + } + + /** + * Creates a Concatenate Layer using {@link #DEFAULT_AXIS} for the axis along which to + * concatenate. + * + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(String name, Class type, Options options) { + this(null, name, DEFAULT_AXIS, type, options); + } + + /** + * Creates a Concatenate Layer + * + * @param axis Axis along which to concatenate. + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Concatenate(String name, int axis, Class type) { + this(null, name, axis, type, null); + } + + /** + * Creates a Concatenate Layer + * + * @param axis Axis along which to concatenate. + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(String name, int axis, Class type, Options options) { + this(null, name, axis, type, options); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name, and using + * {@link #DEFAULT_AXIS} for the axis along which to concatenate. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Concatenate(Ops tf, Class type) { + this(tf, null, DEFAULT_AXIS, type, null); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name, and using + * {@link #DEFAULT_AXIS} for the axis along which to concatenate. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Ops tf, Class type, Options options) { + this(tf, null, DEFAULT_AXIS, type, options); + } + + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + */ + public Concatenate(Ops tf, int axis, Class type) { + this(tf, null, axis, type, null); + } + + /** + * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Ops tf, int axis, Class type, Options options) { + this(tf, null, axis, type, options); + } + + /** + * Creates a Concatenate Layer using {@link #DEFAULT_AXIS} for the axis along which to + * concatenate. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Concatenate(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_AXIS, type, null); + } + + /** + * Creates a Concatenate Layer using {@link #DEFAULT_AXIS} for the axis along which to + * concatenate. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Ops tf, String name, Class type, Options options) { + this(tf, name, DEFAULT_AXIS, type, options); + } + + /** + * Creates a Concatenate Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + */ + public Concatenate(Ops tf, String name, int axis, Class type) { + this(tf, name, axis, type, null); + } + /** + * Creates a Concatenate Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axis Axis along which to concatenate. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Concatenate(Ops tf, String name, int axis, Class type, Options options) { + super(tf, name, type, options); + this.axis = axis; + this.setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> computeMask( + List> inputs, List> masks) { + if (masks == null || masks.isEmpty()) { + return null; + } + if (inputs.size() != masks.size()) { + throw new IllegalArgumentException("The lists inputs and masks should have the same length."); + } + boolean allNull = true; + for (Operand m : masks) { + if (m != null) { + allNull = false; + break; + } + } + if (allNull) { + return null; + } + + final Ops tf = getTF(); + + List> rMasks = + masks.stream().map(m -> cast(getTF(), m, TBool.class)).collect(Collectors.toList()); + + List> newMasks = new ArrayList<>(); + for (int i = 0; i < inputs.size(); i++) { + Operand input = inputs.get(i); + Operand mask = rMasks.get(i); + if (mask == null) { + newMasks.add(cast(tf, tf.onesLike(input), TBool.class)); + } else if (mask.rank() < input.rank()) { + newMasks.add(tf.expandDims(mask, tf.constant(-1))); + } else { + newMasks.add(mask); + } + } + Operand concat = tf.concat(newMasks, tf.constant(axis)); + return Collections.singletonList(tf.reduceAll(concat, tf.constant(-1))); + } + + /** {@inheritDoc} */ + @Override + public void build(List inputShapes) { + + // Used purely for shape validation. + if (inputShapes.size() < 2) { + throw new IllegalArgumentException("A Concatenate layer must have at least 2 inputs."); + } + boolean allShapesUnknown = true; + for (Shape shape : inputShapes) { + if (!shape.isUnknown()) { + allShapesUnknown = false; + break; + } + } + if (allShapesUnknown) { + this.setBuilt(true); + return; + } + Integer rank = null; + long[][] shapesArray = new long[inputShapes.size()][]; + for (int i = 0; i < inputShapes.size(); i++) { + + Shape shape = inputShapes.get(i); + long[] dims = new long[shape.numDimensions() - 1]; + for (int j = 0, k = 0; j < dims.length; k++) { + if (k == axis) continue; + dims[j++] = shape.size(i); + } + + if (rank == null || rank == Shape.UNKNOWN_SIZE) { + rank = shape.numDimensions(); + } else if (rank != shape.numDimensions()) { + throw new IllegalArgumentException( + String.format( + "A Concatenate layer requires inputs with matching shapes %s", + shapesToString(inputShapes))); + } + shapesArray[i] = dims; + } + + if (axis < 0) { + axis = Math.floorMod(axis, rank); + } + long[] firstShape = shapesArray[0]; + for (int i = 1; i < shapesArray.length; i++) { + for (int j = 0; j < shapesArray[i].length; j++) { + if (shapesArray[i][j] != firstShape[j] + && shapesArray[i][j] != Shape.UNKNOWN_SIZE + && firstShape[j] != Shape.UNKNOWN_SIZE) { + throw new IllegalArgumentException( + String.format( + "A Concatenate layer requires inputs with matching shapes %s", + shapesToString(inputShapes))); + } + } + } + + this.setBuilt(true); + } + + /** + * Coverts a list of shapes to a String + * + * @param shapes the list of shapes. + * @return list of shapes as a String + */ + private String shapesToString(List shapes) { + StringBuilder sb = new StringBuilder("[ "); + boolean first = true; + for (Shape shape : shapes) { + if (!first) { + sb.append(", "); + } else { + first = false; + } + sb.append(shape); + } + sb.append(" ]"); + return sb.toString(); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + if (inputs.size() < 2) { + throw new IllegalArgumentException("A Concatenate layer must have at least 2 inputs."); + } + Class inputType = inputs.get(0).type(); + List> tList = + inputs.stream().map(item -> cast(tf, item, getType())).collect(Collectors.toList()); + return cast(tf, tf.concat(tList, tf.constant(axis)), inputType); + } + + public List computeOutputShape(List inputShapes) { + build(inputShapes); + Shape outputShape = inputShapes.get(0); + long[] dims = outputShape.asArray(); + if (dims == null) { + dims = new long[] {Shape.UNKNOWN_SIZE}; + } + + for (int i = 1; i < inputShapes.size(); i++) { + Shape shape = inputShapes.get(0); + if (outputShape.size(axis) == Shape.UNKNOWN_SIZE || shape.size(axis) == Shape.UNKNOWN_SIZE) { + dims[axis] = Shape.UNKNOWN_SIZE; + break; + } + dims[axis] += shape.size(axis); + } + + Shape result = Shape.of(dims); + return Collections.singletonList(result); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java new file mode 100644 index 00000000000..d433c7bcb86 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java @@ -0,0 +1,360 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.activations.Activation; +import org.tensorflow.framework.constraints.Constraint; +import org.tensorflow.framework.initializers.Glorot; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.initializers.VarianceScaling; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.layers.impl.InputSpec; +import org.tensorflow.framework.op.math.TensorDot; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * A regular densely-connected NN layer. + * + *

Dense implements the operation: + * output = activation(dot(input, kernel) + bias) where activation is the + * element-wise activation function passed as the activation argument, kernel + * is a weights matrix created by the layer, and bias is a bias vector created + * by the layer (only applicable if useBias is true). + * + *

Note: If the input to the layer has a rank greater than 2, then Dense + * computes the dot product between the inputs and the kernel along the + * last axis of the inputs and axis 1 of the kernel (using + * tf.tensordot). For example, if input has dimensions (batch_size, d0, + * d1), then we create a kernel with shape (d1, units), and the + * kernel operates along axis 2 of the input, on every sub-tensor of shape + * (1, 1, d1) (there are batch_size * d0 such sub-tensors). The output in + * this case will have shape (batch_size, d0, units). + * + * @param the data type for the layer's weights and computation. + */ +public class Dense extends Layer { + + private final Integer units; + private final Activation activation; + private final boolean useBias; + private final long seed; + private final Constraint kernelConstraint; + private final Constraint biasConstraint; + private Initializer kernelInitializer; + private Initializer biasInitializer; + private Variable kernel; + private Variable bias; + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param units Positive integer, dimensionality of the output space. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dense(Ops tf, Integer units, long seed, Class type) { + this(tf, null, units, null, true, null, null, null, null, seed, type, null); + } + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param units Positive integer, dimensionality of the output space. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dense(Ops tf, Integer units, long seed, Class type, Options options) { + this(tf, null, units, null, true, null, null, null, null, seed, type, options); + } + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param units Positive integer, dimensionality of the output space. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dense(Ops tf, String name, Integer units, long seed, Class type) { + this(tf, name, units, null, true, null, null, null, null, seed, type, null); + } + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param units Positive integer, dimensionality of the output space. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dense(Ops tf, String name, Integer units, long seed, Class type, Options options) { + this(tf, name, units, null, true, null, null, null, null, seed, type, options); + } + + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param units Positive integer, dimensionality of the output space. + * @param activation Activation function to use. If you don't specify anything, no activation is + * applied (ie. "linear" activation: a(x) = x). + * @param useBias whether the layer uses a bias vector. + * @param kernelInitializer Initializer for the kernel weights matrix. + * @param biasInitializer Initializer for the bias vector. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + @SuppressWarnings("unchecked") + public Dense( + Ops tf, + String name, + Integer units, + Activation activation, + boolean useBias, + Initializer kernelInitializer, + Initializer biasInitializer, + Constraint kernelConstraint, + Constraint biasConstraint, + long seed, + Class type) { + this(tf, name, units, activation, useBias, kernelInitializer, biasInitializer, kernelConstraint, biasConstraint, seed, type, null); + } + /** + * Creates a Dense layer. + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param units Positive integer, dimensionality of the output space. + * @param activation Activation function to use. If you don't specify anything, no activation is + * applied (ie. "linear" activation: a(x) = x). + * @param useBias whether the layer uses a bias vector. + * @param kernelInitializer Initializer for the kernel weights matrix. + * @param biasInitializer Initializer for the bias vector. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + @SuppressWarnings("unchecked") + public Dense( + Ops tf, + String name, + Integer units, + Activation activation, + boolean useBias, + Initializer kernelInitializer, + Initializer biasInitializer, + Constraint kernelConstraint, + Constraint biasConstraint, + long seed, + Class type, + Options options) { + super(tf, name, true, type, options); + this.units = units; + this.activation = activation; + this.useBias = useBias; + + this.kernelInitializer = + kernelInitializer != null + ? kernelInitializer + : (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); + this.biasInitializer = biasInitializer != null ? biasInitializer : new Zeros<>(tf); + this.kernelConstraint = kernelConstraint; + this.biasConstraint = biasConstraint; + this.seed = seed; + addInputSpec(new InputSpec(InputSpec.Options.create().minRank(2))); + setSupportsMasking(true); + } + + /** + * Implements the operation: {@code output = activation(dot(input, kernel) + bias)} + * + * @param inputs the input Operands, an N-D tensor with shape: {@code (batch_size, ..., + * input_dim)}. The most common situation would be a 2D input with shape @code (batch_size, + * input_dim)}. + * @param masks a list of masks, one for each input, to apply to the result, may be null + * @param training whether the call is in inference mode or training mode + * @param resultType the data tupe for the result + * @param the data tupe for the result + * @return the output with shape {@code (batch_size, ..., units)}. For instance, for a 2D input + * with shape {@code (batch_size, input_dim)}, the output would have shape {@code (batch_size, units)}. + */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + if (inputs == null || inputs.size() != 1) + throw new IllegalArgumentException("Dense only supports 1 input."); + Operand singleInput = inputs.get(0); + Operand input = cast(getTF(), singleInput, getType()); + System.out.println("Dense.call: " + input.shape()); + if (!isBuilt()) build(input.shape()); + Shape inputShape = input.shape(); + int rank = inputShape.numDimensions(); + Operand tOutput; + System.out.println("Dense input: " + inputShape); + if (rank == 2 || rank == Shape.UNKNOWN_SIZE) { + tOutput = getTF().linalg.matMul(input, getKernel()); + } else { + tOutput = TensorDot.tensordot(getTF().scope(), input, getKernel(), new int[] {rank - 1, 0}); + // Reshape the output back to the original number of dimensions of the input. + Shape newShape = inputShape.take(rank - 1).append(getUnits()); + tOutput = getTF().reshape(tOutput, getTF().constant(newShape)); + } + if (isUseBias()) tOutput = getTF().nn.biasAdd(tOutput, getBias()); + + return callPostProcess(Collections.singletonList(cast(getTF(), tOutput, resultType)), training); + } + + /** {@inheritDoc} */ + @Override + public void build(List inputShapes) { + super.build(inputShapes); + if (inputShapes == null || inputShapes.size() != 1) { + throw new IllegalArgumentException("Dense only supports 1 input."); + } + if (!TFloating.class.isAssignableFrom(getType())) + throw new IllegalArgumentException( + String.format( + "Unable to build Dense layer with non-floating point type: %s", + getType().toString())); + + if (kernelInitializer == null) { + // Cast is required because Glorot is TFloating. + kernelInitializer = new Glorot<>(getTF(), VarianceScaling.Distribution.UNIFORM, getSeed()); + } + if (biasInitializer == null) { + biasInitializer = new Zeros<>(getTF()); + } + + Shape inputShape = inputShapes.get(0); + System.out.println("dense.build: " + inputShape); + if (inputShape.size(-1) == Shape.UNKNOWN_SIZE) { + throw new IllegalArgumentException( + "The last dimension of the inputs to `Dense` should be defined. Found `UNKNOWN`."); + } + long lastDim = inputShape.size(-1); + addInputSpec(new InputSpec(InputSpec.Options.create().minRank(2).axesMap(-1, lastDim))); + + kernel = + addWeight( + getName() + "_kernel", + Shape.of(lastDim, this.getUnits()), + kernelInitializer, + kernelConstraint, + true, + getSeed()); + if (isUseBias()) + bias = + addWeight( + getName() + "_bias", + Shape.of(this.getUnits()), + biasInitializer, + biasConstraint, + true, + getSeed()); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + if (inputShapes == null || inputShapes.size() != 1) + throw new IllegalArgumentException("Dense layer: there must be one input shape"); + if (!isBuilt()) build(inputShapes); + Shape singleShape = inputShapes.get(0); + if (singleShape.size(-1) == Shape.UNKNOWN_SIZE) + throw new IllegalArgumentException( + String.format( + "Dense layer: The innermost dimension of input_shape must be defined, but saw: %s", + singleShape)); + Shape headShape = singleShape.take(singleShape.numDimensions() - 1).append(getUnits()); + + return Collections.singletonList(headShape); + } + + /** + * Gets the dense units + * + * @return the dense units + */ + public Integer getUnits() { + return units; + } + + /** + * Gets the use bias flag + * + * @return the use bias flag + */ + public boolean isUseBias() { + return useBias; + } + + /** + * Gets the seed + * + * @return the seed + */ + public long getSeed() { + return seed; + } + + /** + * Gets the kernel variable + * + * @return the kernel variable + */ + public Variable getKernel() { + return kernel; + } + + /** + * Gets the bias variable + * + * @return + */ + public Variable getBias() { + return bias; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java new file mode 100644 index 00000000000..e5685708c30 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java @@ -0,0 +1,555 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Squeeze; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that computes a dot product between samples in two tensors. + * + *

E.g. if applied to a list of two tensors a and b of shape + * (batch_size, n), the output will be a tensor of shape (batch_size, 1) where + * each entry i will be the dot product between `a[i]` and `b[i]`. + * + * @param the data type for the layer's weights and computation. + */ +public class Dot extends Merge { + private final int[] axes; + private final boolean normalize; + + private boolean reshapeRequired; + + + /** + * Creates a Layer that computes a dot product between samples in two tensors, using {@link + * Class#getSimpleName()} as the layer name, and no L2 Normalization. + * + * @param tf the TensorFlow Ops + * @param axes axes along which to take the dot product. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, int axes, Class type) { + this(tf, null, new int[] {axes}, false, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors, using {@link + * Class#getSimpleName()} as the layer name, and no L2 Normalization. + * + * @param tf the TensorFlow Ops + * @param axes axes along which to take the dot product. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, int axes, Class type, Options options) { + this(tf, null, new int[] {axes}, false, type, options); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors, using {@link + * Class#getSimpleName()} as the layer name and no L2 Normalization. + * + * @param tf the TensorFlow Ops + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, int[] axes, Class type) { + this(tf, null, axes, false, type, null); + } + + + /** + * Creates a Layer that computes a dot product between samples in two tensors, using {@link + * Class#getSimpleName()} as the layer name and no L2 Normalization. + * + * @param tf the TensorFlow Ops + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, int[] axes, Class type, Options options) { + this(tf, null, axes, false, type, options); + } + + + /** + * Creates a Layer that computes a dot product between samples in two tensors with no L2 + * Normalization. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, String name, int axes, Class type) { + this(tf, name, new int[] {axes}, false, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors with no L2 + * Normalization. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, String name, int axes, Class type, Options options) { + this(tf, name, new int[] {axes}, false, type, options); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors with no L2 + * Normalization. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, String name, int[] axes, Class type) { + this(tf, name, axes, false, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors with no L2 + * Normalization. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, String name, int[] axes, Class type, Options options) { + this(tf, name, axes, false, type, options); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. + * @param normalize Whether to L2-normalize samples along the dot product axis before taking the + * dot product. If set to True, then the output of the dot product is the cosine proximity + * between the two samples. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, String name, int axes, boolean normalize, Class type) { + this(tf, name, new int[] {axes}, normalize, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. + * @param normalize Whether to L2-normalize samples along the dot product axis before taking the + * dot product. If set to True, then the output of the dot product is the cosine proximity + * between the two samples. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, String name, int axes, boolean normalize, Class type, Options options) { + this(tf, name, new int[] {axes}, normalize, type, options); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param normalize Whether to L2-normalize samples along the dot product axis before taking the + * dot product. If set to True, then the output of the dot product is the cosine proximity + * between the two samples. + * @param type the data type for the weights and computation + */ + public Dot(Ops tf, String name, int[] axes, boolean normalize, Class type) { + this(tf, name, axes, normalize, type, null); + } + + /** + * Creates a Layer that computes a dot product between samples in two tensors. + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param axes axes along which to take the dot product. Should be one or two integers + * corresponding to the desired axis from the first input and the desired axis from the second + * input, respectively. Note that the size of the two selected axes must match. + * @param normalize Whether to L2-normalize samples along the dot product axis before taking the + * dot product. If set to True, then the output of the dot product is the cosine proximity + * between the two samples. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + public Dot(Ops tf, String name, int[] axes, boolean normalize, Class type, Options options) { + super(tf, name, type, options); + if (axes.length < 1 || axes.length > 2) { + throw new IllegalArgumentException( + "Invalid format for axes - must only contain one or two elements."); + } + this.axes = axes; + this.normalize = normalize; + this.setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + protected void build(List inputShapes) { + // Used purely for shape validation. + if (inputShapes.size() != 2) { + throw new IllegalArgumentException("A Dot layer should be called on exactly 2 inputs"); + } + Shape shape1 = inputShapes.get(0); + Shape shape2 = inputShapes.get(1); + if (shape1.isUnknown() || shape2.isUnknown()) { + return; + } + int[] newAxes; + if (axes.length == 1) { + newAxes = new int[2]; + if (axes[0] < 0) { + newAxes[0] = Math.floorMod(axes[0], shape1.numDimensions()); + newAxes[1] = Math.floorMod(axes[0], shape2.numDimensions()); + } else { + newAxes[0] = axes[0]; + newAxes[1] = axes[0]; + } + } else { + newAxes = axes; + } + if (shape1.size(axes[0]) != shape2.size(axes[1])) { + throw new IllegalArgumentException( + String.format( + "Dimension incompatibility %s != %s. Layer shapes: %s, %s. Chosen axes: %s", + shape1.size(axes[0]), + shape2.size(axes[1]), + shape1, + shape2, + Arrays.toString(newAxes))); + } + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + if (inputs.size() != 2) { + throw new IllegalArgumentException("A Dot layer should be called on exactly 2 inputs"); + } + Operand input1 = inputs.get(0); + Operand input2 = inputs.get(1); + int[] newAxes = new int[2]; + if (axes.length == 1) { + if (axes[0] < 0) { + newAxes[0] = Math.floorMod(axes[0], input1.shape().numDimensions()); + newAxes[1] = Math.floorMod(axes[0], input2.shape().numDimensions()); + } else { + newAxes[0] = axes[0]; + newAxes[1] = axes[0]; + } + } else { + for (int i = 0; i < axes.length; i++) { + if (axes[i] < 0) { + newAxes[i] = Math.floorMod(axes[0], inputs.get(i).shape().numDimensions()); + } else { + newAxes[i] = axes[i]; + } + } + } + if (normalize) { + input1 = Losses.l2Normalize(tf, input1, new int[] {axes[0]}); + input2 = Losses.l2Normalize(tf, input2, new int[] {axes[0]}); + } + return batchDot(input1, input2, newAxes); + } + + /** {@inheritDoc} */ + @Override + public List> computeMask( + List> inputs, List> masks) { + return null; + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + if (inputShapes.size() != 2) { + throw new IllegalArgumentException("A Dot layer should be called on a list of 2 inputs."); + } + Shape shape1 = inputShapes.get(0); + Shape shape2 = inputShapes.get(1); + int[] lAxes; + if (axes.length == 1) { + lAxes = new int[2]; + lAxes[0] = Math.floorMod(axes[0], shape1.numDimensions()); + lAxes[1] = Math.floorMod(axes[0], shape2.numDimensions()); + } else { + lAxes = axes; + for (int i = 0; i < lAxes.length; i++) { + lAxes[i] = Math.floorMod(axes[i], shape1.numDimensions()); + } + } + + // pop(axes[0]) + shape1 = shape1.take(lAxes[0]); + long remainder = shape1.numDimensions() - (lAxes[0] + 1); + if (remainder > 0) { + shape1 = shape1.append(shape1.takeLast((int) remainder)); + } + + // pop(axes[1]) + shape2 = shape2.take(lAxes[1]); + remainder = shape2.numDimensions() - (lAxes[1] + 1); + if (remainder > 0) { + shape2 = shape2.append(shape2.takeLast((int) remainder)); + } + if (shape2.numDimensions() > 0) { + // pop(0) + shape2 = shape2.takeLast(shape2.numDimensions() - 1); + } + Shape outputShape = shape1.append(shape2); + + if (outputShape.numDimensions() == 1) { + outputShape.append(1); + } + return Collections.singletonList(outputShape); + } + + /** + * Computes the batch-wise dot product. + * + *

batchDot is used to compute dot product of x and y + * when x and y are data in batch, i.e. in a shape of + * (batch_size, :). batchDot results in aan Operand with less dimensions than + * the input. If the number of dimensions is reduced to 1, we use expandDims + * to make sure that the number of dimensions is at least 2. + * + * @param x Operand with numdimensions >= 2. + * @param y Operand with numdimensions >= 2. + * @param dotAxes the axes to peform the Dot Product. + * @return A operand with shape equal to the concatenation of x's shape (less the + * dimension that was summed over) and y's shape (less the batch dimension and + * the dimension that was summed over). If the final rank is 1, the result is reshaped to + * (batch_size, 1). + */ + private Operand batchDot( + Operand x, Operand y, int[] dotAxes) { + Ops tf = getTF(); + Operand tX = cast(tf, x, getType()); + Operand tY = cast(tf, y, getType()); + + Shape xShape = tX.shape(); + Shape yShape = tY.shape(); + + int xRank = xShape.numDimensions(); + int yRank = yShape.numDimensions(); + + if (xRank < 2 || yRank < 2) { + throw new IllegalArgumentException( + String.format( + "Cannot do batch_dot on inputs with rank < 2. Received inputs with shapes %s and %s.", + xShape, yShape)); + } + + int xBatchSize = (int) xShape.size(0); + int yBatchSize = (int) yShape.size(0); + if (xBatchSize != Shape.UNKNOWN_SIZE && yBatchSize != Shape.UNKNOWN_SIZE) { + if (xBatchSize != yBatchSize) { + throw new IllegalArgumentException( + String.format( + "Cannot do batchDot on inputs with different batch sizes. Received inputs with shapes %s and %s.", + xShape, yShape)); + } + } + + if (dotAxes == null) { + dotAxes = new int[2]; + if (yRank == 2) { + dotAxes[0] = xRank - 1; + dotAxes[1] = yRank - 1; + } else { + dotAxes[0] = xRank - 1; + dotAxes[1] = yRank - 2; + } + } else if (dotAxes.length == 1) { + dotAxes = new int[] {dotAxes[0], dotAxes[0]}; + } + + if (dotAxes[0] < 0) { + dotAxes[0] = Math.floorMod(dotAxes[0], xRank); + } + if (dotAxes[1] < 0) { + dotAxes[1] = Math.floorMod(dotAxes[1], yRank); + } + if (dotAxes[0] == 0 || dotAxes[1] == 0) { + throw new IllegalArgumentException( + "Cannot perform batch_dot over axis 0. If your inputs are not batched, add a dummy batch dimension to your inputs using tf.expandDims(x, 0)"); + } + + int a0 = dotAxes[0]; + int a1 = dotAxes[1]; + int d1 = (int) xShape.size(a0); + int d2 = (int) yShape.size(a1); + + if (d1 != Shape.UNKNOWN_SIZE && d2 != Shape.UNKNOWN_SIZE && d1 != d2) { + throw new IllegalArgumentException( + String.format( + "Cannot do batch_dot on inputs with shapes %s and %s with axes %s. x.shape[%d] != %d, y.shape[%d] != %d", + xShape, yShape, Arrays.toString(dotAxes), a0, d1, d2)); + } + + int origXRank = xRank; + int origYRank = yRank; + if (xRank == 2) { + tX = tf.expandDims(tX, tf.constant(1)); + xRank++; + a0++; + } + if (yRank == 2) { + tY = tf.expandDims(tY, tf.constant(2)); + yRank += 1; + } + + // move x's dimension to be reduced to last axis. + if (a0 != xRank - 1) { + int[] pattern = new int[xRank]; + for (int i = 0; i < a0; i++) { + pattern[i] = i; + } + for (int i = a0, j = 0; i < xRank; i++) { + pattern[j++] = i; + } + pattern[xRank - 1] = a0; + tX = tf.linalg.transpose(tX, tf.constant(pattern)); + } + // move y's dimension to be reduced to axis 1. + if (a1 != 1) { + int[] pattern = new int[yRank]; + + for (int i = 0, j = 0; i < xRank; i++) { + if (i == 1) { // leave dim 1 slot open + j++; + continue; + } + if (i == a1) { // skip a1 dim + continue; + } + pattern[j++] = i; + } + pattern[1] = a1; + } + + // normalize both inputs to rank 3. + boolean xSquashed = false; + Operand xMidShape = null; + if (xRank > 3) { + org.tensorflow.op.core.Shape tmpShape = tf.shape(tX, TInt64.class); + xMidShape = tf.shape.take(tmpShape, tf.constant((long) (xRank)), TInt64.class); + xMidShape = tf.shape.takeLast(tmpShape, tf.constant((long) (xRank - 1)), TInt64.class); + + Operand squashedShape = + tf.stack( + Arrays.asList( + tf.shape.size(tmpShape, tf.constant(0l), TInt64.class), + tf.constant(Shape.UNKNOWN_SIZE), + tf.shape.size(tmpShape, tf.constant((long) (xRank - 1)), TInt64.class))); + xSquashed = true; + } + + boolean ySquashed = false; + Operand yTrailDims = null; + if (yRank > 3) { + yTrailDims = + tf.shape.takeLast( + tf.shape(tY, TInt64.class), tf.constant((long) (yRank - 2)), TInt64.class); + + Operand squashedShape = + tf.stack( + Arrays.asList( + tf.shape.size(y, tf.constant(0L), TInt64.class), + tf.shape.size(y, tf.constant(1L), TInt64.class), + tf.constant(-1L))); + ySquashed = true; + } + + Operand result = org.tensorflow.framework.op.linalg.MatMul.matmul(getTF().scope(), tX, tY); + boolean doReshape = false; + Operand outputShape = tf.shape(result, TInt64.class); + + if (xSquashed && xMidShape != null) { + outputShape = + tf.concat( + Arrays.asList( + tf.shape.size(outputShape, tf.constant(0L), TInt64.class), + xMidShape, + tf.shape.size(outputShape, tf.constant(-1L), TInt64.class)), + tf.constant(0)); + doReshape = true; + } + + if (ySquashed && yTrailDims != null) { + + outputShape = + tf.concat( + Arrays.asList( + tf.slice(outputShape, tf.constant(0), tf.constant(outputShape.rank() - 1)), + yTrailDims), + tf.constant(0)); + doReshape = true; + } + + if (doReshape) { + result = tf.reshape(result, outputShape); + } + + if (origXRank == 2) { + result = tf.squeeze(result, Squeeze.axis(Collections.singletonList(1L))); + } else if (origYRank == 2) { + result = tf.squeeze(result, Squeeze.axis(Collections.singletonList(-1L))); + } + return result; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java new file mode 100644 index 00000000000..340ff278ac9 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java @@ -0,0 +1,256 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.RandomUniform; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Applies Dropout to the input. + * + *

The Dropout layer randomly sets input units to 0 with a frequency of rate at each step during + * training time, which helps prevent overfitting. Inputs not set to 0 are scaled up by 1/(1 - rate) + * such that the sum over all inputs is unchanged. + * + *

Note that the Dropout layer only applies when training is set to true such that no values are + * dropped during inference. When using model.fit, training will be appropriately set to true + * automatically, and in other contexts, you can set the kwarg explicitly to True when calling the + * layer. + * + *

(This is in contrast to setting trainable=false for a Dropout layer. trainable does not affect + * the layer's behavior, as Dropout does not have any variables/weights that can be frozen during + * training.) + * + * @param the data type for the layer's weights and computation. + * @see Hinton G, et al. 2012, Improving neural networks + * by preventing co-adaptation of feature detectors + */ +public class Dropout extends Layer { + + private final float rate; + private final Shape noiseShape; + private final long seed; + + /** + * Creates a Dropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dropout(Ops tf, float rate, long seed, Class type) { + + this(tf, null, rate, null, seed, type, null); + } + + /** + * Creates a Dropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dropout(Ops tf, float rate, long seed, Class type, Options options) { + + this(tf, null, rate, null, seed, type, options); + } + + /** + * Creates a Dropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dropout(Ops tf, float rate, Shape noiseShape, long seed, Class type) { + this(tf, null, rate, noiseShape, seed, type, null); + } + + /** + * Creates a Dropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dropout(Ops tf, float rate, Shape noiseShape, long seed, Class type, Options options) { + this(tf, null, rate, noiseShape, seed, type, options); + } + + /** + * Creates a Dropout layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public Dropout( + Ops tf, + String name, + float rate, + Shape noiseShape, + long seed, + Class type) { + this(tf, name, rate, noiseShape, seed, type, null); + } + + /** + * Creates a Dropout layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask + * that will be multiplied with the input. For instance, if your inputs have shape + * (batch_size, timesteps, features) and you want the dropout mask to be the same for all + * timesteps, you can use noise_shape=(batch_size, 1, features). May be null. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public Dropout( + Ops tf, + String name, + float rate, + Shape noiseShape, + long seed, + Class type, + Options options) { + super(tf, name, true, type, options); + if (rate < 0 || rate >= 1) + throw new IllegalArgumentException("The rate must be between >= 0 and < 1, inclusive."); + this.rate = rate; + this.noiseShape = noiseShape; + this.seed = seed; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + + for (Operand input : inputs) { + + Operand output; + if (!TFloating.class.isAssignableFrom(input.type())) { + output = cast(tf, input, TFloat64.class); + } else { + output = (Operand) input; + } + + if (training) { + Operand rateV = cast(tf, tf.constant(rate), getType()); + + Operand noise = + noiseShape == null ? tf.shape(input, TInt64.class) : tf.constant(noiseShape); + + Operand tOutput = cast(getTF(), output, getType()); + tOutput = dropout(tOutput, rateV, noise, seed); + + outputs.add(cast(getTF(), tOutput, resultType)); + } else { + outputs.add(cast(getTF(), output, resultType)); + } + } + return callPostProcess(outputs, training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } + + /** + * Computes dropout: randomly sets elements to zero to prevent overfitting. + * + * @param input the input + * @param rate the drop out rate, the probability that each element is dropped. For example, + * setting rate=0.1 would drop 10% of input elements. + * @param noiseShape the noise shape representing the shape for randomly generated keep/drop + * flags. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @return an Operand with the same shape as the input with a percentage of entries dropped out. + */ + /* TODO, this is defined as tf.dropout() in python and is an nn op. Do we want it here? */ + private Operand dropout( + Operand input, Operand rate, Operand noiseShape, long seed) { + Ops tf = getTF(); + + Operand one = cast(tf, tf.constant(1.), input.type()); + Operand keepProb = tf.math.sub(one, rate); + Operand scale = tf.math.div(one, keepProb); + Operand ret = tf.math.mul(input, scale); + + Operand randomTensor = + tf.random.randomUniform(noiseShape, input.type(), RandomUniform.seed(seed)); + Operand keepMask = tf.math.greaterEqual(randomTensor, rate); + ret = tf.math.mul(ret, cast(tf, keepMask, ret.type())); + return tf.reshape(ret, tf.shape(input)); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java new file mode 100644 index 00000000000..b0d453135e6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java @@ -0,0 +1,109 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Exponential Linear Unit layer. + * + *

It follows:: + * + *

+ *     f(x) =  alpha * (exp(x) - 1.) for x < 0
+ *     f(x) = x for x >= 0
+ * 
+ * + * @param the data type for the layer's weights and computation. + */ +public class ELU extends Layer { + public static float DEFAULT_ALPHA = 1.0f; + + private final float alpha; + + /** + * Creates a ELU Layer with a unique name generated based on * {@link Class#getSimpleName()} and + * {@link #DEFAULT_ALPHA} for the alpha value. + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + */ + public ELU(Ops tf, Class type) { + this(tf, null, DEFAULT_ALPHA, type, null); + } + + /** + * Creates a ELU Layer with {@link #DEFAULT_ALPHA} for the alpha value. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public ELU(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_ALPHA, type, null); + } + + /** + * Creates a ELU Layer with a unique name generated based on * {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops. + * @param alpha Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + */ + public ELU(Ops tf, float alpha, Class type, Options options) { + this(tf, null, alpha, type, options); + } + /** + * Creates a ELU Layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param alpha Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public ELU(Ops tf, String name, float alpha, Class type, Options options) { + super(tf, name, true, type, options); + this.alpha = alpha; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + org.tensorflow.framework.activations.ELU elu = + new org.tensorflow.framework.activations.ELU<>(getTF(), alpha); + List> tInputs = convertList(inputs, getType()); + List> results = new ArrayList<>(); + tInputs.forEach(tInput -> results.add(cast(getTF(), elu.call(tInput), resultType))); + return callPostProcess(results, training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java new file mode 100644 index 00000000000..f9f858d6953 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java @@ -0,0 +1,197 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.TensorFormat; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Flattens the input. Does not affect the batch size. + * + *

Note: If inputs are shaped (batch,) without a feature axis, then flattening adds an extra + * channel dimension and output shape is . + * + * @param the data type for the layer's weights and computation. + */ +public class Flatten extends Layer { + private static final int FLATTEN_INPUT_LENGTH = 1; + private final TensorFormat dataFormat; + + /** + * Creates a Flatten Layer with a unique name generated based on * {@link Class#getSimpleName()} + * and {@link TensorFormat#NHWC} for the data format + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + */ + public Flatten(Ops tf, Class type) { + this(tf, null, TensorFormat.NHWC, type, null); + } + + /** + * Creates a Flatten Layer with a unique name generated based on {@link Class#getSimpleName()} and + * {@link TensorFormat#NHWC} for the data format + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public Flatten(Ops tf, String name, Class type) { + this(tf, name, TensorFormat.NHWC, type, null); + } + + + /** + * Creates a Flatten Layer with a unique name generated based on * {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops. + * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} + * corresponds to inputs with shape (batch, ..., channels) + * while {@link TensorFormat#NCHW} corresponds to inputs with shape + * (batch, channels, ...). + * @param type the data type for the layer's weights and computation. + */ + public Flatten(Ops tf, TensorFormat dataFormat, Class type) { + this(tf, null, dataFormat, type, null); + } + + /** + * Creates a Flatten Layer with a unique name generated based on * {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops. + * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} + * corresponds to inputs with shape (batch, ..., channels) + * while {@link TensorFormat#NCHW} corresponds to inputs with shape + * (batch, channels, ...). + * @param type the data type for the layer's weights and computation. + */ + public Flatten(Ops tf, TensorFormat dataFormat, Class type, Options options) { + this(tf, null, dataFormat, type, options); + } + + /** + * Creates a Flatten Layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} + * corresponds to inputs with shape (batch, ..., channels) + * while {@link TensorFormat#NCHW} corresponds to inputs with shape + * (batch, channels, ...). + * @param type the data type for the layer's weights and computation. + */ + public Flatten(Ops tf, String name, TensorFormat dataFormat, Class type) { + this(tf, name, dataFormat, type, null); + } + /** + * Creates a Flatten Layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} + * corresponds to inputs with shape (batch, ..., channels) + * while {@link TensorFormat#NCHW} corresponds to inputs with shape + * (batch, channels, ...). + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Flatten(Ops tf, String name, TensorFormat dataFormat, Class type, Options options) { + super(tf, name, true, type, options); + this.dataFormat = dataFormat; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + // this layer only accepts one input + if (inputs == null || inputs.size() != 1) + throw new IllegalArgumentException("Flatten layer: only accepts 1 input"); + + Operand input = inputs.get(0); + if (!isBuilt()) build(input.shape()); + Shape shape = input.shape(); + int rank = shape.numDimensions(); + if (this.dataFormat == TensorFormat.NCHW) { + if (rank != Shape.UNKNOWN_SIZE && rank > 1) { + long[] permutation = new long[rank + 1]; + permutation[0] = 0; + for (int i = 2; i < rank; i++) permutation[i - 1] = i; + permutation[rank] = 1; + input = getTF().linalg.transpose(input, getTF().constant(permutation)); + } + } + + if (rank == 1) { + input = getTF().expandDims(input, getTF().constant(1)); + } else { + Operand flattenedShape; + long[] dims = shape.asArray(); + if (dims != null) { + long batchDim = dims[0]; + long[] nonBatchDims = new long[dims.length - 1]; + System.arraycopy(dims, 1, nonBatchDims, 0, nonBatchDims.length); + Shape nonBatchShape = Shape.of(nonBatchDims); + if (!nonBatchShape.hasUnknownDimension()) { + int lastDim = 1; + for (long dim : nonBatchDims) lastDim *= dim; + flattenedShape = getTF().constant(Shape.of(-1L, lastDim)); + } else if (batchDim != Shape.UNKNOWN_SIZE) { + flattenedShape = getTF().constant(Shape.of(batchDim, -1L)); + } else { + Operand batchDimension = + getTF().shape.size(input, getTF().constant(0L), TInt64.class); + flattenedShape = getTF().shape.append(batchDimension, getTF().constant(0L)); + } + input = getTF().reshape(input, flattenedShape); + } + } + return callPostProcess(Collections.singletonList(cast(getTF(), input, resultType)), training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + if (inputShapes == null || inputShapes.size() != 1) + throw new IllegalArgumentException("Dense layer: there must be one input shape"); + if (!isBuilt()) build(inputShapes); + Shape inputShape = inputShapes.get(0); + long lastDim = 1L; + for (int i = 1; i < inputShape.numDimensions(); i++) { + lastDim *= inputShape.size(i); + } + // creates a new shape of (batchSize, rest) + Shape newShape = Shape.of(inputShape.size(0)); + newShape = newShape.append(lastDim); + return Collections.singletonList(newShape); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java new file mode 100644 index 00000000000..73707add188 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java @@ -0,0 +1,173 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.ParameterizedTruncatedNormal; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Apply multiplicative 1-centered Gaussian noise. + * + *

As it is a regularization layer, it is only active at training time. + * + * @param the data type for the layer's weights and computation. + */ +public class GaussianDropout extends Layer { + + private final float rate; + private final long seed; + + /** + * Creates a GaussianDropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param rate A number between 0 and 1. Drop probability. The multiplicative noise will have + * standard deviation: sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public GaussianDropout(Ops tf, float rate, long seed, Class type) { + + this(tf, null, rate, seed, type, null); + } + + /** + * Creates a GaussianDropout layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param rate A number between 0 and 1. Drop probability. The multiplicative noise will have + * standard deviation: sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public GaussianDropout(Ops tf, float rate, long seed, Class type, Options options) { + + this(tf, null, rate, seed, type, options); + } + + /** + * Creates a GaussianDropout layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public GaussianDropout( + Ops tf, String name, float rate, long seed, Class type) { + this(tf, name, rate, seed, type, null); + } + + + /** + * Creates a GaussianDropout layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param rate A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public GaussianDropout( + Ops tf, String name, float rate, long seed, Class type, Options options) { + super(tf, name, true, type, options); + if (rate < 0 || rate >= 1) + throw new IllegalArgumentException("The rate must be between >= 0 and < 1, inclusive."); + this.rate = rate; + this.seed = seed; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + + for (Operand input : inputs) { + + Operand output = cast(tf, input, getType()); + + if (training && rate >= 0 && rate <= 1) { + + Operand rateV = cast(tf, tf.constant(rate), getType()); + + output = dropout(output, rateV, seed); + outputs.add(output); + } else { + outputs.add(output); + } + } + return callPostProcess(convertTo(outputs, resultType), training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } + + /** + * Computes dropout: randomly sets elements to zero to prevent overfitting. + * + * @param input the input + * @param rate the drop out rate, the probability that each element is dropped. For example, + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @return an Operand with the same shape as the input with a percentage of entries dropped out. + */ + /* TODO, this is defined as tf.dropout() in python and is an nn op. Do we want it here? */ + private Operand dropout(Operand input, Operand rate, long seed) { + Ops tf = getTF(); + + Operand one = cast(tf, tf.constant(1), input.type()); + Operand zero = cast(tf, tf.constant(0), input.type()); + Operand keepProb = tf.math.sub(one, rate); + Operand stdDev = tf.math.sqrt(tf.math.div(rate, keepProb)); + + Operand randomNormal = + tf.random.parameterizedTruncatedNormal( + tf.shape(input), one, stdDev, zero, one, ParameterizedTruncatedNormal.seed(seed)); + + return tf.math.mul(input, randomNormal); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java new file mode 100644 index 00000000000..ce71d7e5805 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java @@ -0,0 +1,170 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.random.ParameterizedTruncatedNormal; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Apply additive zero-centered Gaussian noise. + * + *

This is useful to mitigate overfitting (you could see it as a form of random data + * augmentation). Gaussian Noise (GS) is a natural choice as corruption process for real valued + * inputs. + * + * @param the data type for the layer's weights and computation. + */ +public class GaussianNoise extends Layer { + + private final float stddev; + private final long seed; + + /** + * Creates a GaussianNoise layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param stddev A number between 0 and 1. Drop probability. The multiplicative noise will have + * standard deviation: sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public GaussianNoise(Ops tf, float stddev, long seed, Class type) { + + this(tf, null, stddev, seed, type, null); + } + + + /** + * Creates a GaussianNoise layer, using a unique name will be generated based on {@link + * Class#getSimpleName()} and no noiseShape. + * + * @param tf the TensorFlow Ops. + * @param stddev A number between 0 and 1. Drop probability. The multiplicative noise will have + * standard deviation: sqrt(rate / (1 - rate)). + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public GaussianNoise(Ops tf, float stddev, long seed, Class type, Options options) { + + this(tf, null, stddev, seed, type, options); + } + + /** + * Creates a GaussianNoise layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param stddev A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + */ + public GaussianNoise( + Ops tf, String name, float stddev, long seed, Class type) { + this(tf, name, stddev, seed, type, null); + } + /** + * Creates a GaussianNoise layer + * + * @param tf the TensorFlow Ops. + * @param name name the unique name for this layer. If null, a unique name will be generated based + * on {@link Class#getSimpleName()}. + * @param stddev A number between 0 and 1. Fraction of the input units to drop. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the data type for the weights and computation + * @param options the layer's options. + */ + public GaussianNoise( + Ops tf, String name, float stddev, long seed, Class type, Options options) { + super(tf, name, true, type, options); + this.stddev = stddev; + this.seed = seed; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @SuppressWarnings("unchecked") + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + + for (Operand input : inputs) { + + Operand output = cast(tf, input, getType()); + + if (training) { + + Operand stddevV = cast(tf, tf.constant(stddev), getType()); + + output = dropout(output, stddevV, seed); + outputs.add(output); + } else { + outputs.add(output); + } + } + return callPostProcess(convertTo(outputs, resultType), training); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + return inputShapes; + } + + /** + * Computes dropout: randomly sets elements to zero to prevent overfitting. + * + * @param input the input + * @param stdDev the drop out rate, the probability that each element is dropped. For example, + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @return an Operand with the same shape as the input with a percentage of entries dropped out. + */ + /* TODO, this is defined as tf.dropout() in python and is an nn op. Do we want it here? */ + private Operand dropout(Operand input, Operand stdDev, long seed) { + Ops tf = getTF(); + + Operand one = cast(tf, tf.constant(1), input.type()); + Operand zero = cast(tf, tf.constant(0), input.type()); + + Operand randomNormal = + tf.random.parameterizedTruncatedNormal( + tf.shape(input), zero, stdDev, zero, one, ParameterizedTruncatedNormal.seed(seed)); + + return tf.math.add(input, randomNormal); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java new file mode 100644 index 00000000000..adc2ec8f599 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java @@ -0,0 +1,345 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that handles model input. + * + * @param the data type for the layer's calculations. + */ +public class Input extends Layer { + + private final Class inputType; + private final boolean placeholder; + private final Operand output; + + /** + * Creates an input layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops. + * @param input The input + * @param type the data type for the layer's weights and computation. + */ + public Input(Ops tf, Operand input, Class type) { + + this(tf, null, input, null, type, null); + } + + + /** + * Creates an input layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops. + * @param input The input + * @param type the data type for the layer's weights and computation. + * @param options the Layer options + */ + public Input(Ops tf, Operand input, Class type, Options options) { + + this(tf, null, input, null, type, options); + } + + /** + * Creates an input layer. + * + * @param tf the TensorFlow Op + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param input The input + * @param type the data type for the layer's weights and computation. + */ + public Input( + Ops tf, String name, Operand input, Class type) { + + this(tf, name, input, null, type, null); + } + + /** + * Creates an input layer. + * + * @param tf the TensorFlow Op + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param input The input + * @param type the data type for the layer's weights and computation. + */ + public Input( + Ops tf, String name, Operand input, Class type, Options options) { + + this(tf, name, input, null, type, options); + } + + /** + * Creates an input layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops, before the first call to the {@link #call} method method is + * called. + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + */ + public Input(Ops tf, Class inputType, Class type) { + this(tf, null, null, inputType, type, null); + } + + /** + * Creates an input layer using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops, before the first call to the {@link #call} method method is + * called. + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + * @param options the layer's options, may be null + */ + public Input(Ops tf, Class inputType, Class type, Options options) { + this(tf, null, null, inputType, type, options); + } + + /** + * Creates an input layer. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + */ + public Input( + Ops tf, String name, Class inputType, Class type) { + this(tf, name, null, inputType, type, null); + } + + /** + * Creates an input layer. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + * @param options the layer's options, may be null + */ + public Input( + Ops tf, String name, Class inputType, Class type, Options options) { + this(tf, name, null, inputType, type, options); + } + + /** + * Creates an input layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param input The input + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + * @throws IllegalArgumentException if inputShape and either batchSize or batchInputShape are not + * null, and if both inputShape and input are null. + */ + public Input( + Ops tf, + String name, + Operand input, + Class inputType, + Class type) { + this(tf, name, input, inputType, type, null); + } + /** + * Creates an input layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null, will generate a name based on {@link + * Class#getSimpleName()} + * @param input The input + * @param inputType the data type for the input and output, if null, input.type() is used + * @param type the data type for the layer's weights and computation. + * @param options the layer's options, may be null + * @throws IllegalArgumentException if inputShape and either batchSize or batchInputShape are not + * null, and if both inputShape and input are null. + */ + public Input( + Ops tf, + String name, + Operand input, + Class inputType, + Class type, + Options options) { + super(tf, name, true, type, options); + Options c = getInstanceOptions(); + + if (inputType == null && input == null) { + throw new IllegalArgumentException("both input and inputType cannot be null"); + } + + if (input != null && inputType != null && !input.type().equals(inputType)) { + throw new IllegalArgumentException( + String.format("input.type() differs from inputType: %s vs. %s", input.type(), inputType)); + } + + //if ((c == null || c.inputShape == null) && input == null) { + // throw new IllegalArgumentException("both input and inputShape cannot be null"); + // } + + if (c != null) { + if ( c.inputShape != null + && (c.batchSize != null || c.batchInputShape != null)) { + throw new IllegalArgumentException( + "Only provide the inputShape or the batchSize or batchInputShape parameters at the size."); + } + } + + Shape lShape; + + if (c != null && c.batchInputShape != null) { + lShape = c.batchInputShape.takeLast(c.batchInputShape.numDimensions() - 1); + setBatchInputShape(c.batchInputShape); + if (getBatchSize() == null) { + setBatchSize(c.batchInputShape.size(0)); + } + } else { + if(input == null) { + lShape = (c == null || c.inputShape == null) ? Shape.of(Shape.UNKNOWN_SIZE) : c.inputShape; + }else { + lShape = (c == null || c.inputShape == null) ? input.shape() : c.inputShape; + } + + setBatchSize((c == null || c.batchSize == null) ? Shape.UNKNOWN_SIZE : c.batchSize); + + setBatchInputShape(Shape.of(getBatchSize()).append(lShape)); + } + setInputShape(lShape); + + this.inputType = inputType == null ? input.type() : inputType; + super.build(lShape); + if (input != null) { + output = input; + placeholder = false; + } else { + output = getTF().placeholder(this.inputType, Placeholder.shape(getBatchInputShape())); + placeholder = true; + } + } + + /** + * Gets the input Operand. This is a convenience method to create the input for a Model. + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + * @param the data type for the layer's calculations. + * @return the output + */ + public static Operand input( + Ops tf, Class type) { + return input(tf, type, null); + } + + /** + * Gets the input Operand. This is a convenience method to create the input for a Model. + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + * @param options the Layer options + * @param the data type for the layer's calculations. + * @return the output + */ + public static Operand input( + Ops tf, Class type, Options options) { + Input layer = new Input<>(tf, type, type, options); + return layer.getOutput(type); + } + + /** + * Gets the input Operand. This is a convenience method to create the input for a Model. + * + * @param tf the TensorFlow Ops. + * @param input The input + * @param type the data type for the layer's weights and computation. + * @param options the Layer options + * @param the data type for the layer's calculations. + * @return the output + */ + public static Operand input( + Ops tf, Operand input, Class type, Options options) { + Input layer = new Input<>(tf, input, type, options); + return layer.getOutput(); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + return callPostProcess(Collections.singletonList(getOutput(resultType)), training); + } + + /** + * Gets the output Operand. + * + *

Note: a calling class should call this method directly, rather than calling one of the + * {@link #call} methods + * + * @return the output Operand. + */ + public Operand getOutput() { + return output; + } + + /** + * Gets the output Operand. + * + *

Note: a calling class should call this method directly, rather than calling one of the + * {@link #call} methods + * + * @param resultType the output data type + * @return the output Operand. + */ + public Operand getOutput(Class resultType) { + + return cast(getTF(), output, resultType); + } + + /** + * Identifies whether the output is a placeholder or not. + * + * @return true, if the output represents a placeholder + */ + public boolean isPlaceholder() { + return placeholder; + } + + /** + * The data type expected by the input. + * + * @return The data type expected by the input. + */ + public Class getInputType() { + return inputType; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java new file mode 100644 index 00000000000..53575b64d3d --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java @@ -0,0 +1,203 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.BiFunction; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Wraps arbitrary Java Lambda as a Layer. + * + *

The Lambda layer exists so that arbitrary TensorFlow functions can be used when + * constructing Sequential models. Lambda layers are best suited for + * simple operations or quick experimentation. + * + *

the Java lambda function is in the form x = function(tf, input). The first + * argument is the TensorFlow Ops, the second argument is the input Operand. For example: + * + *

+ *        Lambda lambda = new Lambda(tf, (ops, input) -> ops.math.mul(ops.constant(2), input), TFloat32.class);
+ *    
+ * + * @param the data type for the layer's weights and computation. + */ +public class Lambda extends Layer { + private BiFunction, Operand> function; + + /** + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName() + * + * @param tf the TensorFlow Ops + * @param function the Java lambda function in the form x = function(tf, input). + * The first argument is the TensorFlow Ops, the second argument is the input Operand. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Lambda(Ops tf, Class type) { + this(tf, null, null, type, null); + } + + /** + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName() + * + * @param tf the TensorFlow Ops + * @param function the Java lambda function in the form x = function(tf, input). + * The first argument is the TensorFlow Ops, the second argument is the input Operand. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Lambda(Ops tf, Class type, Options options) { + this(tf, null, null, type, options); + } + + /** + * Creates a Lambda layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer, if null, generates a unique name based on {@link + * Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public Lambda(Ops tf, String name, Class type) { + this(tf, name, null, type, null); + } + + /** + * Creates a Lambda layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer, if null, generates a unique name based on {@link + * Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public Lambda(Ops tf, String name, Class type, Options options) { + this(tf, name, null, type, options); + } + + /** + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param function The Java lambda function in the form x = function(tf, input). The + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. + * @param type the data type for the layer's weights and computation. + + */ + public Lambda( + Ops tf, BiFunction, Operand> function, Class type) { + this(tf, null, function, type, null); + } + + /** + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param function The Java lambda function in the form x = function(tf, input). The + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Lambda( + Ops tf, BiFunction, Operand> function, Class type, Options options) { + this(tf, null, function, type, options); + } + + /** + * Creates a Lambda layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer, if null, generates a unique name based on {@link + * Class#getSimpleName()}. + * @param function the Java lambda function in the form x = function(tf, input). The + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. + * @param type the data type for the layer's weights and computation. + */ + public Lambda( + Ops tf, + String name, + BiFunction, Operand> function, + Class type) { + this(tf, name, function, type, null); + } + + /** + * Creates a Lambda layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer, if null, generates a unique name based on {@link + * Class#getSimpleName()}. + * @param function the Java lambda function in the form x = function(tf, input). The + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. + * @param type the data type for the layer's weights and computation. + */ + public Lambda( + Ops tf, + String name, + BiFunction, Operand> function, + Class type, + Options options) { + super(tf, name, true, type, options); + this.function = function; + } + + /** + * Sets the lambda function + * + * @param function the Java lambda function in the form + * x = function(tf, input). The first argument is the TensorFlow Ops, the second + * argument is the input Operand. + */ + public void setLamda(BiFunction, Operand> function) { + this.function = function; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + if (inputs.isEmpty()) { + return Collections.emptyList(); + } + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + for (Operand input : inputs) { + if (function != null) { + Operand tInput = cast(tf, input, getType()); + Operand result = function.apply(tf, tInput); + outputs.add(result); + } else { + outputs.add(cast(tf, input, getType())); + } + } + return convertTo(outputs, resultType); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java new file mode 100644 index 00000000000..8df54232e1f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java @@ -0,0 +1,940 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.constraints.Constraint; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.layers.impl.InputSpec; +import org.tensorflow.framework.layers.impl.VariableDef; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * The base abstract class for Layers. + * + *

A layer is a callable object that takes as input one or more tensors and that outputs one or + * more tensors. It involves computation, defined in the call() method, and a state (weight + * variables), defined either in the constructor or in the first call to the {@link #call} method + * method. + * + *

Users will just instantiate a layer and then treat it as a callable. + * + * @param the data type for the layer's weights and computation. + */ +public abstract class Layer { + + private static final Map nameMap = new HashMap<>(); + private final String name; + private final Class type; + private final List> weights = new ArrayList<>(); + private final List> trainableWeights = new ArrayList<>(); + private final List> nonTrainableWeights = new ArrayList<>(); + private final List losses = new ArrayList<>(); + private final List> metrics = new ArrayList<>(); + private final Map, VariableDef> variableMap = new HashMap<>(); + // Note that, unlike other classes, tf may not be set in the constructor, but may be set later. + // the idea behind this is that the model can be built with the layers before the model + // sets the tf instance probably during the model.compile phase. + private final Ops tf; + private boolean trainable; + // TODO change to Regularizer class + private Object activityRegularizer; + private boolean built; + private boolean stateful; + private boolean supportsMasking; + // These are the inputShapes as presented to build + private List inputShapes; + private List inputSpecs; + // These are the shapes/dimensions presented by Options. + private Shape batchInputShape; + private Long batchSize; + private Shape inputShape; + private Options instanceOptions; + + /** + * Creates the base Layer class + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer, if null will generate a name using the method + * {@link #genName()} + * @param trainable whether the layer's variables should be trainable or not. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Layer(Ops tf, String name, boolean trainable, Class type, Options options) { + this.name = name == null ? genName() : name; + this.setTrainable(trainable); + this.type = type; + this.batchSize = Shape.UNKNOWN_SIZE; + loadOptions(options); + this.tf = tf.withSubScope(getName()); + } + + private void loadOptions(Options options) { + if (options != null) { + instanceOptions = options; + if (instanceOptions.batchInputShape != null) { + this.batchInputShape = instanceOptions.batchInputShape; + } + if (instanceOptions.batchSize != null) { + this.batchSize = instanceOptions.batchSize; + } + if (instanceOptions.inputShape != null) { + this.inputShape = instanceOptions.inputShape; + } + if (instanceOptions.activityRegularizer != null) { + this.activityRegularizer = instanceOptions.activityRegularizer; + } + if (instanceOptions.metrics != null) { + this.metrics.addAll(instanceOptions.metrics); + } + if (instanceOptions.losses != null) { + this.losses.addAll(instanceOptions.losses); + } + } + } + + /** + * Generates an unique name by appending an integer value to the {@link Class#getSimpleName} in + * the form {@link Class#getSimpleName}_<identifier>, e.g Dense_1 + * The first call to generate an unique name will only return {@link Class#getSimpleName} with out + * the suffix, e.g Dense. + * + * @return the generated name for the class. + */ + private String genName() { + String base = getClass().getSimpleName(); + Integer id = nameMap.get(base); + if (id == null) { + nameMap.put(base, 0); + return base; + } else { + id++; + nameMap.put(base, id); + return String.format("%s_%d", base, id); + } + } + + /** + * Invokes the layer's algorithm using a single input, returning a single output. Training mode is + * true. + * + *

This is a convenience call on top of {@link {@link #call}}. + * + * @param input the input Operand + * @return the output Operand, or null if no output is generated from the layer's logic. + */ + public Operand call(Operand input) { + + return call(input, null, true, getType()); + } + + /** + * Invokes the layer's algorithm using a single input, returning a single output. Training mode is + * true. + * + *

This is a convenience call on top of {@link {@link #call}}. + * + * @param input the input Operand + * @return the output Operand, or null if no output is generated from the layer's logic. + */ + public Operand call(Operand input, Class type) { + + return call(input, null, true, type); + } + + /** + * Invokes the layer's algorithm using a single input, returning a single output. + * + *

This is a convenience call on top of {@link #call}. + * + * @param input the input Operand + * @param training whether the call is in inference mode or training mode + * @return the output Operand, or null if no output is generated from the layer's logic. + */ + public Operand call( + Operand input, boolean training, Class type) { + return call(input, null, training, type); + } + + /** + * Invokes the layer's algorithm using a single input, returning a single output. + * + *

This is a convenience call on top of {@link #call}. + * + * @param input the input Operand + * @param mask the mask to apply to the result, may be null + * @param training whether the call is in inference mode or training mode + * @return the output Operand, or null if no output is generated from the layer's logic. + */ + public Operand call( + Operand input, Operand mask, boolean training, Class type) { + List> result = + call(Collections.singletonList(input), Collections.singletonList(mask), training, type); + return result != null ? result.get(0) : null; + } + + /** + * Invokes the layer's algorithm Training mode is true. + * + * @param inputs the input Operands + * @return the output Operands + */ + public List> call( + List> inputs, Class type) { + return call(inputs, null, false, type); + } + + /** + * Invokes the layer's logic using a list of inputs, returning a list of outputs. + * + * @param inputs the input Operands + * @param masks a list of masks, one for each input, to apply to the result, may be null + * @param training whether the call is in inference mode or training mode + * @return the output Operands. + */ + public abstract List> call( + List> inputs, + List> masks, + boolean training, + Class type); + + /** + * Post processes a layer's call result + * + * @param inputs the input Operands + * @return the output Operands. + */ + protected List> callPostProcess( + List> inputs, boolean training) { + return handleActivityRegister(inputs); + } + + /** + * Converts a list of inputs to a new list of the internal data type defined for this layer. + * + * @param inputs the inputs. + * @return the new list converted to the new type. + */ + protected List> convertList(List> inputs) { + return convertList(inputs, getType()); + } + /** + * Converts a list of inputs to a new list of the internal data type defined for this layer. + * + * @param inputs the inputs. + * @return the new list converted to the new type. + */ + protected List> convertList( + List> inputs, Class resultType) { + List> result = new ArrayList<>(); + inputs.forEach(input -> result.add(cast(getTF(), input, resultType))); + return result; + } + + /** + * Converts a list of inputs with this class type, to a new list of the new type + * + * @param inputs the inputs. + * @param newType the new type. + * @param the data type for the new type. + * @return the new list converted to the new type. + */ + protected List> convertTo( + List> inputs, Class newType) { + List> result = new ArrayList<>(); + inputs.forEach(input -> result.add(cast(getTF(), input, newType))); + return result; + } + + private List> handleActivityRegister(List> inputs) { + if (this.activityRegularizer != null) { + // TODO activityRegularizer + return inputs; + + } else { + return inputs; + } + } + + /** + * Creates the variables of the layer (optional, for subclass implementers). This is a method that + * implementers of subclasses of Layer or Model can override if they + * need a state-creation step in-between layer instantiation and layer call. This is typically + * used to create the weights of Layer subclasses. + * + *

This method is a convenience method that calls {@link #build(List)}. + * + * @param inputShape the shapes of the inputs, one per input + */ + protected void build(Shape... inputShape) { + build(Arrays.asList(inputShape)); + } + + /** + * Creates the variables of the layer (optional, for subclass implementers). This is a method that + * implementers of subclasses of Layer or Model can override if they + * need a state-creation step in-between layer instantiation and layer call. This is typically + * used to create the weights of Layer subclasses. + * + * @param inputShapes the shapes of the inputs, one per input + * @throws IllegalStateException if the TensorFlow Ops is null. + */ + protected void build(List inputShapes) { + if (tf == null) throw new IllegalStateException("The TensorFlow Ops has not been set yet"); + built = true; + this.inputShapes = inputShapes; + } + + /** + * Computes the output shape of the layer. + * + *

This implementation calls {@link #build(List)} if not already called, and returns the input + * shapes as the output shapes. Sub-classes may want to alter this default behavior + * + *

If the layer has not been built, this method will call {@link #build(List)} on the layer. + * This assumes that the layer will later be used with inputs that match the input shape provided + * here. + * + * @param inputShapes the input shapes, one per input + * @return the output shapes, one per output + */ + public List computeOutputShape(List inputShapes) { + if (!built) build(inputShapes); + return inputShapes; + } + + /** + * Gets the unique name for this layer + * + * @return the unique name for this layer + */ + public String getName() { + return name; + } + + /** + * Gets the trainable setting + * + * @return true, if this layer is trainable + */ + public boolean isTrainable() { + return trainable; + } + + /** + * Sets the trainable indicator + * + * @param trainable the trainable indicator + */ + public void setTrainable(boolean trainable) { + this.trainable = trainable; + } + + /** + * Gets the data type for the layer's weights and computation. + * + * @return the data type for the layer's weights and computation. + */ + public Class getType() { + return type; + } + + /** + * Gets the layer's weights + * + * @return the layer's weights + */ + public List> getWeights() { + return weights; + } + + public void setWeights(List> weights) { + this.weights.clear(); + this.weights.addAll(weights); + } + + /** + * Gets the layer's trainable weights + * + * @return the layer's trainable weights + */ + public List> getTrainableWeights() { + return trainableWeights; + } + + /** + * Gets the layer's non-trainable weights + * + * @return the layer's non-trainable weights + */ + public List> getNonTrainableWeights() { + return nonTrainableWeights; + } + + /** + * Adds a weight to the layer + * + * @param name the variable's name + * @param shape the variable's shape + * @param initializer the variable initializer + * @param trainable whether the variable should be part of the layer's "trainableWeights" + * @throws IllegalStateException if the property {@link #tf} has not been set yet. + */ + public Variable addWeight( + String name, + Shape shape, + Initializer initializer, + Constraint constraint, + boolean trainable, + long seed) { + if (tf == null) { + throw new IllegalStateException("Parameter \"tf\" has not been set"); + } + VariableDef variableDef = + new VariableDef<>(tf, name, shape, initializer, constraint, trainable, seed, getType()); + + Variable variable = variableDef.getVariable(); + + variableMap.put(variable, variableDef); + weights.add(variable); + if (trainable) trainableWeights.add(variable); + else nonTrainableWeights.add(variable); + return variable; + } + + /** + * Adds a weight to the layer + * + * @param variable the variable to add + * @param initializer the variable initializer + * @param trainable whether the variable should be part of the layer's "trainableWeights" + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and type. + * @throws IllegalStateException if the property {@link #tf} has not been set yet. + */ + public Variable addWeight( + String name, + Variable variable, + Initializer initializer, + Constraint constraint, + boolean trainable, + long seed) { + if (tf == null) { + throw new IllegalStateException("Parameter \"tf\" has not been set"); + } + if (variable == null) { + throw new IllegalStateException("Parameter \"variable\" has not been set"); + } + VariableDef variableDef = + new VariableDef<>(tf, name, variable, initializer, constraint, trainable, seed); + variableMap.put(variable, variableDef); + weights.add(variable); + if (trainable) trainableWeights.add(variable); + else nonTrainableWeights.add(variable); + return variable; + } + + /** + * Gets the Operands that initializes all the weights + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and type. + * @return the Operands that initializes all the weights + */ + public List> initializeWeights(long seed) { + List> result = new ArrayList<>(); + weights.forEach(w -> result.add(initializeWeight(w, seed))); + return result; + } + + /** + * Creates an Operand that initializes a weight + * + * @param weight the weight to initialize + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and type. + * @return the Operand that initializes the weight + * @throws IllegalArgumentException if the weight does not have a registered initializer + */ + public Operand initializeWeight(Variable weight, long seed) { + VariableDef varDef = variableMap.get(weight); + if (varDef == null) { // this should not happen if addWeight was used to create/add the weight + addWeight(null, weight, null, null, true, seed); + varDef = variableMap.get(weight); + } + return varDef.init(); + } + + /** + * Computes an output mask tensor. + * + * @param inputs the input Operands + * @param masks the mask Operands. + * @return null or a list of Operands, one for each output from the layer, + */ + @SuppressWarnings("UnusedParameters") + public List> computeMask( + List> inputs, List> masks) { + // the default implementation merely returns the masks. + if (isSupportsMasking()) { + if (masks == null) return null; + return masks.stream().map(m -> cast(getTF(), m, TBool.class)).collect(Collectors.toList()); + } + if (masks == null || masks.isEmpty()) { + throw new IllegalArgumentException( + String.format("%s does not support masking, but was passed a mask", getName())); + } + + return null; + } + + /** + * Gets the Losses assigned to this layer + * + * @return the Losses assigned to this layer + */ + public List getLosses() { + return losses; + } + + /** + * Adds a loss to this layer + * + * @param loss the loss to add + */ + public void addLoss(Loss loss) { + losses.add(loss); + } + + /** + * Adds losses to this layer + * + * @param losses the losses to add + */ + public void addLosses(List losses) { + this.losses.addAll(losses); + } + + /** + * Gets the Losses assigned to this layer + * + * @return the Losses assigned to this layer + */ + public List> getMetrics() { + return metrics; + } + + /** + * Adds a metric to this layer + * + * @param metric the metric to add + */ + public void addMetric(Metric metric) { + metrics.add(metric); + } + + /** + * Adds metrics to this layer + * + * @param metrics the metric to add + */ + public void addMetrics(List> metrics) { + this.metrics.addAll(metrics); + } + + /** + * Determines whether or not the build method has been called. + * + * @return true, if the build method has been called. + */ + public boolean isBuilt() { + return built; + } + + /** + * Sets the build indicator + * + * @param built the build indicator + */ + public void setBuilt(boolean built) { + this.built = built; + } + + /** + * Gets the input shapes, one per input + * + * @return the input shapes, one per input + */ + public List getInputShapes() { + return inputShapes; + } + + /** + * Sets the input shapes, one per input + * + * @param inputShapes the input shapes + */ + public void setInputShapes(List inputShapes) { + this.inputShapes = inputShapes; + } + + /** + * Adds an inputSpec + * + * @param inputSpec the inputSpec + */ + public void addInputSpec(InputSpec inputSpec) { + if (inputSpecs == null) { + inputSpecs = new ArrayList<>(); + } + inputSpecs.add(inputSpec); + } + + /** + * Gets the inputSpecs, one per input + * + * @return the inputSpecs, one per input + */ + public List getInputSpecs() { + return inputSpecs; + } + + /** + * Sets the inputSpecs, one per input + * + * @param inputSpecs the inputSpecs + */ + public void setInputSpecs(List inputSpecs) { + this.inputSpecs = inputSpecs; + } + + /** + * Gets the {@link #tf} property + * + * @return the {@link #tf} property + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the stateful property. + * + *

A stateful layer is a layer whose updates are run during inference too, for instance + * stateful RNNs. + * + * @return true, if this layer is stateful + */ + public boolean isStateful() { + return stateful; + } + + /** + * Sets the stateful property. + * + *

A stateful layer is a layer whose updates are run during inference too, for instance + * stateful RNNs. + * + * @param stateful true, if this layer is stateful. + */ + public void setStateful(boolean stateful) { + this.stateful = stateful; + } + + /** + * Gets the batch input shape + * + * @return the batch input shape + */ + public Shape getBatchInputShape() { + return batchInputShape; + } + + /** + * Sets the batch input shape + * + * @param batchInputShape the batch input shape + */ + public void setBatchInputShape(Shape batchInputShape) { + this.batchInputShape = batchInputShape; + } + + /** + * Gets the batch size + * + * @return the batch size + */ + public Long getBatchSize() { + return batchSize; + } + + /** + * Sets the batch size + * + * @param batchSize the batch size + */ + public void setBatchSize(Long batchSize) { + this.batchSize = batchSize; + } + + /** + * Gets the input shape for this layer + * + * @return the input shape for this layer + */ + public Shape getInputShape() { + return inputShape; + } + + /** + * Sets the input shape for this layer + * + * @param inputShape the input shape for this layer + */ + public void setInputShape(Shape inputShape) { + this.inputShape = inputShape; + } + + /** + * Gets the options instance for this layer. + * + * @return the options instance for this layer. + */ + public Options getInstanceOptions() { + return instanceOptions; + } + + /** + * Gets the activity Regularizer + * + * @return the activity Regularizer + */ + // TODO change to Regularizer class + public Object getActivityRegularizer() { + return activityRegularizer; + } + + /** + * Sets the activity Regularizer + * + * @param activityRegularizer the activity Regularizer + */ + // TODO change to Regularizer class + public void setActivityRegularizer(Object activityRegularizer) { + this.activityRegularizer = activityRegularizer; + } + + /** + * Gets the indicator that this layer supports masking. + * + * @return the indicator that this layer supports masking. + */ + public boolean isSupportsMasking() { + return supportsMasking; + } + + /** + * Sets the indicator that this layer supports masking. + * + * @param supportsMasking the indicator that this layer supports masking. + */ + public void setSupportsMasking(boolean supportsMasking) { + this.supportsMasking = supportsMasking; + } + + /** + * Assigns a value to the variable + * + * @param variable the variable to assign to + * @param value the value to assign + * @return the operand that assigns the value to this variable + * @throws IllegalArgumentException if the variable is not known. + */ + public Operand assign(Variable variable, Operand value) { + VariableDef varDef = variableMap.get(variable); + if (varDef == null) { + throw new IllegalStateException(String.format("Variable %s was not found.", variable)); + } + return varDef.assign(value); + } + + /** + * Adds a value to the variable + * + * @param variable the variable to add to + * @param value the value to add + * @return the operand that adds the value to this variable + * @throws IllegalArgumentException if the variable is not known. + */ + public Operand assignAdd(Variable variable, Operand value) { + VariableDef varDef = variableMap.get(variable); + if (varDef == null) { + throw new IllegalStateException(String.format("Variable %s was not found.", variable)); + } + return varDef.assignAdd(value); + } + + /** + * Subtracts a value from the variable + * + * @param variable the variable to subtract from + * @param value the value to subtract + * @return the operand that subtracts the value from this variable + * @throws IllegalArgumentException if the variable is not known. + */ + public Operand assignSub(Variable variable, Operand value) { + VariableDef varDef = variableMap.get(variable); + if (varDef == null) { + throw new IllegalStateException(String.format("Variable %s was not found.", variable)); + } + return varDef.assignSub(value); + } + + /** Optional attributes for {@link Layer} */ + public static class Options { + protected Shape inputShape; + protected Shape batchInputShape; + protected Long batchSize; + protected List> metrics; + protected List losses; + // TODO change to Regularizer class + protected Object activityRegularizer; + + public static Options create() { + return new Options(); + } + + /** + * Sets the inputShape + * + * @param inputShape the input shape for the layer + * @return this options instance + */ + public Layer.Options inputShape(Shape inputShape) { + this.inputShape = inputShape; + return this; + } + + /** + * Sets the batchSize + * + * @param batchSize the batch input shape for the layer + * @return this Options instance + */ + public Layer.Options batchSize(Long batchSize) { + this.batchSize = batchSize; + return this; + } + + /** + * Sets the shared name + * + * @param batchInputShape the batch input shape for the layer + * @return this Options instance + */ + public Layer.Options batchInputShape(Shape batchInputShape) { + this.batchInputShape = batchInputShape; + return this; + } + + /** + * Sets the activityRegularizer + * + * @param activityRegularizer the activity Regularizer + * @return this Options instance + */ + // TODO change to Regularizer class + public Layer.Options activityRegularizer(Object activityRegularizer) { + this.activityRegularizer = activityRegularizer; + return this; + } + + /** + * Adds a metric + * + * @param metric the metric + * @return this Options instance + */ + public Layer.Options metric(Metric metric) { + if (this.metrics == null) { + this.metrics = new ArrayList<>(); + } + metrics.add(metric); + return this; + } + + /** + * Adds metrics + * + * @param metrics the metrics to add + * @return this Options instance + */ + public Layer.Options metrics(List> metrics) { + if (this.metrics == null) { + this.metrics = new ArrayList<>(metrics); + } else { + this.metrics.addAll(metrics); + } + return this; + } + + /** + * Adds a loss + * + * @param loss the Loss + * @return this Options instance + */ + public Layer.Options loss(Loss loss) { + if (losses == null) { + losses = new ArrayList<>(); + } + losses.add(loss); + return this; + } + + /** + * Adds losses + * + * @param losses the losses to add + * @return this Options instance + */ + public Layer.Options losses(List losses) { + if (this.losses == null) { + this.losses = new ArrayList<>(losses); + } else { + this.losses.addAll(losses); + } + return this; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java new file mode 100644 index 00000000000..57a5d23b291 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java @@ -0,0 +1,108 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.activations.ReLU; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +/** + * Leaky version of a Rectified Linear Unit. + * + *

It allows a small gradient when the unit is not active: + * + *

+ *     f(x) = alpha * x if x < 0
+ *     f(x) = x if x >= 0
+ * 
+ * + * @param the data type for the layer's weights and computation. + */ +public class LeakyReLU extends Layer { + public static float DEFAULT_ALPHA = 0.3f; + + private final float alpha; + + /** + * Creates a LeakyReLU Layer with a unique name generated based on * {@link Class#getSimpleName()} + * and {@link #DEFAULT_ALPHA} for the alpha value. + * + * @param tf the TensorFlow Ops. + * @param type the data type for the layer's weights and computation. + */ + public LeakyReLU(Ops tf, Class type, Options options) { + this(tf, null, DEFAULT_ALPHA, type, null); + } + + /** + * Creates a LeakyReLU Layer with {@link #DEFAULT_ALPHA} for the alpha value. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public LeakyReLU(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_ALPHA, type, null); + } + + /** + * Creates a LeakyReLU Layer with a unique name generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops. + * @param alpha Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + */ + public LeakyReLU(Ops tf, float alpha, Class type, Options options) { + this(tf, null, alpha, type, options); + } + /** + * Creates a LeakyReLU Layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param alpha Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public LeakyReLU(Ops tf, String name, float alpha, Class type, Options options) { + super(tf, name, true, type, options); + this.alpha = alpha; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + ReLU reLU = new ReLU<>(getTF(), alpha, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); + List> tInputs = convertList(inputs); + List> results = new ArrayList<>(); + tInputs.forEach(input -> results.add(reLU.call(input))); + return callPostProcess(convertTo(results, resultType), training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java new file mode 100644 index 00000000000..ea09fafaf98 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that computes the maximum (element-wise) a list of inputs. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Maximum extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Maximum(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Maximum(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Maximum(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Maximum(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.maximum(output, cast(tf, inputs.get(i), getType())); + } + return output; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java new file mode 100644 index 00000000000..a1fe012689c --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java @@ -0,0 +1,92 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that computes the minimum (element-wise) a list of inputs. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Minimum extends Merge { + + + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Minimum(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Minimum(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Minimum(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Minimum(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.minimum(output, cast(tf, inputs.get(i), getType())); + } + return output; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java new file mode 100644 index 00000000000..fbb190cb00b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that multiplies (element-wise) a list of inputs. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + * @param the data type for the layer's weights and computation. + */ +public class Multiply extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Multiply(Ops tf, Class type) { + this(tf, null, type); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Multiply(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Multiply(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Multiply(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + for (int i = 1; i < inputs.size(); i++) { + output = tf.math.mul(output, cast(tf, inputs.get(i), getType())); + } + return output; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java new file mode 100644 index 00000000000..74307780ce0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java @@ -0,0 +1,233 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +/** + * Rectified Linear Unit activation layer + * + *

With default values, it returns element-wise {@code max(x, 0)} + * + *

Otherwise, it follows: + * + *

+ *    f(x) = max_value if x >= max_value
+ *     f(x) = x if threshold <= x < max_value
+ *     f(x) = negative_slope * (x - threshold) otherwise
+ * 
+ * + * @param the data type for the layer's weights and computation. + */ +public class ReLU extends Layer { + public static float DEFAULT_MAX_VALUE = Float.NaN; + public static float DEFAULT_NEGATIVE_SLOPE = 0; + public static float DEFAULT_THRESHOLD = 0; + + private final float maxValue; + private final float negativeSlope; + private final float threshold; + + /** + * Creates a ReLU Layer with a unique name generated based on * {@link Class#getSimpleName()} and + * using {@link #DEFAULT_NEGATIVE_SLOPE} for the negative slope, {@link #DEFAULT_MAX_VALUE} as the + * maximum value and {@link #DEFAULT_THRESHOLD} as the threshold. + * + * @param tf the TensorFlow Ops + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public ReLU(Ops tf, Class type, Options options) { + this(tf, null, DEFAULT_NEGATIVE_SLOPE, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, options); + } + + /** + * Creates a ReLU Layer using {@link #DEFAULT_NEGATIVE_SLOPE} for the negative slope, {@link + * #DEFAULT_MAX_VALUE} as the maximum value and {@link #DEFAULT_THRESHOLD} as the threshold. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public ReLU(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_NEGATIVE_SLOPE, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, null); + } + + /** + * Creates a ReLU Layer using {@link #DEFAULT_NEGATIVE_SLOPE} for the negative slope, {@link + * #DEFAULT_MAX_VALUE} as the maximum value and {@link #DEFAULT_THRESHOLD} as the threshold. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public ReLU(Ops tf, String name, Class type, Options options) { + this(tf, name, DEFAULT_NEGATIVE_SLOPE, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, options); + } + + /** + * Creates a ReLU Layer with a unique name generated based on * {@link Class#getSimpleName()}, + * using {@link #DEFAULT_MAX_VALUE} as the maximum value and {@link #DEFAULT_THRESHOLD} as the + * threshold. + * + * @param tf the TensorFlow Ops. + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @throws IllegalArgumentException if negativeSlope is < 0 + */ + public ReLU(Ops tf, float negativeSlope, Class type) { + this(tf, null, negativeSlope, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, null); + } + + /** + * Creates a ReLU Layer with a unique name generated based on * {@link Class#getSimpleName()}, + * using {@link #DEFAULT_MAX_VALUE} as the maximum value and {@link #DEFAULT_THRESHOLD} as the + * threshold. + * + * @param tf the TensorFlow Ops. + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + * @throws IllegalArgumentException if negativeSlope is < 0 + */ + public ReLU(Ops tf, float negativeSlope, Class type, Options options) { + this(tf, null, negativeSlope, DEFAULT_MAX_VALUE, DEFAULT_THRESHOLD, type, options); + } + + /** + * Creates a ReLU Layer using a unique name will be generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param maxValue Maximum activation value. Must be >= 0. + * @param threshold Threshold value for thresholded activation. + * @param type the data type for the layer's weights and computation. + * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 + */ + public ReLU( + Ops tf, + float negativeSlope, + float maxValue, + float threshold, + Class type) { + this(tf, null, negativeSlope, maxValue, threshold, type, null); + } + + /** + * Creates a ReLU Layer using a unique name will be generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param maxValue Maximum activation value. Must be >= 0. + * @param threshold Threshold value for thresholded activation. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 + */ + public ReLU( + Ops tf, + float negativeSlope, + float maxValue, + float threshold, + Class type, + Options options) { + this(tf, null, negativeSlope, maxValue, threshold, type, options); + } + + /** + * Creates a ReLU Layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param maxValue Maximum activation value. Must be >= 0. + * @param threshold Threshold value for thresholded activation. + * @param type the data type for the layer's weights and computation. + * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 + */ + public ReLU( + Ops tf, + String name, + float negativeSlope, + float maxValue, + float threshold, + Class type) { + this(tf, name, negativeSlope, maxValue, threshold, type, null); + } + /** + * Creates a ReLU Layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param negativeSlope Negative slope coefficient. Must be >= 0. + * @param maxValue Maximum activation value. Must be >= 0. + * @param threshold Threshold value for thresholded activation. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 + */ + public ReLU( + Ops tf, + String name, + float negativeSlope, + float maxValue, + float threshold, + Class type, + Options options) { + super(tf, name, true, type, options); + if (!Float.isNaN(maxValue) && maxValue < 0) { + throw new IllegalArgumentException("maxValue must be >= 0, got " + maxValue); + } + if (negativeSlope < 0) { + throw new IllegalArgumentException("negativeSlope must be >= 0, got " + negativeSlope); + } + + this.maxValue = maxValue; + this.negativeSlope = negativeSlope; + this.threshold = threshold; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + + org.tensorflow.framework.activations.ReLU reLU = + new org.tensorflow.framework.activations.ReLU<>( + getTF(), negativeSlope, maxValue, threshold); + List> tInputs = convertList(inputs); + List> results = new ArrayList<>(); + tInputs.forEach(input -> results.add(reLU.call(input))); + return callPostProcess(convertTo(results, resultType), training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java new file mode 100644 index 00000000000..32594d72f8a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java @@ -0,0 +1,123 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Repeats the input {@code repeatCount} times. + * + * @param the data type for the layer's weights and computation. + */ +public class RepeatVector extends Layer { + + private final int repeatCount; + + /** + * Creates a RepeatCount using a unique name will be generated based on * {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param repeatCount the repetition factor. + * @param type the data type for the layer's weights and computation. + */ + public RepeatVector(Ops tf, int repeatCount, Class type) { + this(tf, null, repeatCount, type, null); + } + + + /** + * Creates a RepeatCountusing a unique name will be generated based on * {@link Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param repeatCount the repetition factor. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public RepeatVector(Ops tf, int repeatCount, Class type, Options options) { + this(tf, null, repeatCount, type, options); + } + + /** + * Creates a RepeatCount + * + * @param tf the TensorFlow Ops + * @param name he unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param repeatCount the repetition factor. + * @param type the data type for the layer's weights and computation. + */ + public RepeatVector(Ops tf, String name, int repeatCount, Class type) { + this(tf, name, repeatCount, type, null); + } + + /** + * Creates a RepeatCount + * + * @param tf the TensorFlow Ops + * @param name he unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param repeatCount the repetition factor. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public RepeatVector(Ops tf, String name, int repeatCount, Class type, Options options) { + super(tf, name, true, type, options); + this.repeatCount = repeatCount; + } + + /** + * + * @param inputs the input Operands, 2D tensor of shape (num_samples, features) + * @param masks a list of masks, one for each input, to apply to the result, may be null + * @param training whether the call is in inference mode or training mode + * @param resultType the result type + * @param the data type of the result + * @return a 3D tensor of shape (num_samples, repeatCount, features) + */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + if (inputs.isEmpty()) { + return Collections.emptyList(); + } + Ops tf = getTF(); + List> outputs = new ArrayList<>(); + for (Operand input : inputs) { + if (input.shape().numDimensions() != 2) { + throw new IllegalArgumentException("RepeatVector inputs must be rank 2."); + } + Operand output = input; + Operand one = tf.constant(1); + output = tf.expandDims(output, tf.constant(1)); + Operand pattern = tf.stack(Arrays.asList(one, tf.constant(repeatCount), one)); + output = tf.tile(output, pattern); + outputs.add(output); + } + return convertList(outputs, resultType); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Reshape.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Reshape.java new file mode 100644 index 00000000000..0641c501612 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Reshape.java @@ -0,0 +1,110 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer that reshapes inputs into the given shape. + * + * @param the data type for the layer's weights and computation. + */ +public class Reshape extends Layer { + + private final Shape targetShape; + + /** + * Creates a Reshape layer using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param targetShape the target shape. Does not include the batch size dimension. + * @param type the data type for the layer's weights and computation. + */ + public Reshape(Ops tf, Shape targetShape, Class type) { + this(tf, null, targetShape, type, null); + } + + /** + * Creates a Reshape layer using a unique name will be generated based on {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param targetShape the target shape. Does not include the batch size dimension. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Reshape(Ops tf, Shape targetShape, Class type, Options options) { + this(tf, null, targetShape, type, options); + } + + /** + * Creates a Reshape layer. + * + * @param tf the TensorFlow Ops + * @param name the name of this layer + * @param targetShape the target shape. Does not include the batch size dimension. + * @param type the data type for the layer's weights and computation. + */ + public Reshape(Ops tf, String name, Shape targetShape, Class type) { + this(tf, name, targetShape, type, null); + } + + /** + * Creates a Reshape layer. + * + * @param tf the TensorFlow Ops + * @param name the name of this layer + * @param targetShape the target shape. Does not include the batch size dimension. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Reshape(Ops tf, String name, Shape targetShape, Class type, Options options) { + super(tf, name, true, type, options); + this.targetShape = targetShape; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + if (inputs.isEmpty()) { + return Collections.emptyList(); + } + Ops tf = getTF(); + Operand input = inputs.get(0); + long batchSize = input.shape().size(0); + Shape newShape = targetShape.prepend(batchSize); + List> result = new ArrayList<>(); + Operand newShapeOp = tf.constant(newShape); + inputs.forEach(inp -> result.add(tf.reshape(cast(tf, inp, getType()), newShapeOp))); + return callPostProcess(convertList(result, resultType), training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java new file mode 100644 index 00000000000..fe954ae0177 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java @@ -0,0 +1,118 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.op.math.ReduceLogSumExp; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Softmax activation function. */ +public class Softmax extends Layer { + + private final int[] axes; + + /** + * Creates a SoftMax layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param axes axes along which the softmax normalization is applied. + * @param type the data type for the layer's weights and computation. + + */ + public Softmax(Ops tf, String name, int[] axes, Class type) { + this(tf, name, axes, type, null); + } + + /** + * Creates a SoftMax layer + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param axes axes along which the softmax normalization is applied. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + */ + public Softmax(Ops tf, String name, int[] axes, Class type, Options options) { + super(tf, name, true, type, options); + this.axes = axes; + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + // TODO mask + + List> results = new ArrayList<>(); + + for (int i = 0; i < inputs.size(); i++) { + Operand input = cast(tf, inputs.get(i), getType()); + Operand result; + if (masks != null && !masks.isEmpty()) { + // Since attention_mask is 1.0 for positions we want to attend and 0.0 for + // masked positions, this operation will create a tensor which is 0.0 for + // positions we want to attend and -1e.9 for masked positions. + Operand mask = masks.get(i); + Operand one = cast(tf, tf.constant(1), getType()); + + Operand adder = + tf.math.mul(tf.math.sub(one, cast(tf, mask, getType())), largeCompatibleNegative()); + // Since we are adding it to the raw scores before the softmax, this is + // effectively the same as removing these entirely. + + input = tf.math.add(input, adder); + } + if (axes.length > 1) { + result = + tf.math.exp( + tf.math.sub(input, ReduceLogSumExp.reduceLogSumExp(tf.scope(), input, axes, true))); + } else { + result = org.tensorflow.framework.op.nn.Softmax.softmax(tf.scope(), input, axes[0]); + } + results.add(result); + } + return callPostProcess(convertTo(results, resultType), training); + } + + /** + * Gets a large number based on the data type + * + * @return a large number based on the data type + */ + private Operand largeCompatibleNegative() { + Ops tf = getTF(); + if (getType() == TFloat16.class) { + return cast(tf, tf.constant(-0xffdc), getType()); + } else { + return cast(tf, tf.constant(-1e9), getType()); + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java new file mode 100644 index 00000000000..7f14012a465 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java @@ -0,0 +1,104 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.Merge; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Layer hat subtracts two inputs. + * + *

It takes as input a list of tensors, all of the same shape, and returns a single tensor (also + * of the same shape). + * + *

It takes as input a list of tensors of size 2, both of the same shape, and returns a single + * tensor, (inputs[0] - inputs[1]), also of the same shape. + * + * @param the data type for the layer's weights and computation. + */ +public class Subtract extends Merge { + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Subtract(Ops tf, Class type) { + this(tf, null, type, null); + } + + /** + * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. + * + * @param tf the TensorFlow Ops + * @param type the data type for the weights and computation + */ + public Subtract(Ops tf, Class type, Options options) { + this(tf, null, type, options); + } + + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Subtract(Ops tf, String name, Class type) { + this(tf, name, type, null); + } + + /** + * Creates an Add Layer + * + * @param tf the TensorFlow Ops + * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} + * @param type the data type for the weights and computation + */ + public Subtract(Ops tf, String name, Class type, Options options) { + super(tf, name, type, options); + } + + /** {@inheritDoc} */ + @Override + protected void build(List inputShapes) { + if (inputShapes.size() != 2) { + throw new IllegalArgumentException("A Subtract layer should be called on exactly 2 inputs"); + } + super.build(inputShapes); + } + + /** {@inheritDoc} */ + @Override + protected Operand mergeFunction(List> inputs) { + if (inputs.size() != 2) { + throw new IllegalArgumentException("A Subtract layer should be called on exactly 2 inputs"); + } + Ops tf = getTF(); + Operand output = cast(tf, tf.identity(inputs.get(0)), getType()); + return tf.math.sub(output, cast(tf, inputs.get(1), getType())); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java new file mode 100644 index 00000000000..fd9c458436a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Thresholded Rectified Linear Unit. + * + *

It follows:: + * + *

+ *     f(x) = x for x > theta
+ *     f(x) = 0 otherwise`
+ * 
+ * + * @param the data type for the layer's weights and computation. + */ +public class ThresholdedReLU extends Layer { + public static float DEFAULT_THETA = 1.03f; + + private final float theta; + + /** + * Creates a ThresholdedReLU Layer with a unique name generated based on * {@link + * Class#getSimpleName()} and {@link #DEFAULT_THETA} for the theta value. + * + * @param tf the TensorFlow Ops + * @param type the data type for the layer's weights and computation. + */ + public ThresholdedReLU(Ops tf, Class type) { + + this(tf, null, DEFAULT_THETA, type, null); + } + + /** + * Creates a ThresholdedReLU Layer with {@link #DEFAULT_THETA} for the theta value. + * + * @param tf the TensorFlow Ops. + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param type the data type for the layer's weights and computation. + */ + public ThresholdedReLU(Ops tf, String name, Class type) { + this(tf, name, DEFAULT_THETA, type, null); + } + + /** + * Creates a ThresholdedReLU Layer with a unique name generated based on * {@link + * Class#getSimpleName()}. + * + * @param tf the TensorFlow Ops + * @param theta Negative slope coefficient. Must be >= 0. + * @param type the data type for the layer's weights and computation. + */ + public ThresholdedReLU(Ops tf, float theta, Class type, Options options) { + this(tf, null, theta, type, options); + } + /** + * Creates a ThresholdedReLU Layer + * + * @param tf the TensorFlow Ops + * @param name the unique name for this layer. If null, a unique name will be generated based on + * {@link Class#getSimpleName()}. + * @param theta Threshold location of activation.. Must be >= 0. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options. + * @throws IllegalArgumentException if theta is *lt; 0. + */ + public ThresholdedReLU(Ops tf, String name, float theta, Class type, Options options) { + super(tf, name, true, type, options); + if (theta < 0) { + throw new IllegalArgumentException("theta must be >= 0, got " + theta); + } + this.theta = theta; + setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + Operand tTheta = cast(tf, tf.constant(theta), getType()); + List> tInputs = convertList(inputs); + List> results = new ArrayList<>(); + tInputs.forEach( + input -> + results.add(tf.math.mul(input, cast(tf, tf.math.greater(input, tTheta), getType())))); + return callPostProcess(convertTo(results, resultType), training); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java new file mode 100644 index 00000000000..b7bedc907c6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java @@ -0,0 +1,473 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers.impl; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Specifies the rank, data type and shape of every input to a layer. + * + *

These objects enable the layer to run input compatibility checks for input structure, input + * rank, input shape, and input data type. + * + *

A {@link Shape#UNKNOWN_SIZE} entry in a shape is compatible with any dimension, a {@link + * Shape#unknown()} shape is compatible with any shape. + */ +public class InputSpec { + private Class dataType; + private Shape shape; + private Integer rank; + private Integer maxRank; + private Integer minRank; + private Map axes; + private boolean allowLastAxisSqueeze; + + public InputSpec() {} + + public InputSpec(Options options) { + dataType = options.dataType; + rank = options.rank; + maxRank = options.maxRank; + minRank = options.minRank; + axes = options.axes; + allowLastAxisSqueeze = options.allowLastAxisSqueeze; + + if (options.shape != null && options.shape.numDimensions() != Shape.UNKNOWN_SIZE) { + shape = options.shape; + rank = shape.numDimensions(); + } + + if (axes != null && (rank != null || maxRank != null)) { + maxRank = rank != null ? rank : maxRank; + + Integer maxAxis = axes.keySet().stream().max(Long::compare).get(); + if (maxAxis >= maxRank) { + throw new IllegalArgumentException( + String.format( + "Axis %d is greater than the maximum allowed value: %d, %s", + maxAxis, maxRank, shape)); + } + } + } + + /** + * Returns a Shape object that matches the shape specifications. + * + *

If the InputSpec's {@link #shape} or expected {@link #rank} is defined, this method will + * return a fully or partially-known shape. Otherwise, the returned Shape is {@link + * Shape#unknown()}. + * + * @return the generated shape + */ + public Shape toShape() { + if (rank == null && shape == null) { + return Shape.unknown(); + } else if (shape != null) { + return shape; + } else { + long[] dims = new long[rank]; + Arrays.fill(dims, Shape.UNKNOWN_SIZE); + if (axes != null) { + for (Integer key : axes.keySet()) { + int dimIdx = Math.floorMod(key, rank); + dims[dimIdx] = axes.get(key); + } + } + return Shape.of(dims); + } + } + + /** + * Checks compatibility between the layer and provided inputs. + * + * @param input the input to check. + * @param layerName layer name for error message formatting. + * @param the data type for the input. + * @throws IllegalArgumentException if the provided input's shape is not compatible wiht this + * InputSpec. + */ + public void assertInputCompatibility(Operand input, String layerName) { + Shape staticShape = input.shape(); + + if (staticShape.numDimensions() != Shape.UNKNOWN_SIZE) { + if (rank != null && !isAllowLastAxisSqueeze()) { + if (staticShape.numDimensions() != rank) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected rank=%d, , found rank=%d. . Full shape received: %s", + layerName, rank, staticShape.numDimensions(), staticShape)); + } + } + if (maxRank != null) { + if (staticShape.numDimensions() > maxRank) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected max rank =%d, , found rank = %d.", + layerName, maxRank, staticShape.numDimensions())); + } + } + if (minRank != null) { + if (staticShape.numDimensions() < minRank) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected min rank =%d, found rank = %d.", + layerName, minRank, staticShape.numDimensions())); + } + } + + if (dataType != null && !dataType.equals(input.type())) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected data type = %s, found data type = %s.", + layerName, dataType.getSimpleName(), input.type().getSimpleName())); + } + + // check each axis + if (axes != null) { + axes.forEach( + (x, v) -> { + if (shape.size(x) != Shape.UNKNOWN_SIZE && shape.size(x) != v) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer: expected axis = %d of input shape to have value %d, but received input with shape %s", + layerName, x, v, staticShape)); + } + }); + } + + // Check shape. + if (shape != null) { + Shape specShape = shape; + Shape inputShape = staticShape; + if (isAllowLastAxisSqueeze()) { + if (inputShape.size(inputShape.numDimensions() - 1) == 1) { + inputShape = inputShape.take(inputShape.numDimensions() - 1); + } + if (specShape.size(specShape.numDimensions() - 1) == 1) { + specShape = specShape.take(specShape.numDimensions() - 1); + } + } + for (int i = 0; i < specShape.numDimensions(); i++) { + if (specShape.size(i) != Shape.UNKNOWN_SIZE + && inputShape.size(i) != Shape.UNKNOWN_SIZE + && specShape.size(i) != inputShape.size(i)) { + throw new IllegalArgumentException( + String.format( + "Input of %s is incompatible with the layer:: expected shape= %s, found shape = %s", + layerName, shape, staticShape)); + } + } + } + } + } + + /** + * Gets the expected Data Type of the input. + * + * @return the expected Data Type of the input. + */ + public Class getDataType() { + return dataType; + } + + /** + * Sets the expected Data Type of the input. + * + * @param dataType the expected Data Type of the input. + */ + public void setDataType(Class dataType) { + this.dataType = dataType; + } + + /** + * Gets the Dictionary mapping integer axes to a specific dimension value. + * + * @return the Dictionary mapping integer axes to a specific dimension value. + */ + public Map getAxesMap() { + return axes; + } + + /** + * Sets the Dictionary mapping integer axes to a specific dimension value. + * + * @param axes the Dictionary mapping integer axes to a specific dimension value. + */ + public void setAxesMap(Map axes) { + this.axes = axes; + } + + /** + * Gets the expected shape of the input (may include {@link Shape#UNKNOWN_SIZE} for unchecked + * axes). Includes the batch size. + * + * @return the expected shape of the input including batch size. + */ + public Shape getShape() { + return shape; + } + + /** + * Sets the expected shape of the input (may include {@link Shape#UNKNOWN_SIZE} for unchecked + * axes). Includes the batch size. + * + * @param shape the expected shape of the input including batch size. + */ + public void setShape(Shape shape) { + this.shape = shape; + } + + /** + * Gets the expected rank of the input + * + * @return the expected rank of the input + */ + public Integer getRank() { + return rank; + } + + /** + * Sets the expected rank of the input + * + * @param rank the expected rank of the input + */ + public void setRank(Integer rank) { + this.rank = rank; + } + + /** + * Gets the maximum rank of the input. + * + * @return the maximum rank of the input. + */ + public Integer getMaxRank() { + return maxRank; + } + + /** + * Sets the maximum rank of the input. + * + * @param maxRank he maximum rank of the input. + */ + public void setMaxRank(Integer maxRank) { + this.maxRank = maxRank; + } + + /** + * Gets the minimum rank of the input. + * + * @return the minimum rank of the input. + */ + public Integer getMinRank() { + return minRank; + } + + /** + * Sets the minimum rank of the input. + * + * @param minRank he maximum rank of the input. + */ + public void setMinRank(Integer minRank) { + this.minRank = minRank; + } + + /** + * Gets the allow last axis squeeze indicator for the input. If true, then allow inputs of rank + * N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the + * last axis of the spec is 1. + * + * @return the allow last axis squeeze indicator + */ + public boolean isAllowLastAxisSqueeze() { + return allowLastAxisSqueeze; + } + + /** + * Sets the allow last axis squeeze indicator for the input. If true, then allow inputs of rank + * N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the + * last axis of the spec is 1. + * + * @param allowLastAxisSqueeze the allow last axis squeeze indicator + */ + public void setAllowLastAxisSqueeze(boolean allowLastAxisSqueeze) { + this.allowLastAxisSqueeze = allowLastAxisSqueeze; + } + + /** Optional attributes for {@link InputSpec} */ + public static class Options { + + private Class dataType; + private Shape shape; + private Integer rank; + private Integer maxRank; + private Integer minRank; + private Map axes; + private boolean allowLastAxisSqueeze; + + public static Options create() { + return new Options(); + } + + /** + * Sets the expected Data Type of the input. + * + * @return this Options instance. + */ + public Options dataType(Class dataType) { + this.dataType = dataType; + return this; + } + + /** + * Sets the expected shape of the input (may include {@link Shape#UNKNOWN_SIZE} for unchecked + * axes). Includes the batch size. + * + * @return this Options instance. + */ + public Options shape(Shape shape) { + this.shape = shape; + return this; + } + + /** + * Sets the expected rank of the input + * + * @return this Options instance. + */ + public Options rank(Integer rank) { + this.rank = rank; + return this; + } + + /** + * Sets the maximum rank of the input. + * + * @return this Options instance. + */ + public Options maxRank(Integer maxRank) { + this.maxRank = maxRank; + return this; + } + + /** + * Sets the minimum rank of the input. + * + * @return this Options instance. + */ + public Options minRank(Integer minRank) { + this.minRank = minRank; + return this; + } + /** + * Sets the Dictionary mapping integer axes to a specific dimension value. + * + * @return this Options instance. + */ + public Options axesMap(Map axes) { + this.axes = axes; + return this; + } + /** + * Sets the Dictionary mapping integer axes to a specific dimension value. + * + * @return this Options instance. + */ + public Options axesMap(Integer key, Long dim) { + if (this.axes == null) { + this.axes = new HashMap<>(); + } + this.axes.put(key, dim); + return this; + } + + /** + * Sets the allow last axis squeeze indicator for the input. If true, then allow inputs of rank + * N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the + * last axis of the spec is 1. + * + * @return this Options instance. + */ + public Options allowLastAxisSqueeze(boolean allowLastAxisSqueeze) { + this.allowLastAxisSqueeze = allowLastAxisSqueeze; + return this; + } + /** + * Gets the expected Data Type of the input. + * + * @return the expected Data Type of the input. + */ + public Class getDataType() { + return dataType; + } + /** + * Gets the expected shape of the input (may include {@link Shape#UNKNOWN_SIZE} for unchecked + * axes). Includes the batch size. + * + * @return the expected shape of the input including batch size. + */ + public Shape getShape() { + return shape; + } + /** + * Gets the expected rank of the input + * + * @return the expected rank of the input + */ + public Integer getRank() { + return rank; + } + /** + * Gets the maximum rank of the input. + * + * @return the maximum rank of the input. + */ + public Integer getMaxRank() { + return maxRank; + } + /** + * Gets the minimum rank of the input. + * + * @return the minimum rank of the input. + */ + public Integer getMinRank() { + return minRank; + } + + /** + * Gets the Dictionary mapping integer axes to a specific dimension value. + * + * @return the Dictionary mapping integer axes to a specific dimension value. + */ + public Map getAxesMap() { + return axes; + } + /** + * Gets the allow last axis squeeze indicator for the input. If true, then allow inputs of rank + * N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the + * last axis of the spec is 1. + * + * @return the allow last axis squeeze indicator + */ + public boolean isAllowLastAxisSqueeze() { + return allowLastAxisSqueeze; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java new file mode 100644 index 00000000000..c78b3ccebdb --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java @@ -0,0 +1,382 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.Layer; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Generic abstract merge layer for element-wise merge functions. + * + * @param the data type for the layer's weights and computation. + */ +public abstract class Merge extends Layer { + + private boolean reshapeRequired; + + /** + * Creates a Merge base class using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops, may not be null before the first call to the {@link #call} method + * method is called. + * @param type the data type for the weights and computation + */ + protected Merge(Ops tf, Class type) { + + this(tf, null, true, type, null); + } + + /** + * Creates a Merge base class using {@link Class#getSimpleName()} for the name. + * + * @param tf the TensorFlow Ops, may not be null before the first call to the {@link #call} method + * method is called. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + protected Merge(Ops tf, Class type, Options options) { + + this(tf, null, true, type, options); + } + + /** + * Creates a Merge base class. + * + * @param tf the TensorFlow Ops, may not be null before the first call to the {@link #call} method + * method is called. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param type the data type for the weights and computation + */ + protected Merge(Ops tf, String name, Class type) { + + this(tf, name, true, type, null); + } + + /** + * Creates a Merge base class. + * + * @param tf the TensorFlow Ops, may not be null before the first call to the {@link #call} method + * method is called. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param type the data type for the weights and computation + * @param options the layer's options + */ + protected Merge(Ops tf, String name, Class type, Options options) { + + this(tf, name, true, type, options); + } + + /** + * Creates the base Layer class + * + * @param tf the TensorFlow Ops, may not be null. + * @param name the unique name for this layer, if null will use {@link Class#getSimpleName()} for + * the name. + * @param trainable whether the layer's variables should be trainable or not. + * @param type the data type for the layer's weights and computation. + * @param options the layer's options + */ + protected Merge(Ops tf, String name, boolean trainable, Class type, Options options) { + super(tf, name, trainable, type, options); + this.setSupportsMasking(true); + } + + /** {@inheritDoc} */ + @Override + public List> computeMask( + List> inputs, List> masks) { + if (masks == null || masks.isEmpty()) { + return null; + } + if (inputs.size() != masks.size()) { + throw new IllegalArgumentException("The lists inputs and masks should have the same length."); + } + + boolean allNull = true; + for (Operand m : masks) { + if (m != null) { + allNull = false; + break; + } + } + if (allNull) { + return null; + } + + final Ops tf = getTF(); + List> rMasks = + masks.stream() + .map(m -> cast(getTF(), m, TBool.class)) + .map(m -> tf.expandDims(m, tf.constant(0))) + .collect(Collectors.toList()); + + Operand concat = tf.concat(rMasks, tf.constant(0)); + Operand bool = cast(tf, concat, TBool.class); + return Collections.singletonList(tf.reduceAll(bool, tf.constant(0))); + } + + /** + * Computes the merged result + * + * @param inputs the inputs + * @return the merged result + */ + protected abstract Operand mergeFunction( + List> inputs); + + /** {@inheritDoc} */ + @Override + public List> call( + List> inputs, + List> masks, + boolean training, + Class resultType) { + Ops tf = getTF(); + + if (reshapeRequired) { + List> reshapedInputs = new ArrayList<>(); + List inputDimensions = new ArrayList<>(); + inputs.forEach(s -> inputDimensions.add(s.shape().numDimensions())); + if (!inputDimensions.contains((int) Shape.UNKNOWN_SIZE)) { + // If ranks of all inputs are available, + // we simply expand each of them at axis=1 + // until all of them have the same rank. + int maxDimension = Collections.max(inputDimensions); + for (Operand input : inputs) { + int numDims = input.shape().numDimensions(); + for (int i = numDims; i < maxDimension; i++) { + input = tf.expandDims(input, tf.constant(1)); + } + Operand tInput = cast(getTF(), input, getType()); + reshapedInputs.add(tInput); + } + Operand result = cast(tf, mergeFunction(reshapedInputs), resultType); + return Collections.singletonList(result); + + } else { + // Transpose all inputs so that batch size is the last dimension. + // (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size) + boolean transposed = false; + for (Operand input : inputs) { + Operand tInput = cast(getTF(), input, getType()); + int nDims = tInput.shape().numDimensions(); + if (nDims == Shape.UNKNOWN_SIZE) { + org.tensorflow.op.core.Shape tShape = tf.shape(tInput); + Operand batchSize = tf.shape.size(tShape, tf.constant(0)); + Operand remainderShape = + tf.shape.takeLast(tShape, tf.math.sub(tf.rank(tInput), tf.constant(1))); + Operand newShape = + tf.shape.append(remainderShape, tf.expandDims(batchSize, tf.constant(-1))); + + Operand transposedInput = + tf.reshape( + tInput, + tf.shape.append(batchSize, tf.reduceProd(remainderShape, tf.constant(0)))); + + transposedInput = tf.linalg.transpose(transposedInput, tf.constant(new int[] {1, 0})); + transposedInput = tf.reshape(transposedInput, newShape); + reshapedInputs.add(transposedInput); + transposed = true; + + } else if (nDims > 1) { + int[] perms = new int[nDims]; + for (int i = 1; i < nDims - 1; i++) { + perms[i - 1] = i; + } + perms[nDims - 1] = 0; + reshapedInputs.add(tf.linalg.transpose(tInput, tf.constant(perms))); + } else { + reshapedInputs.add(tInput); + } + } + Operand result = cast(tf, mergeFunction(reshapedInputs), resultType); + + if (transposed) { + int nDim = result.shape().numDimensions(); + if (nDim == Shape.UNKNOWN_SIZE) { + org.tensorflow.op.core.Shape rShape = tf.shape(result); + Operand batchSize = tf.shape.takeLast(rShape, tf.constant(1)); + Operand baseShape = + tf.shape.take(rShape, tf.math.sub(tf.rank(result), tf.constant(1))); + Operand newShape = tf.shape.append(batchSize, baseShape); + result = + tf.reshape( + result, + tf.concat( + Arrays.asList(tf.constant(new int[] {-1}), batchSize), tf.constant(0))); + result = tf.linalg.transpose(result, tf.constant(new int[] {1, 0})); + result = tf.reshape(result, newShape); + } else if (nDim > 1) { + int[] perms = new int[nDim]; + perms[0] = nDim - 1; + for (int i = 0; i < nDim - 1; i++) { + perms[i + 1] = i; + } + result = tf.linalg.transpose(result, tf.constant(perms)); + } + } + return callPostProcess(Collections.singletonList(result), training); + } + } else { + List> tInputs = new ArrayList<>(); + inputs.forEach(i -> tInputs.add(cast(getTF(), i, getType()))); + Operand merged = cast(tf, mergeFunction(tInputs), resultType); + + return callPostProcess(Collections.singletonList(merged), training); + } + } + + /** {@inheritDoc} */ + @Override + protected void build(List inputShapes) { + if (inputShapes == null || inputShapes.size() <= 1) { + throw new IllegalArgumentException( + String.format( + "A merge layer should be called on a list of at least 2 inputs. Got %d inputs", + inputShapes == null ? 0 : inputShapes.size())); + } + Set batchSizes = new HashSet<>(); + inputShapes.forEach(s -> batchSizes.add(s.size(0))); + if (batchSizes.size() > 1) { + throw new IllegalArgumentException( + String.format( + "Can not merge tensors with different batch sizes. Got tensors with shapes %s: ", + Arrays.toString(inputShapes.toArray()))); + } + + Shape inputShape = inputShapes.get(0); + Shape outputShape = inputShape.takeLast(inputShape.numDimensions() - 1); + Shape shape; + for (int i = 1; i < inputShape.size(); i++) { + shape = inputShapes.get(i); + outputShape = computeElementWiseOpOutputShape(outputShape, shape); + } + + Set ranks = new HashSet<>(); + inputShapes.forEach(s -> ranks.add(s.numDimensions())); + boolean hasUnknown = false; + for (Shape s : inputShapes) { + if (s.isUnknown()) { + hasUnknown = true; + break; + } + } + reshapeRequired = hasUnknown || ranks.size() > 1; + super.build(inputShapes); + } + + /** {@inheritDoc} */ + @Override + public List computeOutputShape(List inputShapes) { + Shape outputShape; + if (inputShapes.isEmpty() || inputShapes.get(0) == null) { + outputShape = Shape.of(); + } else { + Shape shape1 = inputShapes.get(0); + if (shape1.numDimensions() > 0) { + outputShape = shape1.takeLast(shape1.numDimensions() - 1); + } else { + outputShape = Shape.of(); + } + } + Shape shape; + for (int i = 1; i < inputShapes.size(); i++) { + Shape shapei = inputShapes.get(i); + if (shapei == null) { + shape = Shape.of(); + } else { + if (shapei.numDimensions() > 0) { + shape = shapei.takeLast(shapei.numDimensions() - 1); + } else { + shape = Shape.of(); + } + } + outputShape = computeElementWiseOpOutputShape(outputShape, shape); + } + + Set batchSizes = new HashSet<>(); + for (Shape s : inputShapes) { + if (s != null) { + batchSizes.add(s.size(0)); + } + } + if (batchSizes.size() == 1) { + outputShape = outputShape.prepend(batchSizes.toArray(new Long[1])[0]); + } else { + outputShape = outputShape.prepend(Shape.UNKNOWN_SIZE); + } + + return Collections.singletonList(outputShape); + } + + /** + * Computes the shape of the resultant of an element-wise operation. + * + * @param shape1 Shape of the first tensor + * @param shape2 Shape of the second tensor + * @return expected output shape when an element-wise operation is carried out on 2 tensors with + * shapes shape1 and shape2 + */ + protected Shape computeElementWiseOpOutputShape(Shape shape1, Shape shape2) { + if (shape2 == null) { + return shape1; + } + if (shape1.isUnknown() || shape2.isUnknown()) { + return Shape.unknown(); + } + if (shape1.numDimensions() < shape2.numDimensions()) { + return computeElementWiseOpOutputShape(shape2, shape1); + } + Shape outputShape = shape1.take(shape1.numDimensions() - shape2.numDimensions()); + + for (int i = shape1.numDimensions() - shape2.numDimensions(), j = 0; + j < shape2.numDimensions(); + j++, i++) { + if (shape1.size(i) == Shape.UNKNOWN_SIZE || shape2.size(i) == Shape.UNKNOWN_SIZE) { + outputShape = outputShape.append(Shape.UNKNOWN_SIZE); + } else if (shape1.size(i) == 1) { + outputShape = outputShape.append(shape2.size(j)); + } else if (shape2.size(j) == 1) { + outputShape = outputShape.append(shape1.size(i)); + } else if (shape1.size(i) != shape2.size(j)) { + throw new IllegalArgumentException( + String.format( + "Operands could not be broadcast together with shapes %s %s", shape1, shape2)); + } else { + outputShape.append(shape1.size(i)); + } + } + return outputShape; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/TensorFormat.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/TensorFormat.java new file mode 100644 index 00000000000..e3bd6a3ea49 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/TensorFormat.java @@ -0,0 +1,23 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers.impl; + +/* TODO remove after this enum is added to the api. + * PR: Created TensorFormat enum #191 + */ +public enum TensorFormat { + NCHW, + NHWC +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java new file mode 100644 index 00000000000..ef002bfb49f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java @@ -0,0 +1,213 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.constraints.Constraint; +import org.tensorflow.framework.initializers.Glorot; +import org.tensorflow.framework.initializers.Initializer; +import org.tensorflow.framework.initializers.VarianceScaling; +import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; + +public class VariableDef { + private final Ops tf; + private final String name; + private final Shape shape; + private final Initializer initializer; + private final Constraint constraint; + private final boolean trainable; + private final Variable variable; + private final Operand initOperand; + private final Class type; + + public VariableDef( + Ops tf, + String name, + Shape shape, + Initializer initializer, + Constraint constraint, + boolean trainable, + long seed, + Class type) { + this.tf = tf.withName(name); + this.type = type; + this.name = name; + this.constraint = constraint; + this.trainable = trainable; + + this.shape = shape == null ? Shape.scalar() : shape; + this.initializer = initializer == null ? getDefaultInitializer(seed) : initializer; + initOperand = this.initializer.call(tf.constant(this.shape), type); + variable = tf.variable(initOperand); + } + + public VariableDef( + Ops tf, + String name, + Variable variable, + Initializer initializer, + Constraint constraint, + boolean trainable, + long seed) { + this.tf = tf.withName(name); + this.name = name == null ? variable.toString() : name; + this.constraint = constraint; + this.trainable = trainable; + this.variable = variable; + shape = variable.shape(); + type = variable.type(); + this.initializer = initializer == null ? getDefaultInitializer(seed) : initializer; + initOperand = this.initializer.call(tf.constant(this.shape), type); + } + + /** + * Initializes the variable + * + * @return the operand that initializes this variable + */ + public Operand init() { + return assign(initOperand); + } + + /** + * Assigns a value to the variable + * + * @param value the value to assign + * @return the operand that assigns the value to this variable + */ + public Operand assign(Operand value) { + // apply constraint if it exists + Operand tValue = constraint != null ? constraint.call(value) : value; + return tf.assign(variable, tValue); + } + + /** + * Adds a value to the variable + * + * @param value the value to add + * @return the operand that adds the value to this variable + */ + public Operand assignAdd(Operand value) { + Operand add = tf.assignAdd(variable, value); + // apply constraint if it exists + return constraint != null ? tf.assign(variable, constraint.call(add)) : add; + } + + /** + * Subtracts a value from the variable + * + * @param value the value to subtract + * @return the operand that subtracts the value from this variable + */ + public Operand assignSub(Operand value) { + Operand sub = tf.assignSub(variable, value); + // apply constraint if it exists + return constraint != null ? tf.assign(variable, constraint.call(sub)) : sub; + } + + /** + * Gets the default initializer based on type + * + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and type. + * @return the default initializer + */ + @SuppressWarnings("unchecked") + private Initializer getDefaultInitializer(long seed) { + Initializer initializer; + if (TFloating.class.isAssignableFrom(type)) { + // this creates a "Casting 'new Glorot<>(...)' to 'Initializer' is redundant" warning. + // Ignored here as Glorot takes a TFloating which is a subclass of + // and is checked in the if statement above. If you remove this cast, you'll get an error. + + initializer = ( Initializer)new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); + } else { + initializer = new Zeros<>(tf); + } + return initializer; + } + + /** + * Gets the variable name + * + * @return the variable name + */ + public String getName() { + return name; + } + /** + * Gets the variable shape + * + * @return the variable shape + */ + public Shape getShape() { + return shape; + } + /** + * Gets the variable initializer + * + * @return the variable initializer + */ + public Initializer getInitializer() { + return initializer; + } + /** + * Gets the variable constraint + * + * @return the variable constraint + */ + public Constraint getConstraint() { + return constraint; + } + /** + * Gets the variable trainable indicator + * + * @return the variable trainable indicator + */ + public boolean isTrainable() { + return trainable; + } + /** + * Gets the variable + * + * @return the variable + */ + public Variable getVariable() { + return variable; + } + + /** + * Gets the variable initialization operand. + * + * @return the variable initialization operand. + */ + public Operand getInitOperand() { + return initOperand; + } + + /** + * Gets the variable data type + * + * @return the variable data tupe + */ + public Class getType() { + return type; + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ActivationTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ActivationTest.java new file mode 100644 index 00000000000..0f779567ddc --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ActivationTest.java @@ -0,0 +1,48 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.activations.ReLU; +import org.tensorflow.framework.activations.Tanh; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class ActivationTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testReLU() { + float[][] input = {{1, -2}, {3, -4}, {-1, 2}, {-3, 4}}; + float[][] expected = {{1, 0}, {3, 0}, {0, 2}, {0, 4}}; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + ReLU relu = new ReLU<>(tf); + + Activation instance = new Activation<>(tf, relu, TFloat32.class); + Operand result = instance.call(tf.constant(input), true, TFloat32.class); + + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of Tanh call method. */ + @Test + public void testCallDouble() { + double[] input = {1, -2, 3, -4, -5, 6, -7, 8}; + double[] expected = { + 0.76159416, -0.96402758, + 0.99505475, -0.9993293, + -0.9999092, 0.99998771, + -0.99999834, 0.99999977 + }; + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Tanh tanh = new Tanh<>(tf); + Activation instance = new Activation<>(tf, tanh, TFloat64.class); + Operand result = instance.call(tf.constant(input), false, TFloat64.class); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java new file mode 100644 index 00000000000..65d31e97945 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java @@ -0,0 +1,236 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +class AddTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + double[][][] x3 = { + { + {0.90545522, 0.55172128, 0.87254455, 0.1396359, 0.1538656}, + {0.04276304, 0.9315817, 0.91360492, 0.00604873, 0.04174153}, + {0.60856471, 0.37386072, 0.68937889, 0.21272655, 0.65082257}, + {0.44925012, 0.29825938, 0.20043074, 0.84906101, 0.78397795} + }, + { + {0.70855776, 0.17650269, 0.02422264, 0.84612297, 0.72450389}, + {0.05133022, 0.61175015, 0.56296539, 0.66780478, 0.63326012}, + {0.11212696, 0.50675282, 0.58170013, 0.21101392, 0.83090424}, + {0.91830915, 0.42113009, 0.49795942, 0.2814478, 0.11920788} + } + }; + double[][][] xsum = { + { + {1.86943758, 1.39738503, 1.65498872, 1.65746556, 1.09127339}, + {0.9617033, 1.49612674, 2.22012576, 1.06387551, 1.48973533}, + {1.7272264, 1.39105585, 2.08711617, 1.33690942, 1.85540083}, + {2.00503786, 0.59602925, 1.00291149, 1.85067532, 1.56199948} + }, + { + {1.40478652, 0.67829538, 1.75703778, 2.28429285, 1.54299154}, + {1.1543616, 1.40308904, 1.59185462, 1.28059728, 2.30713144}, + {1.40050067, 1.59518573, 1.89868217, 1.43297007, 2.22691399}, + {1.98149274, 1.7430799, 1.72313981, 1.40494213, 0.87391746} + } + }; + + @Test + public void testAdd() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i3 = + new Input<>( + tf, + "l3", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Add instance = new Add<>(tf, TFloat64.class); + List> resultList = + + instance.call( + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)), + TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + Operand x3Op = tf.constant(x3); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0); + TFloat64 x3Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x3Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + feedMap.put(i3.getOutput(TFloat64.class), x3Tensor); + session.evaluate(tf.constant(xsum), result, feedMap); + } + } + } + + @Test + public void testMask() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i3 = + new Input<>( + tf, + "l3", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Add instance = new Add<>(tf, TFloat64.class); + List> inputs = + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null, null); + + List> result = instance.computeMask(inputs, mask); + assertNull(result); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + Operand x3Op = tf.constant(x3); + mask = Arrays.asList(x1Op, x2Op, x3Op); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0); + TFloat64 x3Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x3Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + feedMap.put(i3.getOutput(TFloat64.class), x3Tensor); + result = instance.computeMask(inputs, mask); + Boolean[] expected = new Boolean[(int) result.get(0).size()]; + Arrays.fill(expected, true); + session.evaluate(expected, result.get(0), feedMap); + } + } + } + + @Test + public void testMaskInvalidLengths() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i3 = + new Input<>( + tf, + "l3", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Add instance = new Add<>(tf, TFloat64.class); + List> inputs = + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null); + instance.computeMask(inputs, mask); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AlphaDropoutTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AlphaDropoutTest.java new file mode 100644 index 00000000000..cc7972e97cd --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AlphaDropoutTest.java @@ -0,0 +1,53 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class AlphaDropoutTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testAlphaDropout() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + TestSession.setEpsilon(1e-4F); + Shape expectedShape = Shape.of(3, 2, 3); + float[][][] x = + new float[][][] { + {{0.14517927f, 0.2574964f, 0.2291325f}, {0.9145494f, 0.9378068f, 0.6827883f}}, + {{0.27121753f, 0.08317473f, 0.3770739f}, {0.25451255f, 0.18511271f, 0.5620538f}}, + {{0.40101776f, 0.25205433f, 0.05103926f}, {0.08764106f, 0.00593294f, 0.37244815f}} + }; + AlphaDropout instance = + new AlphaDropout<>( + tf, 0.2f, seed, TFloat32.class, Layer.Options.create().inputShape(Shape.of(3, 2, 3))); + Operand input = tf.constant(x); + + Operand result = instance.call(input, false, TFloat32.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.identity(input), result); + + float[][][] exp = { + {{-1.236160f, 0.535354f, 0.510425f}, {1.112841f, 1.133282f, 0.909145f}}, + {{0.547414f, 0.382143f, -1.236160f}, {-1.236160f, 0.471736f, -1.236160f}}, + {{0.661496f, 0.530571f, -1.236160f}, {0.386068f, 0.314254f, 0.636386f}} + }; + + Operand expected = tf.constant(exp); + result = instance.call(input, true, TFloat32.class); + + assertEquals(expectedShape, result.shape()); + + // NOTE: result can only be evaluated once, otherwise new random numbers + // will be generated and won't match the expected + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java new file mode 100644 index 00000000000..66e55a8deda --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java @@ -0,0 +1,107 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +class AverageTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + + double[][][] xavg = { + { + {0.48199118, 0.42283187, 0.39122208, 0.75891483, 0.46870389}, + {0.45947013, 0.28227252, 0.65326042, 0.52891339, 0.7239969}, + {0.55933084, 0.50859756, 0.69886864, 0.56209143, 0.60228914}, + {0.77789387, 0.14888493, 0.40124038, 0.50080716, 0.38901076} + }, + { + {0.34811438, 0.25089635, 0.86640757, 0.71908493, 0.40924383}, + {0.55151569, 0.39566945, 0.51444461, 0.30639626, 0.83693566}, + {0.64418686, 0.54421645, 0.65849102, 0.61097808, 0.69800487}, + {0.53159179, 0.66097491, 0.6125902, 0.56174716, 0.37735479} + } + }; + + @Test + public void testAverage() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Average instance = new Average<>(tf, TFloat64.class); + List> resultList = + instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xavg), result, feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java new file mode 100644 index 00000000000..fd976b74f51 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java @@ -0,0 +1,209 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +class ConcatenateTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {5.67710153e-02, 5.68608495e-01, 6.94753423e-01, 7.06106392e-01, 9.55901476e-01}, + {1.16221311e-01, 2.77955841e-01, 8.48163908e-01, 6.65887805e-01, 8.48399407e-01}, + {1.32232733e-01, 6.07996978e-01, 5.04046847e-01, 9.79583238e-02, 6.71959629e-01}, + {9.69122927e-01, 2.65313461e-01, 7.25259997e-01, 2.95230608e-02, 2.68600949e-01} + }, + { + {9.26675552e-02, 9.11034266e-01, 9.42616405e-01, 1.76616001e-01, 4.35131783e-01}, + {3.42867908e-01, 4.42621793e-02, 1.86904412e-01, 2.30573118e-05, 1.40271865e-01}, + {9.92634263e-01, 3.50624173e-01, 9.53986246e-01, 6.98818650e-01, 9.82469750e-01}, + {7.84919140e-01, 5.03811516e-01, 2.99471974e-01, 4.13124006e-01, 1.67204622e-01} + } + }; + double[][][] x2 = { + { + {0.28151136, 0.99996448, 0.94123237, 0.92673981, 0.58165141}, + {0.41634875, 0.87652871, 0.52327084, 0.60899574, 0.97460049}, + {0.77076745, 0.46439171, 0.25499671, 0.18764164, 0.13748069}, + {0.19368776, 0.11778548, 0.55451791, 0.06335824, 0.63534461} + }, + { + {0.52078045, 0.85837043, 0.44845609, 0.69742864, 0.99834278}, + {0.23162816, 0.63328557, 0.24782906, 0.37476312, 0.16915018}, + {0.96264864, 0.97704619, 0.58534633, 0.87405632, 0.4750216}, + {0.73685149, 0.13915827, 0.23992944, 0.06455061, 0.30500096} + } + }; + + double[][][] x = { + { + {5.67710153e-02, 5.68608495e-01, 6.94753423e-01, 7.06106392e-01, 9.55901476e-01}, + {1.16221311e-01, 2.77955841e-01, 8.48163908e-01, 6.65887805e-01, 8.48399407e-01}, + {1.32232733e-01, 6.07996978e-01, 5.04046847e-01, 9.79583238e-02, 6.71959629e-01}, + {9.69122927e-01, 2.65313461e-01, 7.25259997e-01, 2.95230608e-02, 2.68600949e-01}, + {2.81511360e-01, 9.99964484e-01, 9.41232373e-01, 9.26739808e-01, 5.81651412e-01}, + {4.16348754e-01, 8.76528710e-01, 5.23270835e-01, 6.08995742e-01, 9.74600488e-01}, + {7.70767447e-01, 4.64391706e-01, 2.54996707e-01, 1.87641636e-01, 1.37480691e-01}, + {1.93687759e-01, 1.17785480e-01, 5.54517906e-01, 6.33582392e-02, 6.35344611e-01} + }, + { + {9.26675552e-02, 9.11034266e-01, 9.42616405e-01, 1.76616001e-01, 4.35131783e-01}, + {3.42867908e-01, 4.42621793e-02, 1.86904412e-01, 2.30573118e-05, 1.40271865e-01}, + {9.92634263e-01, 3.50624173e-01, 9.53986246e-01, 6.98818650e-01, 9.82469750e-01}, + {7.84919140e-01, 5.03811516e-01, 2.99471974e-01, 4.13124006e-01, 1.67204622e-01}, + {5.20780455e-01, 8.58370427e-01, 4.48456095e-01, 6.97428643e-01, 9.98342781e-01}, + {2.31628161e-01, 6.33285571e-01, 2.47829057e-01, 3.74763124e-01, 1.69150184e-01}, + {9.62648639e-01, 9.77046190e-01, 5.85346335e-01, 8.74056318e-01, 4.75021602e-01}, + {7.36851488e-01, 1.39158268e-01, 2.39929436e-01, 6.45506139e-02, 3.05000963e-01} + } + }; + + @Test + public void testConcatenate() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Concatenate instance = new Concatenate<>(tf, 1, TFloat64.class); + List> resultList = + instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 8, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(x), result, feedMap); + } + } + } + + @Test + public void testMask() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Concatenate instance = new Concatenate<>(tf, TFloat64.class); + List> inputs = + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null); + + List> result = instance.computeMask(inputs, mask); + assertNull(result); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + mask = Arrays.asList(x1Op, x2Op); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + result = instance.computeMask(inputs, mask); + Boolean[] expected = new Boolean[(int) result.get(0).size()]; + Arrays.fill(expected, true); + session.evaluate(expected, result.get(0), feedMap); + } + } + } + + @Test + public void testMaskInvalidMaskSize() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Concatenate instance = new Concatenate<>(tf, TFloat64.class); + List> inputs = + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null, null); + + List> result = instance.computeMask(inputs, mask); + assertNull(result); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + mask = Arrays.asList(x1Op, x2Op); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + result = instance.computeMask(inputs, mask); + Boolean[] expected = new Boolean[(int) result.get(0).size()]; + Arrays.fill(expected, true); + session.evaluate(expected, result.get(0), feedMap); + } + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java new file mode 100644 index 00000000000..555afc85676 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java @@ -0,0 +1,625 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.constraints.MinMaxNorm; +import org.tensorflow.framework.constraints.NonNeg; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; + +import java.util.Collections; +import java.util.List; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class DenseTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + private final Random random = new Random(1001L); + + @Test + public void testShape3_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(3, 2); + int units = 3; + + Dense instance = + new Dense<>(tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + + float[][] data = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f} + }; + float[][] expected = { + {5.866010f, 8.906790f, 7.214075f}, + {5.409174f, 8.020676f, 7.272593f}, + {5.140971f, 8.056995f, 5.513147f} + }; + + List> weights = instance.getWeights(); + instance.setWeights(weights); + Operand input = tf.constant(data); + + Operand y = instance.call(input, TFloat32.class); + session.run(tf.init()); + + List computed = instance.computeOutputShape(Collections.singletonList(inputShape)); + assertEquals(1, computed.size()); + assertEquals(computed.get(0), y.shape()); + Shape expectedOutput = Shape.of(3, units); + assertEquals(expectedOutput, y.shape()); + session.evaluate(tf.constant(expected), y); + } + } + + @Test + public void testShape4_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(4, 2); + int units = 3; + + Dense instance = + new Dense<>(tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + + float[][] inputArray = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f}, + {7.716860f, 3.205337f} + }; + + List> weights = instance.getWeights(); + instance.setWeights(weights); + Operand input = tf.reshape(tf.constant(inputArray), tf.constant(inputShape)); + + Operand y = instance.call(input, TFloat32.class); + session.run(tf.init()); + List computedShapes = + instance.computeOutputShape(Collections.singletonList(input.shape())); + assertFalse(computedShapes.isEmpty()); + Shape computedShape = computedShapes.get(0); + Shape expectedShape = Shape.of(4, units); + assertEquals(expectedShape, computedShape); + assertEquals(expectedShape, y.shape()); + + float[][] expected = { + {5.866010f, 8.906790f, 7.214075f}, + {5.409174f, 8.020676f, 7.272593f}, + {5.140971f, 8.056996f, 5.513148f}, + {6.245262f, 9.327854f, 8.179358f} + }; + + session.evaluate(tf.constant(expected), y); + } + } + + @Test + public void testShapeN_N_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(Shape.UNKNOWN_SIZE, Shape.UNKNOWN_SIZE, 2); + int units = 3; + + Dense instance = + new Dense<>(tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + + Shape fullShape = Shape.of(5, 10, 2); + float[][][] data = { + { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f}, + {7.716860f, 3.205337f}, + {8.066205f, 2.362994f}, + {0.686355f, 8.934626f}, + {1.293296f, 9.073912f}, + {4.554000f, 0.347209f}, + {6.760708f, 8.464749f}, + {9.203295f, 6.147404f} + }, + { + {7.022987f, 3.022041f}, + {0.175645f, 7.057390f}, + {4.537057f, 3.270523f}, + {5.694380f, 0.481678f}, + {1.267088f, 4.573346f}, + {7.239103f, 2.671200f}, + {4.631621f, 1.366283f}, + {4.380660f, 0.902928f}, + {7.663558f, 8.725193f}, + {4.102549f, 2.243720f} + }, + { + {0.251945f, 1.804798f}, + {5.300526f, 7.791917f}, + {-0.071388f, 9.458032f}, + {7.492148f, 1.584492f}, + {6.854610f, 2.461785f}, + {4.187295f, 3.974617f}, + {-0.015711f, 1.355883f}, + {1.855492f, 7.734279f}, + {3.403170f, 7.473061f}, + {4.243813f, 6.584970f} + }, + { + {1.645227f, 0.730085f}, + {3.999032f, 5.628812f}, + {5.522727f, 3.001995f}, + {2.459637f, 9.221226f}, + {0.305633f, 9.156766f}, + {8.218584f, 7.329232f}, + {2.657161f, 3.237010f}, + {3.008971f, 7.147655f}, + {2.788105f, 2.895133f}, + {2.805755f, 3.646185f} + }, + { + {2.086996f, 5.481725f}, + {4.222548f, 4.396897f}, + {1.799221f, 7.522835f}, + {3.549520f, 9.244308f}, + {4.980303f, 0.475735f}, + {3.644282f, 0.544247f}, + {6.282454f, 8.306262f}, + {3.650939f, 1.386086f}, + {3.526051f, 1.671946f}, + {7.763572f, 6.653723f} + }, + }; + + float[][][] expected = { + { + {5.866010f, 8.906790f, 7.214075f}, + {5.409174f, 8.020676f, 7.272593f}, + {5.140971f, 8.056995f, 5.513147f}, + {6.245262f, 9.327854f, 8.179358f}, + {6.258242f, 9.272379f, 8.437642f}, + {2.918294f, 5.014478f, 1.708536f}, + {3.378673f, 5.693542f, 2.339058f}, + {3.263673f, 4.757503f, 4.651771f}, + {7.016673f, 10.908866f, 7.807479f}, + {8.083269f, 12.249319f, 10.018546f} + }, + { + {5.712371f, 8.539888f, 7.455799f}, + {2.050113f, 3.591536f, 0.978359f}, + {4.050456f, 6.154792f, 4.966178f}, + {4.093921f, 5.971835f, 5.822024f}, + {2.131000f, 3.489655f, 1.802051f}, + {5.766911f, 8.587945f, 7.634893f}, + {3.596069f, 5.328780f, 4.845973f}, + {3.294865f, 4.851680f, 4.539239f}, + {7.716053f, 11.944765f, 8.751445f}, + {3.467615f, 5.220104f, 4.409636f} + }, + { + {0.668335f, 1.127111f, 0.459879f}, + {5.816830f, 9.111765f, 6.252261f}, + {2.534012f, 4.504061f, 1.000444f}, + {5.646128f, 8.317189f, 7.767926f}, + {5.442162f, 8.099133f, 7.221718f}, + {3.999421f, 6.142959f, 4.691792f}, + {0.359459f, 0.640173f, 0.137874f}, + {3.403915f, 5.611978f, 2.756518f}, + {4.409483f, 7.045343f, 4.294413f}, + {4.751827f, 7.462863f, 5.045105f} + }, + { + {1.344244f, 2.011289f, 1.749129f}, + {4.320304f, 6.753564f, 4.688737f}, + {4.662963f, 7.018229f, 5.934030f}, + {4.230494f, 6.940253f, 3.537062f}, + {2.714058f, 4.738264f, 1.348129f}, + {7.720919f, 11.828723f, 9.155255f}, + {2.733208f, 4.244021f, 3.058378f}, + {4.046294f, 6.490631f, 3.858251f}, + {2.730931f, 4.210578f, 3.152225f}, + {2.948380f, 4.591742f, 3.255287f} + }, + { + {2.949665f, 4.755452f, 2.735502f}, + {4.139306f, 6.382794f, 4.775392f}, + {3.306999f, 5.452967f, 2.675543f}, + {4.995176f, 8.049803f, 4.643537f}, + {3.595419f, 5.249314f, 5.098118f}, + {2.684487f, 3.936022f, 3.752738f}, + {6.640594f, 10.350204f, 7.305118f}, + {2.919088f, 4.350031f, 3.854963f}, + {2.910276f, 4.362474f, 3.760896f}, + {7.219775f, 11.043336f, 8.617791f} + } + }; + + List> weights = instance.getWeights(); + instance.setWeights(weights); + Operand input = tf.constant(data); + + Operand y = instance.call(input, TFloat32.class); + session.run(tf.init()); + + List computed = instance.computeOutputShape(Collections.singletonList(fullShape)); + assertEquals(1, computed.size()); + assertEquals(computed.get(0), y.shape()); + Shape expectedOutput = Shape.of(5, 10, units); + assertEquals(expectedOutput, y.shape()); + session.evaluate(tf.constant(expected), y); + } + } + + @Test + public void testShape3_4_5_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + int units = 3; + + Shape inputShape = Shape.of(3, 4, 5, 2); + + Dense instance = + new Dense<>(tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + assertEquals("Dense", instance.getName()); + session.run(tf.init()); + + float[][][][] data = { + { + { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f}, + {7.716860f, 3.205337f}, + {8.066205f, 2.362994f}, + }, + { + {0.686355f, 8.934626f}, + {1.293296f, 9.073912f}, + {4.554000f, 0.347209f}, + {6.760708f, 8.464749f}, + {9.203295f, 6.147404f}, + }, + { + {7.022987f, 3.022041f}, + {0.175645f, 7.057390f}, + {4.537057f, 3.270523f}, + {5.694380f, 0.481678f}, + {1.267088f, 4.573346f}, + }, + { + {7.239103f, 2.671200f}, + {4.631621f, 1.366283f}, + {4.380660f, 0.902928f}, + {7.663558f, 8.725193f}, + {4.102549f, 2.243720f}, + }, + }, + { + { + {0.251945f, 1.804798f}, + {5.300526f, 7.791917f}, + {-0.071388f, 9.458032f}, + {7.492148f, 1.584492f}, + {6.854610f, 2.461785f}, + }, + { + {4.187295f, 3.974617f}, + {-0.015711f, 1.355883f}, + {1.855492f, 7.734279f}, + {3.403170f, 7.473061f}, + {4.243813f, 6.584970f}, + }, + { + {1.645227f, 0.730085f}, + {3.999032f, 5.628812f}, + {5.522727f, 3.001995f}, + {2.459637f, 9.221226f}, + {0.305633f, 9.156766f}, + }, + { + {8.218584f, 7.329232f}, + {2.657161f, 3.237010f}, + {3.008971f, 7.147655f}, + {2.788105f, 2.895133f}, + {2.805755f, 3.646185f}, + }, + }, + { + { + {2.086996f, 5.481725f}, + {4.222548f, 4.396897f}, + {1.799221f, 7.522835f}, + {3.549520f, 9.244308f}, + {4.980303f, 0.475735f}, + }, + { + {3.644282f, 0.544247f}, + {6.282454f, 8.306262f}, + {3.650939f, 1.386086f}, + {3.526051f, 1.671946f}, + {7.763572f, 6.653723f}, + }, + { + {2.367239f, 3.317834f}, + {2.330428f, 9.358873f}, + {3.638705f, 5.096712f}, + {9.156695f, 4.436713f}, + {-0.416358f, 8.118915f}, + }, + { + {6.330701f, 6.326071f}, + {4.724874f, -0.368026f}, + {3.975863f, 0.017570f}, + {3.545376f, 7.946171f}, + {-0.495031f, 7.853283f}, + } + } + }; + + float[][][][] expected = { + { + { + {5.866010f, 8.906790f, 7.214075f}, + {5.409174f, 8.020676f, 7.272593f}, + {5.140971f, 8.056995f, 5.513147f}, + {6.245262f, 9.327854f, 8.179358f}, + {6.258242f, 9.272379f, 8.437642f}, + }, + { + {2.918294f, 5.014478f, 1.708536f}, + {3.378673f, 5.693542f, 2.339058f}, + {3.263673f, 4.757503f, 4.651771f}, + {7.016673f, 10.908866f, 7.807479f}, + {8.083269f, 12.249319f, 10.018546f}, + }, + { + {5.712371f, 8.539888f, 7.455799f}, + {2.050113f, 3.591536f, 0.978359f}, + {4.050456f, 6.154792f, 4.966178f}, + {4.093921f, 5.971835f, 5.822024f}, + {2.131000f, 3.489655f, 1.802051f}, + }, + { + {5.766911f, 8.587945f, 7.634893f}, + {3.596069f, 5.328780f, 4.845973f}, + {3.294865f, 4.851680f, 4.539239f}, + {7.716053f, 11.944765f, 8.751445f}, + {3.467615f, 5.220104f, 4.409636f}, + } + }, + { + { + {0.668335f, 1.127111f, 0.459879f}, + {5.816830f, 9.111765f, 6.252261f}, + {2.534012f, 4.504061f, 1.000444f}, + {5.646128f, 8.317189f, 7.767926f}, + {5.442162f, 8.099133f, 7.221718f}, + }, + { + {3.999421f, 6.142959f, 4.691792f}, + {0.359459f, 0.640173f, 0.137874f}, + {3.403915f, 5.611978f, 2.756518f}, + {4.409483f, 7.045343f, 4.294413f}, + {4.751827f, 7.462863f, 5.045105f}, + }, + { + {1.344244f, 2.011289f, 1.749129f}, + {4.320304f, 6.753564f, 4.688737f}, + {4.662963f, 7.018229f, 5.934030f}, + {4.230494f, 6.940253f, 3.537062f}, + {2.714058f, 4.738264f, 1.348129f}, + }, + { + {7.720919f, 11.828723f, 9.155255f}, + {2.733208f, 4.244021f, 3.058378f}, + {4.046294f, 6.490631f, 3.858251f}, + {2.730931f, 4.210578f, 3.152225f}, + {2.948380f, 4.591742f, 3.255287f}, + } + }, + { + { + {2.949665f, 4.755452f, 2.735502f}, + {4.139306f, 6.382794f, 4.775392f}, + {3.306999f, 5.452967f, 2.675543f}, + {4.995176f, 8.049803f, 4.643537f}, + {3.595419f, 5.249314f, 5.098118f}, + }, + { + {2.684487f, 3.936022f, 3.752738f}, + {6.640594f, 10.350204f, 7.305118f}, + {2.919088f, 4.350031f, 3.854963f}, + {2.910276f, 4.362474f, 3.760896f}, + {7.219775f, 11.043336f, 8.617791f}, + }, + { + {2.553549f, 3.990942f, 2.773906f}, + {4.178188f, 6.876633f, 3.421808f}, + {3.924220f, 6.132985f, 4.263438f}, + {7.583528f, 11.374685f, 9.777319f}, + {1.928159f, 3.508506f, 0.499166f}, + }, + { + {6.133229f, 9.440765f, 7.129385f}, + {3.187190f, 4.583663f, 4.743713f}, + {2.771337f, 4.015369f, 4.028832f}, + {4.637676f, 7.417559f, 4.492103f}, + {1.800851f, 3.300701f, 0.389355f}, + } + } + }; + + List> weights = instance.getWeights(); + instance.setWeights(weights); + Operand input = tf.constant(data); + + Operand y = instance.call(input, TFloat32.class); + session.run(tf.init()); + + List computed = instance.computeOutputShape(Collections.singletonList(inputShape)); + assertEquals(1, computed.size()); + assertEquals(computed.get(0), y.shape()); + Shape expectedOutput = Shape.of(3, 4, 5, units); + assertEquals(expectedOutput, y.shape()); + session.evaluate(tf.constant(expected), y); + } + } + + @Test + public void testConstraintsNonNeg() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(3, 2); + int units = 3; + + NonNeg nonNeg = new NonNeg(tf); + + Dense instance = + new Dense<>( + tf, + "constraintTest", + units, + null, + true, + null, + null, + nonNeg, + nonNeg, + 1001L, + TFloat32.class, + Layer.Options.create().inputShape(inputShape)); + + float[][] data = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f} + }; + + float[][] constraintInput = { + {-1, 2, 5}, + {-2, 4, -4} + }; + float[][] constraintExpected = { + {-0, 2, 5}, + {-0, 4, -0} + }; + + float[] biasConstraintInput = { -1, 2, 5 }; + float[] biasConstraintExpected = { -0, 2, 5 }; + + Operand input = tf.constant(data); + + Operand y = instance.call(input, TFloat32.class); + // initialize variables + session.run(tf.init()); + + List> weights = instance.getWeights(); + instance.setWeights(weights); + + // Test kernel + Variable kernel = instance.getKernel(); + Operand varUpdate = instance.assign(kernel, tf.constant(constraintInput)); + session.run(varUpdate); + session.evaluate(tf.constant(constraintExpected), kernel); + + // test bias + Variable bias = instance.getBias(); + assertEquals(Shape.of(units), bias.shape()); + varUpdate = instance.assignAdd(bias, tf.constant(biasConstraintInput)); + session.run(varUpdate); + session.evaluate(tf.constant(biasConstraintExpected), bias); + + } + } + + @Test + public void testConstraintsMinMaxNorm() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + Shape inputShape = Shape.of(3, 2); + int units = 3; + + MinMaxNorm minMaxNorm = new MinMaxNorm(tf); + + Dense instance = + new Dense<>( + tf, + "constraintTest", + units, + null, + true, + null, + null, + minMaxNorm, + minMaxNorm, + 1001L, + TFloat32.class, + Layer.Options.create().inputShape(inputShape)); + + float[][] data = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f} + }; + + float[][] constraintInput = { + {1, 0.5f, 2}, + {-2, 0.75f, 0} + }; + float[][] constraintExpected = { + { 0.447214f, 0.5f, 1}, + {-0.894427f, 0.75f, 0} + }; + + float[] biasConstraintInput = { -1, 2, 5 }; + float[] biasConstraintExpected = { -0.182574f,0.365148f, 0.912871f }; + + Operand input = tf.constant(data); + + Operand y = instance.call(input, TFloat32.class); + //initialize variables + session.run(tf.init()); + + List> weights = instance.getWeights(); + instance.setWeights(weights); + + // Test kernel + Variable kernel = instance.getKernel(); + Operand varUpdate = instance.assign(kernel, tf.constant(constraintInput)); + session.run(varUpdate); + session.evaluate(tf.constant(constraintExpected), kernel); + + // test bias + Variable bias = instance.getBias(); + assertEquals(Shape.of(units), bias.shape()); + + varUpdate = instance.assignAdd(bias, tf.constant(biasConstraintInput)); + session.run(varUpdate); + session.evaluate(tf.constant(biasConstraintExpected), bias); + + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java new file mode 100644 index 00000000000..27292222814 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java @@ -0,0 +1,112 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +class DotTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][] x1 = { + {0.04867243, 0.42833055, 0.57495679, 0.04191259}, {0.48993384, 0.80122145, 0.8199583, 0.0552641} + }; + double[][] x2 = { + {0.37530763, 0.65938955, 0.69901548, 0.87864686}, + {0.79027356, 0.29017831, 0.62662979, 0.34575866} + }; + + double[][] xdot = {{0.73943388}, {1.15259719}}; + + @Test + public void testDot() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, "l1", TFloat64.class, TFloat64.class, Layer.Options.create().inputShape(Shape.of(4))); + Input i2 = + new Input<>( + tf, "l2", TFloat64.class, TFloat64.class, Layer.Options.create().inputShape(Shape.of(4))); + Dot instance = new Dot<>(tf, 1, TFloat64.class); + List> resultList = + instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 1}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xdot), result, feedMap); + } + } + } + + @Test + public void testDotNegativeAxis() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, "l1", TFloat64.class, TFloat64.class, Layer.Options.create().inputShape(Shape.of(4))); + Input i2 = + new Input<>( + tf, "l2", TFloat64.class, TFloat64.class, Layer.Options.create().inputShape(Shape.of(4))); + Dot instance = new Dot<>(tf, new int[] {-1, -1}, TFloat64.class); + List> resultList = + instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 1}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xdot), result, feedMap); + } + } + } + + @Test + public void testDotComputeOutputShape() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Dot dot = new Dot<>(tf, -1, TFloat32.class); + + List outputShapes = + dot.computeOutputShape(Arrays.asList(Shape.of(4, 5), Shape.of(4, 5))); + assertFalse(outputShapes.isEmpty()); + assertArrayEquals(new long[] {4}, outputShapes.get(0).asArray()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DropoutTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DropoutTest.java new file mode 100644 index 00000000000..cf093bd7c6d --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DropoutTest.java @@ -0,0 +1,89 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class DropoutTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testShape3_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Shape expectedShape = Shape.of(3, 2); + Operand input = + tf.constant( + new float[][] { + {1.3463433f, 7.2481093f}, + {5.4018216f, 0.6772865f}, + {3.4442706f, 0.95697135f} + }); + + Dropout instance = new Dropout<>(tf, 0.5f, seed, TFloat32.class); + + // first pass, trainable is false, so there should be no dropout + Operand result = instance.call(input, false, TFloat32.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.identity(input), result); + + Operand expected = + tf.constant( + new float[][] { + {0f, 14.496219f}, + {0f, 0f}, + {0f, 0f} + }); + + // second pass, trainable is true, so there should be dropout + result = instance.call(input, true, TFloat32.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(expected, result); + } + } + + @Test + public void testShape3_2Noise() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Shape expectedShape = Shape.of(3, 2); + Operand input = + tf.constant( + new float[][] { + {1.3463433f, 7.2481093f}, + {5.4018216f, 0.6772865f}, + {3.4442706f, 0.95697135f} + }); + + Dropout instance = new Dropout<>(tf, 0.5f, Shape.of(3, 1), seed, TFloat32.class); + + Float[] expected = new Float[] {0f, 0f, 10.803643f, 1.354573f, 0f, 0f}; + + // trainable is true, so there should be dropout + Operand result = instance.call(input, true, TFloat32.class); + assertEquals(expectedShape, result.shape()); + // Note: this can only be evaluated once, or else the result will be updated with + // new values and will not match expected. + session.evaluate(expected, result); + } + } + + @Test + public void testSupportsMasking() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Dropout instance = new Dropout<>(tf, 0.5f, seed, TFloat32.class); + assertTrue(instance.isSupportsMasking()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java new file mode 100644 index 00000000000..a0530e27e4a --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java @@ -0,0 +1,119 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class ELUTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCallAlpha0() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = 0f; + ELU instance = + new ELU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallAlpha0Point5() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = 0.5f; + ELU instance = + new ELU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallAlphaMinus1() { + + double[][][] expectedArray = { + { + {2.7085743, 8.254536, 9.754793, 1.1027353}, + {8.698364, 2.2781835, 8.608563, 1.4326588}, + {0.7584583, 5.604635, 7.3599877, 0.06365667} + }, + { + {4.8735523, 9.90222, 5.390144, 2.052634}, + {5.9165273, 0.9186602, 0.9137567, 0.5605333}, + {2.0804656, 8.537634, 6.403787, 5.8328476} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = -1.f; + ELU instance = + new ELU<>( + tf, alpha, TFloat64.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = + instance.call( + tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } +} + diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/FlattenTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/FlattenTest.java new file mode 100644 index 00000000000..0c360c6ecdf --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/FlattenTest.java @@ -0,0 +1,74 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.layers.impl.TensorFormat; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class FlattenTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCall() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape inputShape = Shape.of(1, 3, 2); + float[] a = {1F, 2F, 3F, 4F, 5F, 6F}; + Float[] expected = {1F, 2F, 3F, 4F, 5F, 6F}; + Shape expectedShape = Shape.of(1, 6); + Operand input = tf.reshape(tf.constant(a), tf.constant(inputShape)); + Flatten layer = new Flatten<>(tf, TFloat32.class); + Operand output = layer.call(input, TFloat32.class); + assertEquals(expectedShape, output.shape()); + session.evaluate(expected, output); + } + } + + @Test + public void testCallChannelsFirst() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[] a = { + 0.12911275f, + 0.16172077f, + 0.7024991f, + 0.3936557f, + 0.8216052f, + 0.04838822f, + 0.96763366f, + 0.1477106f, + 0.03416549f, + 0.40088153f + }; + Shape expectedShape = Shape.of(10, 1); + Operand input = tf.constant(a); + Flatten layer = new Flatten<>(tf, TensorFormat.NCHW, TFloat32.class); + Operand output = layer.call(input, TFloat32.class); + assertEquals(expectedShape, output.asOutput().shape()); + Operand expected = tf.expandDims(input, tf.constant(-1)); + session.evaluate(expected, output); + } + } + + /** Test of computeOutputShape method, of class Flatten. */ + @Test + public void testComputeOutputShape() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape inputShape = Shape.of(1, 3, 2); + Shape expectedShape = Shape.of(1, 6); + Flatten layer = new Flatten<>(tf, TFloat32.class); + List computedShapes = layer.computeOutputShape(Collections.singletonList(inputShape)); + assertEquals(expectedShape, computedShapes.get(0)); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java new file mode 100644 index 00000000000..be74d8c736b --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java @@ -0,0 +1,63 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class GaussianDropoutTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testShape3_2_3() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Shape expectedShape = Shape.of(3, 2, 3); + Operand input = + tf.constant( + new double[][][] { + {{3.22382299, 1.41224385, 7.265976}, {9.1436238, 6.15759347, 6.79954284}}, + {{6.41459591, 2.16451569, 4.12015256}, {2.42915398, 2.27193001, 1.09604702}}, + {{5.13626611, 4.34388458, 1.32951124}, {8.47118881, 6.70455732, 8.57420547}} + }); + + GaussianDropout instance = new GaussianDropout<>(tf, 0.5f, seed, TFloat64.class); + + // first pass, trainable is false, so there should be no dropout + Operand result = instance.call(input, false, TFloat64.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.identity(input), result); + + Operand expected = + tf.constant( + new double[][][] { + {{0.734139, 0.398936, 5.600555}, {0.817308, 6.134556, 2.074016}}, + {{3.984928, 1.046533, 2.354332}, {1.876065, 1.218514, 1.014165}}, + {{4.405925, 3.813551, 1.100304}, {4.984621, 1.846423, 2.097348}} + }); + + // second pass, trainable is true, so there should be dropout + result = instance.call(input, true, TFloat64.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(expected, result); + + } + } + + @Test + public void testSupportsMasking() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Dropout instance = new Dropout<>(tf, 0.5f, seed, TFloat32.class); + assertTrue(instance.isSupportsMasking()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java new file mode 100644 index 00000000000..eb703c94ae2 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java @@ -0,0 +1,65 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class GaussianNoiseTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testShape3_2_3() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Shape expectedShape = Shape.of(3, 2, 3); + Operand input = + tf.constant( + new double[][][] { + {{3.451546, 1.694727, 8.036768}, {9.233009, 6.462616, 7.226611}}, + {{6.644345, 2.605274, 4.224330}, {2.912649, 2.843349, 1.097698}}, + {{5.672600, 5.269178, 2.187318}, {9.298789, 7.292978, 8.849604}} + }); + + GaussianNoise instance = new GaussianNoise<>(tf, 1.f, seed, TFloat64.class); + + // first pass, trainable is false, so there should be no dropout + Operand result = instance.call(input, false, TFloat64.class); + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.identity(input), result); + + Operand expected = + tf.constant( + new double[][][] { + {{3.679269, 1.977210, 8.807560}, {9.322395, 6.767639, 7.653679}}, + {{6.874095, 3.046032, 4.328507}, {3.396144, 3.414768, 1.099349}}, + {{6.208934, 6.194471, 3.045125}, {10.126389, 7.881398, 9.125002}} + }); + + // second pass, trainable is true, so there should be noise applied + result = instance.call(input, true, TFloat64.class); + assertEquals(expectedShape, result.shape()); + // cannot evaluate more than once, else it doesn't match expected + // because of random number generation. + //session.print(result); + session.evaluate(expected, result); + } + } + + @Test + public void testSupportsMasking() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + long seed = 1001L; + Dropout instance = new Dropout<>(tf, 0.5f, seed, TFloat32.class); + assertTrue(instance.isSupportsMasking()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java new file mode 100644 index 00000000000..587e24da029 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java @@ -0,0 +1,99 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Placeholder; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TType; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +public class InputTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + void call() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] array = new float[][] {{0, 1}, {2, 3}}; + Operand input = tf.constant(array); + Input instance = new Input<>(tf, input, TFloat32.class); + List> result = + instance.call(Collections.singletonList(input), null, false, TFloat32.class); + + assertNotNull(result); + assertEquals(1, result.size()); + + session.evaluate(input, result.get(0)); + } + } + + @Test + void getOutput() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] array = new float[][] {{0, 1}, {2, 3}}; + Operand input = tf.constant(array); + Input instance = new Input<>(tf, input, TFloat32.class); + Operand result = instance.getOutput(TFloat32.class); + + session.evaluate(input, result); + } + } + + @Test + void isPlaceholder() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] array = new float[][] {{0, 1}, {2, 3}}; + Operand input = tf.constant(array); + Input instance = + new Input<>( + tf, TFloat32.class, TFloat32.class, Layer.Options.create().inputShape(input.shape())); + + assertTrue(instance.isPlaceholder()); + Operand result = instance.getOutput(TFloat32.class); + assertTrue(result instanceof Placeholder); + try (TFloat32 inputTensor = + (TFloat32) session.getGraphSession().runner().fetch(input).run().get(0)) { + Map, Tensor> feedMap = + Collections.singletonMap(result, inputTensor); + session.evaluate(tf.constant(array), tf.identity(result), feedMap); + } + } + } + + @Test + void getInputType() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] array = new float[][] {{0, 1}, {2, 3}}; + Operand input = tf.constant(array); + Input instance = new Input<>(tf, input, TFloat32.class); + Operand result = instance.getOutput(TFloat32.class); + + session.evaluate(input, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java new file mode 100644 index 00000000000..689ab1163e9 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java @@ -0,0 +1,43 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.tensorflow.framework.utils.CastHelper.cast; + +class LambdaTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCallLambda() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(3, 2); + + Lambda instance = new Lambda<>(tf, TFloat32.class); + instance.setLamda((t, y) -> t.math.mul(cast(t, t.constant(2), y.type()), y)); + + double[][] array = { + {0.41448207, 0.71509451}, {0.21307868, 0.76890945}, {0.37533432, 0.7761148} + }; + double[][] expected = new double[array.length][array[0].length]; + for (int i = 0; i < array.length; i++) { + for (int j = 0; j < array[0].length; j++) { + expected[i][j] = array[i][j] * 2; + } + } + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(array), TFloat64.class), TFloat64.class); + + assertEquals(shape, result.shape()); + session.evaluate(tf.constant(expected), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java new file mode 100644 index 00000000000..f5baa2f9574 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java @@ -0,0 +1,119 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +class LeakyReLUTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCallAlpha0() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = 0f; + LeakyReLU instance = + new LeakyReLU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallAlpha0Point5() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = 0.5f; + LeakyReLU instance = + new LeakyReLU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallAlphaMinus1() { + + double[][][] expectedArray = { + { + {2.7085743, 8.254536, 9.754793, 1.1027353}, + {8.698364, 2.2781835, 8.608563, 1.4326588}, + {0.7584583, 5.604635, 7.3599877, 0.06365667} + }, + { + {4.8735523, 9.90222, 5.390144, 2.052634}, + {5.9165273, 0.9186602, 0.9137567, 0.5605333}, + {2.0804656, 8.537634, 6.403787, 5.8328476} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float alpha = -1.f; + LeakyReLU instance = + new LeakyReLU<>( + tf, alpha, TFloat64.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call( tf.constant(inputArray) , TFloat64.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java new file mode 100644 index 00000000000..ed35ab8408d --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java @@ -0,0 +1,106 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +class MaximumTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + + double[][][] xmax = { + { + {0.82828211, 0.55677077, 0.7159566, 0.93377237, 0.61086578}, + {0.73234341, 0.39331301, 0.68069423, 0.96272026, 0.86098578}, + {0.99338463, 0.64175689, 0.74858191, 0.80589999, 0.94056888}, + {0.79376476, 0.24171677, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.56599224, 0.39567908, 0.89910993, 0.72514044, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.44414272, 0.96730508}, + {0.89191749, 0.73008498, 0.9177326, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.79244879, 0.63492784} + } + }; + + @Test + public void testAverage() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Maximum instance = new Maximum<>(tf, TFloat64.class); + List> resultList = + instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xmax), result, feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java new file mode 100644 index 00000000000..c5b7673a6d7 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java @@ -0,0 +1,106 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +class MinimumTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + + double[][][] xmin = { + { + {0.13570025, 0.28889298, 0.06648757, 0.58405729, 0.32654201}, + {0.18659685, 0.17123203, 0.62582661, 0.09510652, 0.58700802}, + {0.12527705, 0.37543824, 0.64915537, 0.31828287, 0.26400939}, + {0.76202298, 0.05605309, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.13023652, 0.10611362, 0.83370522, 0.71302943, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.16864979, 0.70656624}, + {0.39645622, 0.35834793, 0.39924944, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.33104554, 0.11978174} + } + }; + + @Test + public void testAverage() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Minimum instance = new Minimum<>(tf, TFloat64.class); + List> resultList = + instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xmin), result, feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java new file mode 100644 index 00000000000..5e3e802b371 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java @@ -0,0 +1,134 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +class MultiplyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + double[][][] x3 = { + { + {0.90545522, 0.55172128, 0.87254455, 0.1396359, 0.1538656}, + {0.04276304, 0.9315817, 0.91360492, 0.00604873, 0.04174153}, + {0.60856471, 0.37386072, 0.68937889, 0.21272655, 0.65082257}, + {0.44925012, 0.29825938, 0.20043074, 0.84906101, 0.78397795} + }, + { + {0.70855776, 0.17650269, 0.02422264, 0.84612297, 0.72450389}, + {0.05133022, 0.61175015, 0.56296539, 0.66780478, 0.63326012}, + {0.11212696, 0.50675282, 0.58170013, 0.21101392, 0.83090424}, + {0.91830915, 0.42113009, 0.49795942, 0.2814478, 0.11920788} + } + }; + double[][][] xmul = { + { + {0.10177144, 0.0887428, 0.04153505, 0.07615415, 0.03069209}, + {0.0058437, 0.06273996, 0.38919256, 0.00055383, 0.0210964}, + {0.07573484, 0.09007803, 0.33500089, 0.05456525, 0.16161162}, + {0.27173657, 0.00404111, 0.00997398, 0.05556795, 0.11125445} + }, + { + {0.05222982, 0.00741081, 0.01815711, 0.4374849, 0.04340629}, + {0.01536318, 0.06324873, 0.0969503, 0.05002163, 0.4328112}, + {0.03964879, 0.13257892, 0.21313739, 0.06077848, 0.39066002}, + {0.23341656, 0.14758288, 0.18685326, 0.07383407, 0.00906609} + } + }; + + @Test + public void testAdd() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i3 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Multiply instance = new Multiply<>(tf, TFloat64.class); + List> resultList = + instance.call( + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)), TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + Operand x3Op = tf.constant(x3); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0); + TFloat64 x3Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x3Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + feedMap.put(i3.getOutput(TFloat64.class), x3Tensor); + session.evaluate(tf.constant(xmul), result, feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java new file mode 100644 index 00000000000..2561eb2d2cd --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java @@ -0,0 +1,128 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class ReLUTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCallMaxValue10() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float maxValue = 10f; + ReLU instance = + new ReLU<>( + tf, + ReLU.DEFAULT_NEGATIVE_SLOPE, + maxValue, + ReLU.DEFAULT_THRESHOLD, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallNegativeSlope() { + + float[][][] expectedArray = { + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float negativeSlope = 0.2f; + ReLU instance = + new ReLU<>( + tf, negativeSlope, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat32.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } + + @Test + public void testCallMaxValue6() { + + double[][][] expectedArray = { + { + {2.7085743, 6., 6., 1.1027353}, + {6., 2.2781835, 6., 1.4326588}, + {0.7584583, 5.604635, 6., 0.06365667} + }, + { + {4.8735523, 6., 5.390144, 2.052634}, + {5.9165273, 0.9186602, 0.9137567, 0.5605333}, + {2.0804656, 6., 6., 5.8328476} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float maxValue = 6f; + ReLU instance = + new ReLU<>( + tf, + ReLU.DEFAULT_NEGATIVE_SLOPE, + maxValue, + ReLU.DEFAULT_THRESHOLD, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = + instance.call( + tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java new file mode 100644 index 00000000000..aee2e9a7dd4 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java @@ -0,0 +1,47 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class RepeatVectorTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testCall3_2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(3, 2); + RepeatVector instance = new RepeatVector(tf, 3, TFloat32.class); + + double[][] array = { + {0.41448207, 0.71509451}, {0.21307868, 0.76890945}, {0.37533432, 0.7761148} + }; + + Shape expectedShape = Shape.of(3, 3, 2); + + double[][][] expected = { + {{0.41448206, 0.7150945}, {0.41448206, 0.7150945}, {0.41448206, 0.7150945}}, + {{0.21307868, 0.76890945}, {0.21307868, 0.76890945}, {0.21307868, 0.76890945}}, + {{0.37533432, 0.7761148}, {0.37533432, 0.7761148}, {0.37533432, 0.7761148}} + }; + + + + Operand result = + instance.call( + tf.dtypes.cast(tf.constant(array), TFloat64.class), TFloat64.class); + + assertEquals(expectedShape, result.shape()); + session.evaluate(tf.constant(expected), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java new file mode 100644 index 00000000000..66a5f75d8f3 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java @@ -0,0 +1,121 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.*; + +class ReshapeTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + float[][][] inputArrayNN2 = { + { + {2.70857435f, 8.25453567f}, {9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f}, {8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f}, {7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f}, {5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f}, {0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f}, {6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCall43() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(4,3); + long batchSize = 2; + Reshape instance = new Reshape<>(tf, targetShape, + TFloat32.class, Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4)) ); + + Operand result = + instance.call( + tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + assertArrayEquals(targetShape.prepend(batchSize).asArray(), result.shape().asArray()); + + + } + } + + @Test + public void testCallUnknown_1() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(Shape.UNKNOWN_SIZE,1); + long batchSize = 2; + Reshape instance = new Reshape<>(tf, targetShape, + TFloat32.class, Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4)) ); + + Shape expectedShape = Shape.of(batchSize, 12, 1); + + Operand result = + instance.call( + tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); + + + } + } + + @Test + public void testCall1_Unknown_1() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(1, Shape.UNKNOWN_SIZE); + long batchSize = 2; + Reshape instance = new Reshape<>(tf, targetShape, + TFloat32.class, Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4)) ); + + Operand result = + instance.call( + tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + Shape expectedShape = Shape.of(batchSize, 1, 12); + assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); + + + } + } + + @Test + public void testCallUnknownUnknown2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(Shape.UNKNOWN_SIZE,1); + long batchSize = 2; + Reshape instance = new Reshape<>(tf, targetShape, + TFloat32.class, Layer.Options.create().inputShape(Shape.of(Shape.UNKNOWN_SIZE, Shape.UNKNOWN_SIZE, 2)) ); + + Operand result = + instance.call( + tf.dtypes.cast(tf.constant(inputArrayNN2), TFloat64.class), TFloat64.class); + + Shape expectedShape = Shape.of(batchSize, 12, 1); + assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); + + + } + } +} \ No newline at end of file diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SequentialLayersTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SequentialLayersTest.java new file mode 100644 index 00000000000..11879b85ff2 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SequentialLayersTest.java @@ -0,0 +1,67 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +import java.util.Arrays; +import java.util.List; + +public class SequentialLayersTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + /** Tests executing a thread through sequential layers. */ + @Test + public void testSequentialLayers() { + long seed = 1001L; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] inputArray = { + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f}, + {7.716860f, 3.205337f} + }; + Operand input = tf.constant(inputArray); + List> sequencedLayers = + Arrays.asList( + new Input<>(tf, input, TFloat32.class), + new Dense<>(tf, 3, seed, TFloat32.class), + new Dropout<>(tf, 0.3f, seed, TFloat32.class), + new Flatten<>(tf, TFloat32.class)); + + Operand result = input; + for (Layer layer : sequencedLayers) { + result = layer.call(result, TFloat32.class); + } + session.run(tf.init()); + float[][] expected = + new float[][] { + {0f, 12.723986f, 0f}, + {0f, 0f, 0f}, + {7.344245f, 11.509995f, 7.875926f}, + {8.921803f, 0f, 0f} + }; + + session.evaluate(tf.constant(expected), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java new file mode 100644 index 00000000000..7057951ad84 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java @@ -0,0 +1,186 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +class SubtractTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] x1 = { + { + {0.13570025, 0.55677077, 0.06648757, 0.58405729, 0.61086578}, + {0.18659685, 0.39331301, 0.68069423, 0.09510652, 0.86098578}, + {0.99338463, 0.37543824, 0.74858191, 0.31828287, 0.94056888}, + {0.76202298, 0.05605309, 0.73475366, 0.9313434, 0.48606332} + }, + { + {0.13023652, 0.39567908, 0.89910993, 0.71302943, 0.73722061}, + {0.6212917, 0.62624375, 0.8184835, 0.16864979, 0.96730508}, + {0.39645622, 0.35834793, 0.39924944, 0.90297727, 0.82857399}, + {0.70014157, 0.95498672, 0.6179583, 0.33104554, 0.11978174} + } + }; + double[][][] x2 = { + { + {0.82828211, 0.28889298, 0.7159566, 0.93377237, 0.32654201}, + {0.73234341, 0.17123203, 0.62582661, 0.96272026, 0.58700802}, + {0.12527705, 0.64175689, 0.64915537, 0.80589999, 0.26400939}, + {0.79376476, 0.24171677, 0.0677271, 0.07027092, 0.29195821} + }, + { + {0.56599224, 0.10611362, 0.83370522, 0.72514044, 0.08126704}, + {0.48173969, 0.16509515, 0.21040572, 0.44414272, 0.70656624}, + {0.89191749, 0.73008498, 0.9177326, 0.31897888, 0.56743576}, + {0.36304201, 0.36696309, 0.60722209, 0.79244879, 0.63492784} + } + }; + + double[][][] xsub = { + { + {-0.69258186, 0.26787779, -0.64946903, -0.34971508, 0.28432377}, + {-0.54574656, 0.22208098, 0.05486762, -0.86761374, 0.27397776}, + {0.86810758, -0.26631865, 0.09942654, -0.48761712, 0.67655949}, + {-0.03174178, -0.18566368, 0.66702656, 0.86107248, 0.19410511} + }, + { + {-0.43575572, 0.28956546, 0.06540471, -0.01211101, 0.65595357}, + {0.13955201, 0.4611486, 0.60807778, -0.27549293, 0.26073884}, + {-0.49546127, -0.37173705, -0.51848316, 0.58399839, 0.26113823}, + {0.33709956, 0.58802363, 0.01073621, -0.46140325, -0.5151461} + } + }; + + @Test + public void testSubtract() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Subtract instance = new Subtract<>(tf, TFloat64.class); + List> resultList = + instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + + Operand result = resultList.get(0); + + assertArrayEquals(new long[] {Shape.UNKNOWN_SIZE, 4, 5}, result.shape().asArray()); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + session.evaluate(tf.constant(xsub), result, feedMap); + } + } + } + + @Test + public void testSubtractInvalidInputsLength() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Subtract instance = new Subtract<>(tf, TFloat64.class); + + // not used, should throw exception + List> resultList = + instance.call( + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class)), TFloat64.class); + } + }); + } + + @Test + public void testMask() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Input i1 = + new Input<>( + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + Input i2 = + new Input<>( + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4, 5))); + + Subtract instance = new Subtract<>(tf, TFloat64.class); + List> inputs = + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)); + List> mask = Arrays.asList(null, null); + List> result = instance.computeMask(inputs, mask); + assertNull(result); + + Operand x1Op = tf.constant(x1); + Operand x2Op = tf.constant(x2); + mask = Arrays.asList(x1Op, x2Op); + + try (TFloat64 x1Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x1Op).run().get(0); + TFloat64 x2Tensor = + (TFloat64) session.getGraphSession().runner().fetch(x2Op).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(i1.getOutput(TFloat64.class), x1Tensor); + feedMap.put(i2.getOutput(TFloat64.class), x2Tensor); + result = instance.computeMask(inputs, mask); + Boolean[] expected = new Boolean[(int) result.get(0).size()]; + Arrays.fill(expected, true); + session.evaluate(expected, result.get(0), feedMap); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ThresholdedReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ThresholdedReLUTest.java new file mode 100644 index 00000000000..bc4ac177989 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ThresholdedReLUTest.java @@ -0,0 +1,56 @@ +package org.tensorflow.framework.layers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat64; + +class ThresholdedReLUTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + double[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } + }; + + @Test + public void testCallThetaPoint5() { + + double[][][] expectedArray = { + { + {2.7085743, 8.254536, 9.754793, 1.1027353}, + {8.698364, 2.2781835, 8.608563, 1.4326588}, + {0.7584583, 5.604635, 7.3599877, 0.} + }, + { + {4.8735523, 9.90222, 5.390144, 2.052634}, + {5.9165273, 0.9186602, 0.9137567, 0.5605333}, + {2.0804656, 8.537634, 6.403787, 5.8328476} + } + }; + + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + float theta = 0.5f; + ThresholdedReLU instance = + new ThresholdedReLU<>( + tf, theta, TFloat64.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + + Operand result = instance.call(tf.constant(inputArray), TFloat64.class); + + session.evaluate(tf.constant(expectedArray), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java new file mode 100644 index 00000000000..daee9a859c7 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java @@ -0,0 +1,72 @@ +package org.tensorflow.framework.layers.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.ndarray.Shape; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.*; + +class InputSpecTest { + + @Test + public void testAxis() { + + InputSpec instance = + new InputSpec( + InputSpec.Options.create() + .shape(Shape.of(1, Shape.UNKNOWN_SIZE, 2, 3)) + .axesMap(3, 5L) + .axesMap(2, 2L)); + + assertThrows( + java.lang.IllegalArgumentException.class, + () -> { + InputSpec instance1 = + new InputSpec( + InputSpec.Options.create() + .shape(Shape.of(1, Shape.UNKNOWN_SIZE, 2, 3)) + .axesMap(4, 5L)); + }); + + } + + @Test + public void testDefinedShape() { + Shape expected = Shape.of(1, Shape.UNKNOWN_SIZE, 2, 3); + InputSpec instance = + new InputSpec(InputSpec.Options.create().shape(expected)); + assertArrayEquals(expected.asArray(), instance.toShape().asArray()); + } + + @Test + public void testDefinedRank() { + InputSpec instance = + new InputSpec(InputSpec.Options.create().rank(5)); + long[] dims = new long[5]; + Arrays.fill(dims, Shape.UNKNOWN_SIZE); + assertArrayEquals(dims, instance.toShape().asArray()); + + instance = new InputSpec(InputSpec.Options.create().rank(0)); + dims = new long[0]; + assertArrayEquals(dims, instance.toShape().asArray()); + + instance = new InputSpec(InputSpec.Options.create().rank(3).axesMap(1,3L).axesMap(-1,2L)); + dims = new long[] {Shape.UNKNOWN_SIZE, 3, 2}; + assertArrayEquals(dims, instance.toShape().asArray()); + } + + @Test + public void testUndefinedShapes() { + InputSpec instance = + new InputSpec(InputSpec.Options.create().maxRank(5)); + Shape genShaped = instance.toShape(); + assertTrue(genShaped.isUnknown()); + + instance = + new InputSpec(InputSpec.Options.create().minRank(5).maxRank(5)); + genShaped = instance.toShape(); + assertTrue(genShaped.isUnknown()); + + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/TensorDotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/TensorDotTest.java new file mode 100644 index 00000000000..be5d2495c65 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/TensorDotTest.java @@ -0,0 +1,186 @@ +package org.tensorflow.framework.layers.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.exceptions.TFInvalidArgumentException; +import org.tensorflow.framework.op.math.TensorDot; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class TensorDotTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + float[][] aArray = { + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + }; + + float[][][] bArray = {{{2, 3, 1}}}; + + @Test + public void testInvalidShape() { + for (TestSession.Mode tfMode : tfModes) + assertThrows( + TFInvalidArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] a = new float[][] {{1, 2}, {3, 4}}; + float[][] b = new float[][] {{1, 2}, {3, 4}, {5, 6}}; + + Operand aOp = tf.constant(a); + Operand bOp = tf.constant(b); + + TensorDot.tensordot(tf.scope(), aOp, bOp, new int[] {1, 0}); + } + }); + } + + @Test + public void testInvalidDynamicShape() { + assertThrows( + TFInvalidArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(TestSession.Mode.GRAPH)) { + Ops tf = session.getTF(); + + Operand aPH = tf.placeholder(TFloat32.class); + Operand bPH = tf.placeholder(TFloat32.class); + Operand axesPH = tf.placeholder(TInt32.class); + + float[][] a = new float[][] {{1, 2}, {3, 4}}; + float[][] b = new float[][] {{1, 2}, {3, 4}, {5, 6}}; + + Operand aOp = tf.constant(a); + Operand bOp = tf.constant(b); + Operand axesOp = tf.constant(new int[] {1, 0}); + + Operand output = TensorDot.tensordot(tf.scope(), aPH, bPH, axesPH); + + try (TFloat32 aTensor = + (TFloat32) session.getGraphSession().runner().fetch(aOp).run().get(0); + TFloat32 bTensor = + (TFloat32) session.getGraphSession().runner().fetch(bOp).run().get(0); + TInt32 axesTensor = + (TInt32) session.getGraphSession().runner().fetch(axesOp).run().get(0)) { + Map, Tensor> feedMap = new HashMap<>(); + feedMap.put(aPH, aTensor); + feedMap.put(bPH, bTensor); + feedMap.put(axesPH, axesTensor); + session.run(output, feedMap); + } + } + }); + } + + @Test + public void testInvalidAxes() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float[][] a = new float[][] {{1, 2}, {3, 4}}; + float[][] b = new float[][] {{1, 2}, {3, 4}}; + + Operand aOp = tf.constant(a); + Operand bOp = tf.constant(b); + assertThrows( + IllegalArgumentException.class, () -> TensorDot.tensordot(tf.scope(), aOp, bOp, -1)); + assertThrows( + IllegalArgumentException.class, () -> TensorDot.tensordot(tf.scope(), aOp, bOp, 3)); + assertThrows( + IllegalArgumentException.class, + () -> TensorDot.tensordot(tf.scope(), aOp, bOp, new int[] {1, 0, 1})); + assertThrows( + Exception.class, () -> TensorDot.tensordot(tf.scope(), aOp, bOp, new int[] {0, 7})); + } + } + + @Test + public void testValidAxis1() { + Shape expectedShape = Shape.of(3, 1, 1); + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}}); + Operand a = tf.constant(aArray); + Operand b = tf.constant(bArray); + Operand result = TensorDot.tensordot(tf.scope(), a, b, new int[] {1, 2}); + assertEquals(expectedShape, result.shape()); + session.evaluate(expected, result); + } + } + + @Test + public void testValidAxis2() { + + Shape expectedShape = Shape.of(3, 1, 1); + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}}); + Operand a = tf.constant(aArray); + Operand b = tf.constant(bArray); + Operand result = TensorDot.tensordot(tf.scope(), a, b, new int[][] {{1}, {2}}); + assertEquals(expectedShape, result.shape()); + session.evaluate(expected, result); + } + } + + @Test + public void testValidAxis3() { + Shape expectedShape = Shape.of(3, 3, 1, 1, 3); + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand expected = + tf.constant( + new float[][][][][] { + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}} + }); + + Operand a = tf.constant(aArray); + Operand b = tf.constant(bArray); + Operand result = TensorDot.tensordot(tf.scope(), a, b, 0); + assertEquals(expectedShape, result.shape()); + session.evaluate(expected, result); + } + } + + @Test + public void testValidAxis4() { + Shape expectedShape = Shape.of(3, 3, 1, 1, 3); + // for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(TestSession.Mode.GRAPH)) { + Ops tf = session.getTF(); + Operand expected = + tf.constant( + new float[][][][][] { + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, + {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}} + }); + + Operand a = tf.constant(aArray); + Operand b = tf.constant(bArray); + Operand result = TensorDot.tensordot(tf.scope(), a, b, new int[][] {{}, {}}); + assertEquals(expectedShape, result.shape()); + session.evaluate(expected, result); + } + } +} From 84974495ea98a71427825ed6970b9547be581392 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 26 Apr 2021 17:26:50 -0400 Subject: [PATCH 24/31] Initial checkin --- .../op/nn/SoftmaxCrossEntropyWithLogits.java | 226 ------------------ 1 file changed, 226 deletions(-) delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java deleted file mode 100644 index 7d59941f27a..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java +++ /dev/null @@ -1,226 +0,0 @@ -package org.tensorflow.framework.op.nn; - -import org.tensorflow.Operand; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.core.Concat; -import org.tensorflow.op.core.Constant; -import org.tensorflow.op.core.Range; -import org.tensorflow.op.core.Rank; -import org.tensorflow.op.core.Reshape; -import org.tensorflow.op.core.Slice; -import org.tensorflow.op.dtypes.Cast; -import org.tensorflow.op.linalg.Transpose; -import org.tensorflow.op.math.Sub; -import org.tensorflow.types.TBfloat16; -import org.tensorflow.types.TFloat16; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TNumber; - -import java.util.Arrays; -import java.util.List; - -// @Operator(group = "nn") -public class SoftmaxCrossEntropyWithLogits { - - /** - * Computes softmax cross entropy between logits and labels. - * - *

Measures the probability error in discrete classification tasks in which the classes are - * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is - * labeled with one and only one label: an image can be a dog or a truck, but not both. - * - *

NOTE: - * - *

While the classes are mutually exclusive, their probabilities need not be. All that is - * required is that each row of labels is a valid probability distribution. If they - * are not, the computation of the gradient will be incorrect. - * - *

If using exclusive labels (wherein one and only one class is true at a time), - * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} - * - *

Usage: - * - *

-   *   Operand<TFloat32> logits =
-   *       tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
-   *   Operand<TFloat32> labels =
-   *       tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
-   *   Operand<TFloat32> output =
-   *       tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
-   *   // output Shape = [2]
-   *   // dataType = FLOAT (1)
-   *   // values { 0.169846, 0.824745 }
-   * 
- * - *

Backpropagation will happen into both logits and labels. To - * disallow backpropagation into labels, pass label tensors through - * tf.stopGradient before feeding it to this function. - * - * @param scope current scope - * @param labels Each vector along the class dimension should hold a valid probability - * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] - * , each row of labels[i] must be a valid probability distribution. - * @param logits Per-label activations, typically a linear output. These activation energies are - * interpreted as unnormalized log probabilities. - * @param axis The class dimension. -1 is the last dimension. - * @param the data type for the logits and return operand - * @param the data type for the labels - * @return the softmax cross entropy loss. Its type is the same as logits and its - * shape is the same as labels except that it does not have the last dimension of - * labels. - */ - @SuppressWarnings("unchecked") - @Endpoint(name = "softmaxCrossEntropyWithLogits") - public static Operand softmaxCrossEntropyWithLogits( - Scope scope, Operand labels, Operand logits, int axis) { - scope = scope.withSubScope("SoftmaxCrossEntropyWithLogits"); - axis = axis % logits.shape().numDimensions(); - if (axis < 0) { - axis += logits.shape().numDimensions(); - } - - if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { - Operand result = - softmaxCrossEntropyWithLogits( - scope, - Cast.create(scope, labels, TFloat32.class), - Cast.create(scope, logits, TFloat32.class), - axis); - return Cast.create(scope, result, logits.asOutput().type()); - } - - if (logits.asOutput().type() != labels.asOutput().type()) { - return softmaxCrossEntropyWithLogits( - scope, Cast.create(scope, labels, logits.asOutput().type()), logits, axis); - } - - Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); - Shape shape = logits.shape(); - - // Move the dim to the end if dim is not the last dimension. - if (axis != -1 && axis != logits.shape().numDimensions() - 1) { - logits = moveDimToEnd(scope, logits, axis, inputRank); - labels = moveDimToEnd(scope, labels, axis, inputRank); - } - - Operand tLabels; - if (labels.type() != logits.type()) { - tLabels = Cast.create(scope, labels, logits.type()); - } else { - // Unchecked warning checked in if statement. - tLabels = (Operand) labels; - } - - Shape inputShape = logits.shape(); - logits = flattenOuterDims(scope, logits); - tLabels = flattenOuterDims(scope, tLabels); - - org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits smax = - org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits.create(scope, logits, tLabels); - /* cannot use generic on cost, because cost may be recast later. */ - Operand cost = smax.loss(); - Operand outputShape = - Slice.create( - scope, - Constant.tensorOf(scope, inputShape), - Constant.arrayOf(scope, 0L), - Constant.arrayOf(scope, inputShape.numDimensions() - 1L)); - cost = Reshape.create(scope, cost, outputShape); - if (scope.env().isGraph() && !shape.hasUnknownDimension()) { - long[] array = shape.asArray(); - if (array == null) { - array = new long[0]; - } - long[] newArray = new long[array.length - 1]; - if (axis < 0) { - axis = shape.numDimensions() + axis; - } - for (int i = 0; i < axis; i++) { - newArray[i] = shape.size(i); - } - for (int i = axis + 1; i < shape.numDimensions(); i++) { - newArray[i - 1] = shape.size(i); - } - cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray)); - } - - return cost; - } - - /** - * Flattens logits' outer dimensions and keep its last dimension. - * - * @param scope the TensorFlow scope - * @param logits the logits - * @param the type of logits - * @return the flattened logits - */ - private static Operand flattenOuterDims(Scope scope, Operand logits) { - Operand one = Constant.scalarOf(scope, 1L); - - Shape shape = logits.shape(); - int ndims = shape.numDimensions(); - if (!shape.hasUnknownDimension()) { - long product = 1L; - boolean productValid = true; - for (int i = ndims - 2; i >= 0; i--) { - long d = shape.size(i); - if (d == Shape.UNKNOWN_SIZE) { - productValid = false; - break; - } - product *= d; - } - if (productValid) { - return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.size(-1))); - } - } - - Operand rank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); - Operand rankMinusOne = Sub.create(scope, rank, one); - - Operand lastDimSize = - Slice.create( - scope, - org.tensorflow.op.core.Shape.create(scope, logits, TInt64.class), - rankMinusOne, - one); - Operand concat = - Concat.create( - scope, - Arrays.asList(Constant.arrayOf(scope, -1L), lastDimSize), - Constant.scalarOf(scope, 0)); - return Reshape.create(scope, logits, concat); - } - - /** - * Move the dim to the end if dimIndex is not the last dimension. - * - * @param scope The TensorFlow Scope - * @param input the input to reshape - * @param dimIndex the index to move - * @param rank the number of Dimensions in the tensor - * @param the data type of the tensor. - * @param the data type of the rank - * @return the reshaped input - */ - private static Operand moveDimToEnd( - Scope scope, Operand input, int dimIndex, Operand rank) { - Class rankType = rank.asOutput().type(); - Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankType); - List> concatList = - Arrays.asList( - Range.create( - scope, Cast.create(scope, Constant.scalarOf(scope, dimIndex), rankType), one, one), - Range.create( - scope, - Cast.create(scope, Constant.scalarOf(scope, dimIndex + 1), rankType), - one, - one)); - return Transpose.create( - scope, input, Concat.create(scope, concatList, Constant.scalarOf(scope, 0))); - } -} From b95c7503b4aa909bd2301f0558bda7b314380c7c Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 26 Apr 2021 17:27:31 -0400 Subject: [PATCH 25/31] Initial checkin --- .../framework/op/linalg/MatMul.java | 268 +++++++ .../framework/op/math/ReduceLogSumExp.java | 142 ++++ .../framework/op/math/TensorDot.java | 719 ++++++++++++++++++ .../tensorflow/framework/op/nn/Softmax.java | 119 +++ 4 files changed, 1248 insertions(+) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ReduceLogSumExp.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java new file mode 100644 index 00000000000..c4843db7266 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java @@ -0,0 +1,268 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op.linalg; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.math.Conj; +import org.tensorflow.op.sparse.SparseMatMul; +import org.tensorflow.op.train.BatchMatMul; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; + +/** + * Higher level operation for matMul that does logic before calling the low leve MatMul operations. + */ +/* TODO this is a higher level of abstraction from the low level matmul +it is defined as tf.matmul and tf.linalg.matmul in python. +Should this be defined here? */ +@Operator(group = "linalg") +public class MatMul { + + /** + * Multiplies matrix a by matrix b, producing a * b + * . + * + *

The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 + * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions + * specify matching batch size. + * + *

Both matrices must be of the same type. The supported types are: TFloat16, + * TFloat32, TFloat64, TInt32. + * + *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by + * setting one of the corresponding flag to true. These are false by + * default. + * + *

A simple 2-D tensor matrix multiplication: + * + *

+   *  Operand a = tf.constant(new float[][] {{1, 2, 3}, {4, 5, 6}});
+   *  Operand b = tf.constant(new float[][] {{7, 8},{ 9, 10}, {11, 12}});
+   *  Operand c = FMWLinalgOps.matmul(tf.scope(), a, b)
+   *
+   * 
+ * + *

Note: This is matrix product, not element-wise product. + * + * @param scope the Tensorflow scope + * @param a an Operand of of type TFloat16, TFloat32, TFloat64 + * , TInt32. with a rank > 1 + * @param b an Operand with same type and rank as a. + * @param the data type of the Operands + * @return A Operand of the same type as a and b where each inner-most + * matrix is the product of the corresponding matrices in a and b. + * This is the matrix product not an element-wise product. + * @throws java.lang.IllegalArgumentException If transposeA and adjointA + * , or transposeB and adjointB are both set to `true`. + */ + @Endpoint(name = "matmul") + public static Operand matmul(Scope scope, Operand a, Operand b) { + return matmul(scope, a, b, false, false, false, false, false, false); + } + + /** + * Multiplies matrix a by matrix b, producing a * b. + *

+ * The inputs must, following any transpositions, be tensors of rank >= 2 + * where the inner 2 dimensions specify valid matrix multiplication + * dimensions, and any further outer dimensions specify matching batch size. + *

+ * Both matrices must be of the same type. The supported types are: + * TFloat16, TFloat32, TFloat64, TInt32. + *

+ * Either matrix can be transposed or adjointed (conjugated and transposed) + * on the fly by setting one of the corresponding flag to true. These are + * false by default. + *

+ *

Note: This is matrix product, not element-wise product. + *

+ * A simple 2-D tensor matrix multiplication: + *

+   * //TODO
+   * TFloat16, TFloat32, TFloat64, TInt32.
+   * with a rank > 1
+   * @param b an Operand with same type and rank as a.
+   * @param transposeA If `true`, a is transposed before multiplication.
+   * @param transposeB If `True`, b is transposed before multiplication
+   * @param  the data type of the Operands
+   * @return A Operand of the same type as a and b where each
+   * inner-most matrix is the product of the corresponding matrices in a and
+   * b. This is the
+   * matrix product not an element-wise product.
+   * @throws java.lang.IllegalArgumentException If transposeA and
+   * adjointA, or transposeB and adjointB are both set to `true`.
+   */
+  @Endpoint(name = "matmul")
+  public static  Operand matmul(
+      Scope scope, Operand a, Operand b, boolean transposeA, boolean transposeB) {
+    return matmul(scope, a, b, transposeA, transposeB, false, false, false, false);
+  }
+
+  /**
+   * Multiplies matrix a by matrix b, producing a * b.
+   * 

+ * The inputs must, following any transpositions, be tensors of rank >= 2 + * where the inner 2 dimensions specify valid matrix multiplication + * dimensions, and any further outer dimensions specify matching batch size. + *

+ * Both matrices must be of the same type. The supported types are: + * TFloat16, TFloat32, TFloat64, TInt32. + *

+ * Either matrix can be transposed or adjointed (conjugated and transposed) + * on the fly by setting one of the corresponding flag to true. These are + * false by default. + * + *

Note: This is matrix product, not element-wise product. + *

+ * A simple 2-D tensor matrix multiplication: + *

+   * //TODO
+   * TFloat16, TFloat32, TFloat64, TInt32.
+   * with a rank > 1
+   * @param b an Operand with same type and rank as a.
+   * @param transposeA If true, a is transposed before multiplication.
+   * @param transposeB If True, b is transposed before multiplication
+   * @param adjointA If true, a is conjugated and transposed before
+   * multiplication.
+   * @param adjointB If true, b is conjugated and transposed before
+   * multiplication.
+   * @param aIsSparse If true, a is treated as a sparse matrix. Notice, this
+   *       does not support org.tensorflow.framework.utils.SparseTensor, it just makes optimizations
+   *       that assume most values in a are zero.
+   * @param bIsSparse If true, b is treated as a sparse matrix. Notice, this
+   *       does not support org.tensorflow.framework.utils.SparseTensor, it just makes optimizations
+   *       that assume most values in b are zero.
+   * @param  the data type of the Operands
+   * @return A Operand of the same type as a and b where each
+   * inner-most matrix is the product of the corresponding matrices in a and
+   * b. This is the
+   * matrix product not an element-wise product.
+   * @throws java.lang.IllegalArgumentException If transposeA and
+   * adjointA, or transposeB and adjointB are both set to `true`.
+   */
+  @SuppressWarnings("unchecked")
+  @Endpoint(name = "matmul")
+  public static  Operand matmul(
+      Scope scope,
+      Operand a,
+      Operand b,
+      boolean transposeA,
+      boolean transposeB,
+      boolean adjointA,
+      boolean adjointB,
+      boolean aIsSparse,
+      boolean bIsSparse) {
+    scope = scope.withSubScope("MatMul");
+    if (transposeA && adjointA)
+      throw new IllegalArgumentException("Only one of transposeA and adjointA can be true.");
+    if (transposeB && adjointB)
+      throw new IllegalArgumentException("Only one of transposeB and adjointB can be true.");
+    if (!(TFloating.class.isAssignableFrom(a.type()) || a.type().equals(TInt32.class)))
+      throw new IllegalArgumentException(
+          String.format(
+              "Operand 'a' must be of type 'TBfloat16','TFloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s",
+              a.type().getSimpleName()));
+    if (!(TFloating.class.isAssignableFrom(a.type()) || b.type().equals(TInt32.class)))
+      throw new IllegalArgumentException(
+          String.format(
+              "Operand 'b' must be of type 'TBfloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s",
+              b.type().getSimpleName()));
+
+    Shape aShape = a.shape();
+    Shape bShape = b.shape();
+    if (aShape.numDimensions() != bShape.numDimensions())
+      throw new IllegalArgumentException(
+          String.format(
+              "Parameters 'a' and 'b' must the same rank: found a rank = %d, b rank = %d",
+              aShape.numDimensions(), bShape.numDimensions()));
+    boolean outputMayHaveNonEmptyBatchShape =
+        aShape.numDimensions() == Shape.UNKNOWN_SIZE
+            || aShape.numDimensions() > 2
+            || bShape.numDimensions() == Shape.UNKNOWN_SIZE;
+
+    if ((!aIsSparse && !bIsSparse) && outputMayHaveNonEmptyBatchShape) {
+      // BatchMatmul does not support transpose, so we conjugate the matrix and
+      // use adjoint instead. Conj() is a noop for real matrices.
+      if (transposeA) {
+        a = Conj.create(scope, a);
+        adjointA = true;
+      }
+      if (transposeB) {
+        b = Conj.create(scope, b);
+        adjointB = true;
+      }
+      Operand bT = a.type().equals(b.type()) ? (Operand) b : Cast.create(scope, b, a.type());
+      return BatchMatMul.create(
+          scope, a, bT, BatchMatMul.adjX(adjointA), BatchMatMul.adjY(adjointB));
+    }
+
+    // Neither matmul nor sparse_matmul support adjoint, so we conjugate
+    // the matrix and use transpose instead. Conj() is a noop for real
+    // matrices.
+    if (adjointA) {
+      a = Conj.create(scope, a);
+      transposeA = true;
+    }
+    if (adjointB) {
+      b = Conj.create(scope, b);
+      transposeB = true;
+    }
+
+    boolean useSparseMatmul = false;
+    if (aIsSparse || bIsSparse) {
+      useSparseMatmul =
+          (a.type().equals(TBfloat16.class) || a.type().equals(TFloat32.class))
+              && (b.type().equals(TBfloat16.class) || b.type().equals(TFloat32.class));
+    }
+    if ((a.type().equals(TBfloat16.class) || b.type().equals(TBfloat16.class))
+        && !a.type().equals(b.type())) useSparseMatmul = true;
+
+    if (useSparseMatmul) {
+      Operand result =
+          SparseMatMul.create(
+              scope,
+              a,
+              b,
+              SparseMatMul.transposeA(transposeA),
+              SparseMatMul.transposeB(transposeB),
+              SparseMatMul.aIsSparse(aIsSparse),
+              SparseMatMul.bIsSparse(bIsSparse));
+      if (a.type().equals(TFloat32.class)) return (Operand) result;
+      else return Cast.create(scope, result, a.type());
+    }
+
+    // need to cast b to Operand
+    Operand bT = a.type().equals(b.type()) ? (Operand) b : Cast.create(scope, b, a.type());
+
+    return org.tensorflow.op.linalg.MatMul.create(
+        scope, a, bT,
+            org.tensorflow.op.linalg.MatMul.transposeA(transposeA),
+            org.tensorflow.op.linalg.MatMul.transposeB(transposeB));
+  }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ReduceLogSumExp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ReduceLogSumExp.java
new file mode 100644
index 00000000000..71678344976
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ReduceLogSumExp.java
@@ -0,0 +1,142 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.op.math;
+
+import org.tensorflow.Operand;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.op.Scope;
+import org.tensorflow.op.annotation.Endpoint;
+import org.tensorflow.op.annotation.Operator;
+import org.tensorflow.op.core.*;
+import org.tensorflow.op.dtypes.Cast;
+import org.tensorflow.op.math.*;
+import org.tensorflow.types.TInt32;
+import org.tensorflow.types.TString;
+import org.tensorflow.types.family.TFloating;
+import org.tensorflow.types.family.TNumber;
+import org.tensorflow.types.family.TType;
+
+@Operator(group = "math")
+public class ReduceLogSumExp {
+
+
+
+  // TODO this method is defined in tf.math.reduce_logsumexp in TF Python.
+  /**
+   * Computes log(sum(exp(elements across dimensions of a tensor))). Reduces {@code input_tensor}
+   * along the dimensions given in {@code axes}.
+   *
+   * 

Reduces `{@code input} along the dimensions given in {@code axes}. Unless {@code keepdims} + * is true, the rank of the tensor is reduced by 1 for each of the entries in {@code axes}, which + * must be unique. If {@code keepdims} is true, the reduced dimensions are retained with length 1. + * If {@code axes} has no entries, all dimensions are reduced, and a tensor with a single element + * is returned. This function is more numerically stable than {@code log(sum(exp(input)))}. It + * avoids overflows caused by taking the exp of large inputs and underflows caused by taking the + * log of small inputs. + * + * @param input The tensor to reduce. + * @param axes The dimensions to reduce. If null, reduces all dimensions. Must be in the range + * {@link [-rank(input_tensor), rank(input_tensor)]}. + * @param keepDims If true, retains reduced dimensions with length 1. + * @return The reduced tensor. + */ + @Endpoint(name = "reduceLogSumExp") + public static Operand reduceLogSumExp( + Scope scope, Operand input, int[] axes, boolean keepDims) { + Operand reduceDims = reductionDims(scope, input, axes); + Operand rawMax = reduceMaxWithDims(scope, input, axes, keepDims, reduceDims); + Operand myMax = + StopGradient.create( + scope, + Select.create( + scope, IsFinite.create(scope, rawMax), rawMax, ZerosLike.create(scope, rawMax))); + + Operand result = + Log.create( + scope, + reduceSumWithDims( + scope, + Exp.create(scope, Sub.create(scope, input, myMax)), + axes, + keepDims, + reduceDims)); + + if (!keepDims) { + myMax = Reshape.create(scope, myMax, org.tensorflow.op.core.Shape.create(scope, result)); + } + result = Add.create(scope, result, myMax); + return mayReduceToScalar(scope, keepDims, axes, result); + } + + private static Operand reduceSumWithDims( + Scope scope, Operand input, int[] axes, boolean keepDims, Operand dims) { + return mayReduceToScalar( + scope, keepDims, axes, ReduceSum.create(scope, input, dims, ReduceSum.keepDims(keepDims))); + } + + private static Operand reduceMaxWithDims( + Scope scope, Operand input, int[] axes, boolean keepDims, Operand dims) { + return mayReduceToScalar( + scope, keepDims, axes, ReduceMax.create(scope, input, dims, ReduceMax.keepDims(keepDims))); + } + + /** + * Sets a reduction's output shape to be a scalar if possible. + * + * @return the operand, possibly reduced to a scalar. + */ + private static Operand mayReduceToScalar( + Scope scope, boolean keepDims, int[] axes, Operand output) { + + if ((output.shape().numDimensions() == Shape.UNKNOWN_SIZE + || output.shape().hasUnknownDimension()) + && !keepDims + && axes == null) { + return Reshape.create(scope, output, Constant.tensorOf(scope, Shape.scalar())); + } else { + return output; + } + } + + /** + * Reduce dimensions based on axis + * + * @param scope the TensorFlow scope + * @param input the input + * @param axes he dimensions to reduce, may be null + * @return the dimensions to be reduced. + */ + private static Operand reductionDims( + Scope scope, Operand input, int[] axes) { + if (axes != null) { + return Constant.vectorOf(scope, axes); + } + long rank = input.shape().numDimensions(); + if (rank != Shape.UNKNOWN_SIZE) { + int[] dims = new int[(int) rank]; + for (int i = 0; i < rank; i++) { + dims[i] = i; + } + return Constant.vectorOf(scope, dims); + + } else { + return Range.create( + scope, + Constant.scalarOf(scope, 0), + Rank.create(scope, input), + Constant.scalarOf(scope, 1)); + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java new file mode 100644 index 00000000000..28c430369a2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java @@ -0,0 +1,719 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op.math; + +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.framework.op.linalg.MatMul; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.*; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.linalg.Transpose; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.GreaterEqual; +import org.tensorflow.op.math.Less; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; + +// NOTE: This is implemented as tf.tensordot() in Python, and the code is in +// python/ops/math_ops.py. Is this where we want to keep this? +/** + * Implements the TensorDot (Tensor contraction of a and b along specified axes and outer product) + * operation. + */ +@Operator(group = "math") +public abstract class TensorDot { + + /** + * Transpose and reshape the input for contraction op. + * + *

This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul` using + * `array_ops.transpose` and `array_ops.reshape`. The method takes a tensor and performs the + * correct transpose and reshape operation for a given set of indices. It returns the reshaped + * tensor as well as a list of indices necessary to reshape the tensor again after matrix + * multiplication. + * + * @param the type of Operand + * @param scope the TensorFlow scope + * @param a the Tensor + * @param axis unique indices specifying valid axes of `a`. + * @param flipped whether to flip the dimensions or not + * @return A tuple (reshapedA, freeDims, freeDimsStatic) where reshapedA is a reshaped to allow + * contraction via matmul, freeDims` is a TInt32 Operand, depending on whether the shape of a + * is fully specified, and freeDimsStatic is either a list of integers and null values, or + * None, representing the inferred static shape of the free dimensions + */ + private static Object[] tensordotReshape( + Scope scope, Operand a, Operand axis, boolean flipped) { + Shape aShape = a.shape(); + + if (!aShape.hasUnknownDimension()) { // calculate using static values + long[] aShapeDims = aShape.asArray(); + if (aShapeDims == null) aShapeDims = new long[0]; + long[] aDimsIndex = new long[aShapeDims.length]; + for (int i = 0; i < aDimsIndex.length; i++) aDimsIndex[i] = i; + + // get int array from axis Operand + int[] iAxes = getIntArray(scope, axis); + // Convert negative axes to positive + for (int i = 0; i < iAxes.length; i++) + iAxes[i] = iAxes[i] >= 0 ? iAxes[i] : Math.floorMod(iAxes[i], iAxes.length); + + // convert integer axis to long axis + long[] lAxes = Arrays.stream(iAxes).mapToLong(i -> i).toArray(); + + // create list of the axes, dims, and free axes + List axesList = Arrays.stream(lAxes).boxed().collect(Collectors.toList()); + List freeList = Arrays.stream(aDimsIndex).boxed().collect(Collectors.toList()); + freeList.removeAll(axesList); + + // create array of free dims + long[] free = freeList.stream().mapToLong(i -> i).toArray(); + long[] freeDims = new long[free.length]; + for (int i = 0; i < free.length; i++) freeDims[i] = aShapeDims[(int) free[i]]; + + // Calculate the free dim by doing a reduce prod + long prodFree = 1; + for (long i : freeDims) { + prodFree *= i; + } + + // calculate the used dims by doing a reduce prod + long prodAxis = 1; + for (long i : lAxes) { + prodAxis *= aShapeDims[(int) i]; + } + + // setup the permutations array for the transpose + long[] perm = new long[freeDims.length + lAxes.length]; + Shape newShape; + if (flipped) { + System.arraycopy(lAxes, 0, perm, 0, lAxes.length); + System.arraycopy(free, 0, perm, lAxes.length, free.length); + newShape = Shape.of(prodAxis, prodFree); + } else { + System.arraycopy(free, 0, perm, 0, free.length); + System.arraycopy(lAxes, 0, perm, freeDims.length, lAxes.length); + newShape = Shape.of(prodFree, prodAxis); + } + + Operand aTrans; + long[] arrange = new long[lAxes.length]; + for (int i = 0; i < arrange.length; i++) arrange[i] = i; + + // if the permutations is not equals to the natural order of the dims, then do a transpose + if (!Arrays.equals(perm, arrange)) { + aTrans = Transpose.create(scope, a, Constant.vectorOf(scope, perm)); + } else { + aTrans = a; + } + + // reshape the final result to the new Shape, if necessary + Operand aReshaped = + aTrans.asOutput().shape().equals(newShape) + ? aTrans + : Reshape.create(scope, aTrans, Constant.vectorOf(scope, newShape.asArray())); + // return a tuple for the reshaped Operand, and Operand for the free dimensions, and a long + // array for the free dimensions + return new Object[] {aReshaped, Constant.vectorOf(scope, freeDims), freeDims}; + + } else { // calculate dynamically + + long[] freeDimsStatic = null; + Operand one = Constant.scalarOf(scope, 1); + Operand minusOne = Constant.scalarOf(scope, -1); + Operand zero = Constant.scalarOf(scope, 0); + org.tensorflow.op.core.Shape tShape = org.tensorflow.op.core.Shape.create(scope, a); + Operand axesT; + Operand freeT; + if (aShape.numDimensions() + != Shape.UNKNOWN_SIZE) { // we know the rank, but there are unknown dimensions + long[] aShapeDims = aShape.asArray(); + if (aShapeDims == null) aShapeDims = new long[0]; + + // get int array from axis Operand + int[] iAxes = getIntArray(scope, axis); + // Convert negative axes to positive + for (int i = 0; i < iAxes.length; i++) + iAxes[i] = iAxes[i] >= 0 ? iAxes[i] : Math.floorMod(iAxes[i], iAxes.length); + + // convert integer axis to long axis + long[] lAxes = Arrays.stream(iAxes).mapToLong(i -> i).toArray(); + + // create list of the axes, dims, and free axes + List axesList = Arrays.stream(lAxes).boxed().collect(Collectors.toList()); + List dimsList = Arrays.stream(aShapeDims).boxed().collect(Collectors.toList()); + List freeList = new ArrayList<>(axesList); + freeList.removeAll(dimsList); + + // create array of free dims + long[] freeDims = freeList.stream().mapToLong(i -> i).toArray(); + freeDimsStatic = freeDims; + + axesT = Constant.vectorOf(scope, iAxes); + freeT = Cast.create(scope, Constant.vectorOf(scope, freeDims), TInt32.class); + + } else { // we don't know the rank yet + Rank rank = Rank.create(scope, a); + + // convert axis to positive + axesT = + Select.create( + scope, + GreaterEqual.create(scope, axis, Constant.scalarOf(scope, 0)), + axis, + Add.create(scope, axis, rank)); + + SetDiff1d diff = + SetDiff1d.create( + scope, Range.create(scope, Constant.scalarOf(scope, 0), rank, one), axesT); + freeT = diff.out(); + } + Operand freeDims = Gather.create(scope, tShape, freeT, zero); + Operand axesDims = Gather.create(scope, tShape, axesT, zero); + Operand prodFreeDims = ReduceProd.create(scope, freeDims, minusOne); + Operand prodAxesDims = ReduceProd.create(scope, axesDims, minusOne); + Operand perm; + Operand newShape; + if (flipped) { + perm = Concat.create(scope, Arrays.asList(axesT, freeT), zero); + newShape = Stack.create(scope, Arrays.asList(prodAxesDims, prodFreeDims)); + } else { + perm = Concat.create(scope, Arrays.asList(freeT, axesT), zero); + newShape = Stack.create(scope, Arrays.asList(prodFreeDims, prodAxesDims)); + } + Operand aReshaped = Reshape.create(scope, Transpose.create(scope, a, perm), newShape); + return new Object[] {aReshaped, freeDims, freeDimsStatic}; + } + } + + /** + * Gets an int array from an Operand<TInt32> operand. + * + * @param scope the TensorFlow scope + * @param axes the Operand to fetch the values + * @return the int array from an Operand<TInt32> + */ + private static int[] getIntArray(Scope scope, Operand axes) { + List result = new ArrayList<>(); + if (scope.env().isEager()) { + axes.asTensor().scalars().forEach(s -> result.add(s.getInt())); + } else { + try (Session session = new Session((Graph) scope.env()); + TInt32 tensor = (TInt32) session.runner().fetch(axes).run().get(0)) { + tensor.scalars().forEach(s -> result.add(s.getInt())); + } + } + return result.stream().mapToInt(i -> i).toArray(); + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param scope the scope + * @param a the Operand to analyze + * @param axis the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings("unchecked") + private static Operand[] tensordotAxes( + Scope scope, Operand a, int axis) { + Shape aShape = a.asOutput().shape(); + if (axis < 0) { + throw new IllegalArgumentException("'axis' must be at least 0."); + } + int rank = aShape.numDimensions(); + Operand[] result = new Operand[2]; + if (rank != Shape.UNKNOWN_SIZE) { + if (axis > rank) { + throw new IllegalArgumentException( + String.format( + "'axis' must not be larger than the number of dimensions of tensor %s.", rank)); + } + int min = rank - axis; + int postRange = rank - min; + int[] postAxis = new int[postRange]; + for (int i = 0; i < postRange; i++) postAxis[i] = i + min; + + int[] preAxis = new int[axis]; + for (int i = 0; i < axis; i++) preAxis[i] = i; + + result[0] = Constant.vectorOf(scope, postAxis); + result[1] = Constant.vectorOf(scope, preAxis); + } else { + Rank rankT = Rank.create(scope, a); + Constant axisT = Constant.scalarOf(scope, axis); + Constant one = Constant.scalarOf(scope, 1); + Constant zero = Constant.scalarOf(scope, 0); + AssertThat assertion = + AssertThat.create( + scope, + Less.create(scope, axisT, rankT), + Arrays.asList( + Constant.scalarOf( + scope, "'axes' must not be larger than the number of dimensions of tensor "), + rankT)); + Scope scope1 = scope.withControlDependencies(Collections.singletonList(assertion)); + result[0] = Range.create(scope1, Sub.create(scope, rankT, axisT), rankT, one); + result[1] = Range.create(scope1, zero, axisT, one); + } + return result; + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param scope the scope + * @param a the Operand to analyze + * @param axes the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings({"unchecked", "unused"}) + private static Operand[] tensordotAxes( + Scope scope, Operand a, int[] axes) { + if (axes.length != 2) + throw new IllegalArgumentException( + "'axes' must have length 1 or 2, provided with " + axes.length); + int[] aAxis = new int[] {axes[0]}; + int[] bAxis = new int[] {axes[1]}; + Operand[] result = new Operand[2]; + result[0] = Constant.vectorOf(scope, aAxis); + result[1] = Constant.vectorOf(scope, bAxis); + + return result; + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param scope the scope + * @param a the Operand to analyze + * @param axes the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings({"unchecked", "unused"}) + private static Operand[] tensordotAxes( + Scope scope, Operand a, int[][] axes) { + if (axes.length != 2) + throw new IllegalArgumentException( + "'axes' must have length 1 or 2, provided with " + axes.length); + int[] aAxis = axes[0]; + int[] bAxis = axes[1]; + if (aAxis.length != bAxis.length) + throw new IllegalArgumentException( + String.format( + "Different number of contraction axes 'a' and 'b', %d != %d", + aAxis.length, bAxis.length)); + Operand[] result = new Operand[2]; + result[0] = Constant.vectorOf(scope, aAxis); + result[1] = Constant.vectorOf(scope, bAxis); + return result; + } + + /** + * Generates two sets of contraction axes for the two tensor arguments. + * + * @param scope the scope + * @param a the Operand to analyze + * @param axes the axes + * @param the data type for the Operand + * @return the contraction axes + */ + @SuppressWarnings({"unchecked", "unused"}) + private static Operand[] tensordotAxes( + Scope scope, Operand a, Operand axes) { + + Constant one = Constant.scalarOf(scope, 1); + Constant zero = Constant.scalarOf(scope, 0); + Operand[] result = new Operand[2]; + result[0] = + Slice.create( + scope, + axes, + Cast.create(scope, zero, TInt32.class), + Cast.create(scope, one, TInt32.class)); + result[1] = + Slice.create( + scope, + axes, + Cast.create(scope, one, TInt32.class), + Cast.create(scope, one, TInt32.class)); + return result; + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param scope the TensorFlow Scope + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param axis sum over the last N axes of a and the + * first N axes of b in order. If `axes=0`, computes the outer + * product between `a` and `b`. + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public static Operand tensordot( + Scope scope, Operand a, Operand b, int axis) { + + Operand[] abAxis = tensordotAxes(scope, a, axis); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + return tensordot(scope, a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param scope the TensorFlow Scope + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param axes If axes is a scalar, sum over the last N axes of a and the + * first N axes of b in order. If axes is a list, the first and second row + * contain the set of unique integers specifying axes along which the + * contraction is computed, for `a` and `b`, respectively. The number of + * axes for `a` and `b` must be equal. If `axes=0`, computes the outer + * product between `a` and `b`. + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public static Operand tensordot( + Scope scope, Operand a, Operand b, Operand axes) { + + Operand[] abAxis = tensordotAxes(scope, a, axes); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + + return tensordot(scope, a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param scope the TensorFlow Scope + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param axes the first and second row + * contain the set of unique integers specifying axes along which the + * contraction is computed, for `a` and `b`, respectively. The number of + * axes for `a` and `b` must be equal. I + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public static Operand tensordot( + Scope scope, Operand a, Operand b, int[] axes) { + + Operand[] abAxis = tensordotAxes(scope, a, axes); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + + return tensordot(scope, a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param scope the TensorFlow Scope + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param axes the first and second row + * contain the set of unique integers specifying axes along which the + * contraction is computed, for `a` and `b`, respectively. The number of + * axes for `a` and `b` must be equal. I + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @Endpoint(name = "tensordot") + public static Operand tensordot( + Scope scope, Operand a, Operand b, int[][] axes) { + + Operand[] abAxis = tensordotAxes(scope, a, axes); + Operand aAxis = abAxis[0]; + Operand bAxis = abAxis[1]; + + return tensordot(scope, a, b, aAxis, bAxis); + } + + /** + * Tensor contraction of a and b along specified axes and outer product. + *

+ * Tensordot (also known as tensor contraction) sums the product of elements + * from a and b` over the indices specified by + * a_axes and b_axes. The lists + * a_axes and b_axes specify those pairs of axes + * along which to contract the tensors. The axis a_axes[i] of + * a must have the same dimension as axis + * b_axes[i] of b for all i in + * range(0, len(a_axes)). The lists + * a_axes and b_axes must have identical length + * and consist of unique integers that specify valid axes for each of the + * tensors. Additionally outer product is supported by passing + * axes=0. + *

+ * This operation corresponds to numpy.tensordot(a, b, axes). + *

+ * Example 1: When a and b are matrices (order 2), + * the case axes = 1 is equivalent to matrix multiplication. + *

+ * Example 2: When a and`b are matrices (order 2), + * the case + * axes = [[1], [0]] is equivalent to matrix multiplication. + *

+ * Example 3: When a and b are matrices (order 2), + * the case axes=0 gives the outer product, a tensor of order + * 4. + *

+ * Example 4: Suppose that aijk and blmn + * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor + * cjklm whose entry corresponding to the indices + * (j,k,l,m) is given by: + *

+ * cjklm = Σi aijk + * blmi . + *

+ * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). + *

+ * + * @param scope the TensorFlow Scope + * @param a `Operand` of type `float32` or `float64`. + * @param b `Operand` with the same type as `a`. + * @param aAxis axes for the a Operand + * @param bAxis axes for the b Operand + * @param the datatype of the Operands, must be either TFloat32 or + * TFloat64 + * @return A `Operand` with the same type as `a`. + * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type + */ + @SuppressWarnings({"unchecked", "unused"}) + @Endpoint(name = "tensordot") + public static Operand tensordot( + Scope scope, Operand a, Operand b, Operand aAxis, Operand bAxis) { + + if (a.type().equals(TBfloat16.class) || a.type().equals(TFloat16.class)) { + throw new IllegalArgumentException( + String.format( + "Operand 'a' must be either TFloat32 or TFloat64 DataType, 'a' is a %s DataType", + a.type().getSimpleName())); + } + if (!a.type().equals(b.type())) { + throw new IllegalArgumentException( + String.format( + "Operands a and b must be the same data type, a is %s DataType, b is %s DataType", + a.type().getSimpleName(), b.type().getSimpleName())); + } + + // first result is Operand, second result is Operand, third result is long[] and it is + // ignored here. + Object[] aResult = tensordotReshape(scope, a, aAxis, false); + Operand reshapedA = (Operand) aResult[0]; + Operand aFreeDims = (Operand) aResult[1]; + long[] aFreeDimsStatic = (long[]) aResult[2]; + + // first result is Operand, second result is Operand, third result is long[] and it is + // ignored here. + Object[] bResult = tensordotReshape(scope, b, bAxis, true); + Operand reshapedB = (Operand) bResult[0]; + Operand bFreeDims = (Operand) bResult[1]; + long[] bFreeDimsStatic = (long[]) bResult[2]; + + Operand abMatmul = MatMul.matmul(scope, reshapedA, reshapedB); + long[] abDimsStatic = new long[aFreeDimsStatic.length + bFreeDimsStatic.length]; + System.arraycopy(aFreeDimsStatic, 0, abDimsStatic, 0, aFreeDimsStatic.length); + System.arraycopy( + bFreeDimsStatic, 0, abDimsStatic, aFreeDimsStatic.length, bFreeDimsStatic.length); + if (!abMatmul.shape().hasUnknownDimension() + && abMatmul.shape().equals(Shape.of(abDimsStatic))) { + return abMatmul; + } else { + return Reshape.create(scope, abMatmul, Constant.vectorOf(scope, abDimsStatic)); + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java new file mode 100644 index 00000000000..ade4cbb9166 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java @@ -0,0 +1,119 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op.nn; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.*; +import org.tensorflow.op.linalg.Transpose; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; + +import java.util.Arrays; + +/** + * Higher level operation for Softmax. This class will move the desired axis to the last axis, if + * necessary, before calling the low level tf.nn.softmax method. + */ + +/* TODO this is a higher level of abstraction from the low level softmax +this method is defined as tf.nn.softmax in TF Python +Should this be defined here? */ +@Operator(group = "nn") +public class Softmax { + + /** + * Calculates a Softmax operation. If the exis is not the last dimension, then the input axis is + * moved to the last axis berfore calling tf.nn.softmax, then restored before returning. + * + * @param scope The TensorFlow scope + * @param input the input + * @param axis the axis + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + @Endpoint(name = "softmax") + public static Operand softmax(Scope scope, Operand input, int axis) { + Shape shape = input.shape(); + boolean isLastDim = axis == -1 || axis == shape.numDimensions() - 1; + if (isLastDim) { + return org.tensorflow.op.nn.Softmax.create(scope, input); + } + + if (axis <= -shape.numDimensions() || axis >= shape.numDimensions()) { + throw new IllegalArgumentException( + String.format( + "Axis (%d) must be in the range [%d, %d] where %d is the number of dimensions in the input.", + axis, -shape.numDimensions(), shape.numDimensions(), shape.numDimensions())); + } + + int dim = Math.floorMod(axis, shape.numDimensions()); + Operand rank = Rank.create(scope, input); + Operand dimOp = Constant.scalarOf(scope, dim); + Operand one = Constant.scalarOf(scope, 1); + Operand lastIndex = Sub.create(scope, rank, one); + Operand swappedInputs = swapAxis(scope, input, dimOp, lastIndex); + Operand output = org.tensorflow.op.nn.Softmax.create(scope, swappedInputs); + return fixOutput(scope, output, shape, dimOp, lastIndex); + } + + /** + * Restores the specified axis, then reshapes the input to the provided shaoe. + * + * @param scope The TensorFlow scope + * @param output the output + * @param shape the desired shape + * @param dim the dimension to move + * @return the restored output based on the dimension and shape. + */ + private static Operand fixOutput( + Scope scope, Operand output, Shape shape, Operand dim, Operand lastIndex) { + + Operand result = swapAxis(scope, output, dim, lastIndex); + return Reshape.create(scope, result, Constant.tensorOf(scope, shape)); + } + + /** + * Moves the specified Axis to the last axis + * + * @param input the input + * @param dim the dimension to move + * @param lastIndex the last dimension + * @return input with the dimension swapped to the last dimension + */ + private static Operand swapAxis( + Scope scope, Operand input, Operand dim, Operand lastIndex) { + + Operand zero = Constant.scalarOf(scope, 0); + Operand one = Constant.scalarOf(scope, 1); + return Transpose.create( + scope, + input, + Concat.create( + scope, + Arrays.asList( + Range.create(scope, zero, dim, one), + Range.create(scope, Add.create(scope, dim, one), lastIndex, one), + dim), + zero)); + } +} From bdcbb211ea086ab6554e28328da2acd4e9822055 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Sat, 1 May 2021 16:41:43 -0400 Subject: [PATCH 26/31] Initial checkin --- .../org/tensorflow/framework/layers/Dot.java | 10 +- .../org/tensorflow/framework/op/SetsOps.java | 161 ++++ .../framework/op/linalg/MatMul.java | 268 ------- .../framework/op/math/ReduceLogSumExp.java | 142 ---- .../framework/op/math/TensorDot.java | 719 ------------------ .../tensorflow/framework/op/nn/Softmax.java | 119 --- .../op/nn/SoftmaxCrossEntropyWithLogits.java | 226 ++++++ 7 files changed, 394 insertions(+), 1251 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ReduceLogSumExp.java delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java delete mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java index e5685708c30..5ad01445558 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java @@ -17,6 +17,7 @@ import org.tensorflow.Operand; import org.tensorflow.framework.layers.impl.Merge; import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Squeeze; @@ -300,8 +301,9 @@ protected Operand mergeFunction(List> inputs) { } } if (normalize) { - input1 = Losses.l2Normalize(tf, input1, new int[] {axes[0]}); - input2 = Losses.l2Normalize(tf, input2, new int[] {axes[0]}); + FrameworkOps fops = FrameworkOps.create(tf); + input1 = fops.math.l2Normalize(input1, new int[] {axes[0]}); + input2 = fops.math.l2Normalize(input2, new int[] {axes[0]}); } return batchDot(input1, input2, newAxes); } @@ -378,6 +380,7 @@ public List computeOutputShape(List inputShapes) { private Operand batchDot( Operand x, Operand y, int[] dotAxes) { Ops tf = getTF(); + FrameworkOps fops = FrameworkOps.create(tf); Operand tX = cast(tf, x, getType()); Operand tY = cast(tf, y, getType()); @@ -515,7 +518,8 @@ private Operand batchDot( ySquashed = true; } - Operand result = org.tensorflow.framework.op.linalg.MatMul.matmul(getTF().scope(), tX, tY); + + Operand result = fops.linalg.matmul(tX, tY); boolean doReshape = false; Operand outputShape = tf.shape(result, TInt64.class); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java new file mode 100644 index 00000000000..d7833cdbb06 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/SetsOps.java @@ -0,0 +1,161 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op; + +import org.tensorflow.Operand; +import org.tensorflow.op.Scope; +import org.tensorflow.op.SparseOps; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.sparse.DenseToDenseSetOperation; +import org.tensorflow.op.sparse.SparseToDense; +import org.tensorflow.types.family.TNumber; + +/** Implementation of set operations */ +public class SetsOps { + + private final Scope scope; + + private final FrameworkOps frameworkOps; + + /** + * Creates Framework {@code nn} Operations + * + * @param frameworkOps the TensorFLow framework Ops + */ + SetsOps(FrameworkOps frameworkOps) { + this.scope = frameworkOps.scope(); + this.frameworkOps = frameworkOps; + } + + /** + * Computes set difference of elements in last dimension of a and b with + * aMinusB set to true. + * + *

All but the last dimension of a and b must match + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand difference(Operand a, Operand b) { + return difference(a, b, true); + } + + /** + * Computes set difference of elements in last dimension of a and b. + * + *

All but the last dimension of a and b must match + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param aMinusB whether to subtract b from a, vs vice versa. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand difference(Operand a, Operand b, boolean aMinusB) { + return setOperation(a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); + } + + /** + * Computes set union of elements in last dimension of a and b. + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand union(Operand a, Operand b) { + return setOperation(a, b, Operation.UNION); + } + + /** + * Computes set intersection of elements in last dimension of a and b. + * + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand intersection(Operand a, Operand b) { + return setOperation(a, b, Operation.INTERSECTION); + } + + /** + * Compute set operation of elements in last dimension of a and b. + * + * @param a The first set operation operand + * @param b The other et operation operand + * @param setOperation The set operation to perform, {@link Operation}. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the same. Elements along the last dimension contain the results of the set + * operation. + */ + public Operand setOperation( + Operand a, Operand b, Operation setOperation) { + + DenseToDenseSetOperation setOperationResult = + DenseToDenseSetOperation.create( + scope, + a, + b, + setOperation.getSetOperation(), + DenseToDenseSetOperation.validateIndices(true)); + + return SparseToDense.create( + scope, + setOperationResult.resultIndices(), + setOperationResult.resultShape(), + setOperationResult.resultValues(), + Cast.create(scope, Constant.scalarOf(scope, 0), a.type())); + } + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java deleted file mode 100644 index c4843db7266..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/linalg/MatMul.java +++ /dev/null @@ -1,268 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.op.linalg; - -import org.tensorflow.Operand; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.dtypes.Cast; -import org.tensorflow.op.math.Conj; -import org.tensorflow.op.sparse.SparseMatMul; -import org.tensorflow.op.train.BatchMatMul; -import org.tensorflow.types.TBfloat16; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.family.TFloating; -import org.tensorflow.types.family.TNumber; - -/** - * Higher level operation for matMul that does logic before calling the low leve MatMul operations. - */ -/* TODO this is a higher level of abstraction from the low level matmul -it is defined as tf.matmul and tf.linalg.matmul in python. -Should this be defined here? */ -@Operator(group = "linalg") -public class MatMul { - - /** - * Multiplies matrix a by matrix b, producing a * b - * . - * - *

The inputs must, following any transpositions, be tensors of rank >= 2 where the inner 2 - * dimensions specify valid matrix multiplication dimensions, and any further outer dimensions - * specify matching batch size. - * - *

Both matrices must be of the same type. The supported types are: TFloat16, - * TFloat32, TFloat64, TInt32. - * - *

Either matrix can be transposed or adjointed (conjugated and transposed) on the fly by - * setting one of the corresponding flag to true. These are false by - * default. - * - *

A simple 2-D tensor matrix multiplication: - * - *

-   *  Operand a = tf.constant(new float[][] {{1, 2, 3}, {4, 5, 6}});
-   *  Operand b = tf.constant(new float[][] {{7, 8},{ 9, 10}, {11, 12}});
-   *  Operand c = FMWLinalgOps.matmul(tf.scope(), a, b)
-   *
-   * 
- * - *

Note: This is matrix product, not element-wise product. - * - * @param scope the Tensorflow scope - * @param a an Operand of of type TFloat16, TFloat32, TFloat64 - * , TInt32. with a rank > 1 - * @param b an Operand with same type and rank as a. - * @param the data type of the Operands - * @return A Operand of the same type as a and b where each inner-most - * matrix is the product of the corresponding matrices in a and b. - * This is the matrix product not an element-wise product. - * @throws java.lang.IllegalArgumentException If transposeA and adjointA - * , or transposeB and adjointB are both set to `true`. - */ - @Endpoint(name = "matmul") - public static Operand matmul(Scope scope, Operand a, Operand b) { - return matmul(scope, a, b, false, false, false, false, false, false); - } - - /** - * Multiplies matrix a by matrix b, producing a * b. - *

- * The inputs must, following any transpositions, be tensors of rank >= 2 - * where the inner 2 dimensions specify valid matrix multiplication - * dimensions, and any further outer dimensions specify matching batch size. - *

- * Both matrices must be of the same type. The supported types are: - * TFloat16, TFloat32, TFloat64, TInt32. - *

- * Either matrix can be transposed or adjointed (conjugated and transposed) - * on the fly by setting one of the corresponding flag to true. These are - * false by default. - *

- *

Note: This is matrix product, not element-wise product. - *

- * A simple 2-D tensor matrix multiplication: - *

-   * //TODO
-   * TFloat16, TFloat32, TFloat64, TInt32.
-   * with a rank > 1
-   * @param b an Operand with same type and rank as a.
-   * @param transposeA If `true`, a is transposed before multiplication.
-   * @param transposeB If `True`, b is transposed before multiplication
-   * @param  the data type of the Operands
-   * @return A Operand of the same type as a and b where each
-   * inner-most matrix is the product of the corresponding matrices in a and
-   * b. This is the
-   * matrix product not an element-wise product.
-   * @throws java.lang.IllegalArgumentException If transposeA and
-   * adjointA, or transposeB and adjointB are both set to `true`.
-   */
-  @Endpoint(name = "matmul")
-  public static  Operand matmul(
-      Scope scope, Operand a, Operand b, boolean transposeA, boolean transposeB) {
-    return matmul(scope, a, b, transposeA, transposeB, false, false, false, false);
-  }
-
-  /**
-   * Multiplies matrix a by matrix b, producing a * b.
-   * 

- * The inputs must, following any transpositions, be tensors of rank >= 2 - * where the inner 2 dimensions specify valid matrix multiplication - * dimensions, and any further outer dimensions specify matching batch size. - *

- * Both matrices must be of the same type. The supported types are: - * TFloat16, TFloat32, TFloat64, TInt32. - *

- * Either matrix can be transposed or adjointed (conjugated and transposed) - * on the fly by setting one of the corresponding flag to true. These are - * false by default. - * - *

Note: This is matrix product, not element-wise product. - *

- * A simple 2-D tensor matrix multiplication: - *

-   * //TODO
-   * TFloat16, TFloat32, TFloat64, TInt32.
-   * with a rank > 1
-   * @param b an Operand with same type and rank as a.
-   * @param transposeA If true, a is transposed before multiplication.
-   * @param transposeB If True, b is transposed before multiplication
-   * @param adjointA If true, a is conjugated and transposed before
-   * multiplication.
-   * @param adjointB If true, b is conjugated and transposed before
-   * multiplication.
-   * @param aIsSparse If true, a is treated as a sparse matrix. Notice, this
-   *       does not support org.tensorflow.framework.utils.SparseTensor, it just makes optimizations
-   *       that assume most values in a are zero.
-   * @param bIsSparse If true, b is treated as a sparse matrix. Notice, this
-   *       does not support org.tensorflow.framework.utils.SparseTensor, it just makes optimizations
-   *       that assume most values in b are zero.
-   * @param  the data type of the Operands
-   * @return A Operand of the same type as a and b where each
-   * inner-most matrix is the product of the corresponding matrices in a and
-   * b. This is the
-   * matrix product not an element-wise product.
-   * @throws java.lang.IllegalArgumentException If transposeA and
-   * adjointA, or transposeB and adjointB are both set to `true`.
-   */
-  @SuppressWarnings("unchecked")
-  @Endpoint(name = "matmul")
-  public static  Operand matmul(
-      Scope scope,
-      Operand a,
-      Operand b,
-      boolean transposeA,
-      boolean transposeB,
-      boolean adjointA,
-      boolean adjointB,
-      boolean aIsSparse,
-      boolean bIsSparse) {
-    scope = scope.withSubScope("MatMul");
-    if (transposeA && adjointA)
-      throw new IllegalArgumentException("Only one of transposeA and adjointA can be true.");
-    if (transposeB && adjointB)
-      throw new IllegalArgumentException("Only one of transposeB and adjointB can be true.");
-    if (!(TFloating.class.isAssignableFrom(a.type()) || a.type().equals(TInt32.class)))
-      throw new IllegalArgumentException(
-          String.format(
-              "Operand 'a' must be of type 'TBfloat16','TFloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s",
-              a.type().getSimpleName()));
-    if (!(TFloating.class.isAssignableFrom(a.type()) || b.type().equals(TInt32.class)))
-      throw new IllegalArgumentException(
-          String.format(
-              "Operand 'b' must be of type 'TBfloat16', 'TFloat32', 'TFloat64' or 'TInt32'. found type : %s",
-              b.type().getSimpleName()));
-
-    Shape aShape = a.shape();
-    Shape bShape = b.shape();
-    if (aShape.numDimensions() != bShape.numDimensions())
-      throw new IllegalArgumentException(
-          String.format(
-              "Parameters 'a' and 'b' must the same rank: found a rank = %d, b rank = %d",
-              aShape.numDimensions(), bShape.numDimensions()));
-    boolean outputMayHaveNonEmptyBatchShape =
-        aShape.numDimensions() == Shape.UNKNOWN_SIZE
-            || aShape.numDimensions() > 2
-            || bShape.numDimensions() == Shape.UNKNOWN_SIZE;
-
-    if ((!aIsSparse && !bIsSparse) && outputMayHaveNonEmptyBatchShape) {
-      // BatchMatmul does not support transpose, so we conjugate the matrix and
-      // use adjoint instead. Conj() is a noop for real matrices.
-      if (transposeA) {
-        a = Conj.create(scope, a);
-        adjointA = true;
-      }
-      if (transposeB) {
-        b = Conj.create(scope, b);
-        adjointB = true;
-      }
-      Operand bT = a.type().equals(b.type()) ? (Operand) b : Cast.create(scope, b, a.type());
-      return BatchMatMul.create(
-          scope, a, bT, BatchMatMul.adjX(adjointA), BatchMatMul.adjY(adjointB));
-    }
-
-    // Neither matmul nor sparse_matmul support adjoint, so we conjugate
-    // the matrix and use transpose instead. Conj() is a noop for real
-    // matrices.
-    if (adjointA) {
-      a = Conj.create(scope, a);
-      transposeA = true;
-    }
-    if (adjointB) {
-      b = Conj.create(scope, b);
-      transposeB = true;
-    }
-
-    boolean useSparseMatmul = false;
-    if (aIsSparse || bIsSparse) {
-      useSparseMatmul =
-          (a.type().equals(TBfloat16.class) || a.type().equals(TFloat32.class))
-              && (b.type().equals(TBfloat16.class) || b.type().equals(TFloat32.class));
-    }
-    if ((a.type().equals(TBfloat16.class) || b.type().equals(TBfloat16.class))
-        && !a.type().equals(b.type())) useSparseMatmul = true;
-
-    if (useSparseMatmul) {
-      Operand result =
-          SparseMatMul.create(
-              scope,
-              a,
-              b,
-              SparseMatMul.transposeA(transposeA),
-              SparseMatMul.transposeB(transposeB),
-              SparseMatMul.aIsSparse(aIsSparse),
-              SparseMatMul.bIsSparse(bIsSparse));
-      if (a.type().equals(TFloat32.class)) return (Operand) result;
-      else return Cast.create(scope, result, a.type());
-    }
-
-    // need to cast b to Operand
-    Operand bT = a.type().equals(b.type()) ? (Operand) b : Cast.create(scope, b, a.type());
-
-    return org.tensorflow.op.linalg.MatMul.create(
-        scope, a, bT,
-            org.tensorflow.op.linalg.MatMul.transposeA(transposeA),
-            org.tensorflow.op.linalg.MatMul.transposeB(transposeB));
-  }
-}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ReduceLogSumExp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ReduceLogSumExp.java
deleted file mode 100644
index 71678344976..00000000000
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/ReduceLogSumExp.java
+++ /dev/null
@@ -1,142 +0,0 @@
-/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
-    http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-=======================================================================*/
-package org.tensorflow.framework.op.math;
-
-import org.tensorflow.Operand;
-import org.tensorflow.ndarray.Shape;
-import org.tensorflow.op.Scope;
-import org.tensorflow.op.annotation.Endpoint;
-import org.tensorflow.op.annotation.Operator;
-import org.tensorflow.op.core.*;
-import org.tensorflow.op.dtypes.Cast;
-import org.tensorflow.op.math.*;
-import org.tensorflow.types.TInt32;
-import org.tensorflow.types.TString;
-import org.tensorflow.types.family.TFloating;
-import org.tensorflow.types.family.TNumber;
-import org.tensorflow.types.family.TType;
-
-@Operator(group = "math")
-public class ReduceLogSumExp {
-
-
-
-  // TODO this method is defined in tf.math.reduce_logsumexp in TF Python.
-  /**
-   * Computes log(sum(exp(elements across dimensions of a tensor))). Reduces {@code input_tensor}
-   * along the dimensions given in {@code axes}.
-   *
-   * 

Reduces `{@code input} along the dimensions given in {@code axes}. Unless {@code keepdims} - * is true, the rank of the tensor is reduced by 1 for each of the entries in {@code axes}, which - * must be unique. If {@code keepdims} is true, the reduced dimensions are retained with length 1. - * If {@code axes} has no entries, all dimensions are reduced, and a tensor with a single element - * is returned. This function is more numerically stable than {@code log(sum(exp(input)))}. It - * avoids overflows caused by taking the exp of large inputs and underflows caused by taking the - * log of small inputs. - * - * @param input The tensor to reduce. - * @param axes The dimensions to reduce. If null, reduces all dimensions. Must be in the range - * {@link [-rank(input_tensor), rank(input_tensor)]}. - * @param keepDims If true, retains reduced dimensions with length 1. - * @return The reduced tensor. - */ - @Endpoint(name = "reduceLogSumExp") - public static Operand reduceLogSumExp( - Scope scope, Operand input, int[] axes, boolean keepDims) { - Operand reduceDims = reductionDims(scope, input, axes); - Operand rawMax = reduceMaxWithDims(scope, input, axes, keepDims, reduceDims); - Operand myMax = - StopGradient.create( - scope, - Select.create( - scope, IsFinite.create(scope, rawMax), rawMax, ZerosLike.create(scope, rawMax))); - - Operand result = - Log.create( - scope, - reduceSumWithDims( - scope, - Exp.create(scope, Sub.create(scope, input, myMax)), - axes, - keepDims, - reduceDims)); - - if (!keepDims) { - myMax = Reshape.create(scope, myMax, org.tensorflow.op.core.Shape.create(scope, result)); - } - result = Add.create(scope, result, myMax); - return mayReduceToScalar(scope, keepDims, axes, result); - } - - private static Operand reduceSumWithDims( - Scope scope, Operand input, int[] axes, boolean keepDims, Operand dims) { - return mayReduceToScalar( - scope, keepDims, axes, ReduceSum.create(scope, input, dims, ReduceSum.keepDims(keepDims))); - } - - private static Operand reduceMaxWithDims( - Scope scope, Operand input, int[] axes, boolean keepDims, Operand dims) { - return mayReduceToScalar( - scope, keepDims, axes, ReduceMax.create(scope, input, dims, ReduceMax.keepDims(keepDims))); - } - - /** - * Sets a reduction's output shape to be a scalar if possible. - * - * @return the operand, possibly reduced to a scalar. - */ - private static Operand mayReduceToScalar( - Scope scope, boolean keepDims, int[] axes, Operand output) { - - if ((output.shape().numDimensions() == Shape.UNKNOWN_SIZE - || output.shape().hasUnknownDimension()) - && !keepDims - && axes == null) { - return Reshape.create(scope, output, Constant.tensorOf(scope, Shape.scalar())); - } else { - return output; - } - } - - /** - * Reduce dimensions based on axis - * - * @param scope the TensorFlow scope - * @param input the input - * @param axes he dimensions to reduce, may be null - * @return the dimensions to be reduced. - */ - private static Operand reductionDims( - Scope scope, Operand input, int[] axes) { - if (axes != null) { - return Constant.vectorOf(scope, axes); - } - long rank = input.shape().numDimensions(); - if (rank != Shape.UNKNOWN_SIZE) { - int[] dims = new int[(int) rank]; - for (int i = 0; i < rank; i++) { - dims[i] = i; - } - return Constant.vectorOf(scope, dims); - - } else { - return Range.create( - scope, - Constant.scalarOf(scope, 0), - Rank.create(scope, input), - Constant.scalarOf(scope, 1)); - } - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java deleted file mode 100644 index 28c430369a2..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/math/TensorDot.java +++ /dev/null @@ -1,719 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.op.math; - -import org.tensorflow.Graph; -import org.tensorflow.Operand; -import org.tensorflow.Session; -import org.tensorflow.framework.op.linalg.MatMul; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.core.*; -import org.tensorflow.op.dtypes.Cast; -import org.tensorflow.op.linalg.Transpose; -import org.tensorflow.op.math.Add; -import org.tensorflow.op.math.GreaterEqual; -import org.tensorflow.op.math.Less; -import org.tensorflow.op.math.Sub; -import org.tensorflow.types.TBfloat16; -import org.tensorflow.types.TFloat16; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.family.TFloating; -import org.tensorflow.types.family.TNumber; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.stream.Collectors; - -// NOTE: This is implemented as tf.tensordot() in Python, and the code is in -// python/ops/math_ops.py. Is this where we want to keep this? -/** - * Implements the TensorDot (Tensor contraction of a and b along specified axes and outer product) - * operation. - */ -@Operator(group = "math") -public abstract class TensorDot { - - /** - * Transpose and reshape the input for contraction op. - * - *

This method is helpful in reducing `math_ops.tensordot` to `math_ops.matmul` using - * `array_ops.transpose` and `array_ops.reshape`. The method takes a tensor and performs the - * correct transpose and reshape operation for a given set of indices. It returns the reshaped - * tensor as well as a list of indices necessary to reshape the tensor again after matrix - * multiplication. - * - * @param the type of Operand - * @param scope the TensorFlow scope - * @param a the Tensor - * @param axis unique indices specifying valid axes of `a`. - * @param flipped whether to flip the dimensions or not - * @return A tuple (reshapedA, freeDims, freeDimsStatic) where reshapedA is a reshaped to allow - * contraction via matmul, freeDims` is a TInt32 Operand, depending on whether the shape of a - * is fully specified, and freeDimsStatic is either a list of integers and null values, or - * None, representing the inferred static shape of the free dimensions - */ - private static Object[] tensordotReshape( - Scope scope, Operand a, Operand axis, boolean flipped) { - Shape aShape = a.shape(); - - if (!aShape.hasUnknownDimension()) { // calculate using static values - long[] aShapeDims = aShape.asArray(); - if (aShapeDims == null) aShapeDims = new long[0]; - long[] aDimsIndex = new long[aShapeDims.length]; - for (int i = 0; i < aDimsIndex.length; i++) aDimsIndex[i] = i; - - // get int array from axis Operand - int[] iAxes = getIntArray(scope, axis); - // Convert negative axes to positive - for (int i = 0; i < iAxes.length; i++) - iAxes[i] = iAxes[i] >= 0 ? iAxes[i] : Math.floorMod(iAxes[i], iAxes.length); - - // convert integer axis to long axis - long[] lAxes = Arrays.stream(iAxes).mapToLong(i -> i).toArray(); - - // create list of the axes, dims, and free axes - List axesList = Arrays.stream(lAxes).boxed().collect(Collectors.toList()); - List freeList = Arrays.stream(aDimsIndex).boxed().collect(Collectors.toList()); - freeList.removeAll(axesList); - - // create array of free dims - long[] free = freeList.stream().mapToLong(i -> i).toArray(); - long[] freeDims = new long[free.length]; - for (int i = 0; i < free.length; i++) freeDims[i] = aShapeDims[(int) free[i]]; - - // Calculate the free dim by doing a reduce prod - long prodFree = 1; - for (long i : freeDims) { - prodFree *= i; - } - - // calculate the used dims by doing a reduce prod - long prodAxis = 1; - for (long i : lAxes) { - prodAxis *= aShapeDims[(int) i]; - } - - // setup the permutations array for the transpose - long[] perm = new long[freeDims.length + lAxes.length]; - Shape newShape; - if (flipped) { - System.arraycopy(lAxes, 0, perm, 0, lAxes.length); - System.arraycopy(free, 0, perm, lAxes.length, free.length); - newShape = Shape.of(prodAxis, prodFree); - } else { - System.arraycopy(free, 0, perm, 0, free.length); - System.arraycopy(lAxes, 0, perm, freeDims.length, lAxes.length); - newShape = Shape.of(prodFree, prodAxis); - } - - Operand aTrans; - long[] arrange = new long[lAxes.length]; - for (int i = 0; i < arrange.length; i++) arrange[i] = i; - - // if the permutations is not equals to the natural order of the dims, then do a transpose - if (!Arrays.equals(perm, arrange)) { - aTrans = Transpose.create(scope, a, Constant.vectorOf(scope, perm)); - } else { - aTrans = a; - } - - // reshape the final result to the new Shape, if necessary - Operand aReshaped = - aTrans.asOutput().shape().equals(newShape) - ? aTrans - : Reshape.create(scope, aTrans, Constant.vectorOf(scope, newShape.asArray())); - // return a tuple for the reshaped Operand, and Operand for the free dimensions, and a long - // array for the free dimensions - return new Object[] {aReshaped, Constant.vectorOf(scope, freeDims), freeDims}; - - } else { // calculate dynamically - - long[] freeDimsStatic = null; - Operand one = Constant.scalarOf(scope, 1); - Operand minusOne = Constant.scalarOf(scope, -1); - Operand zero = Constant.scalarOf(scope, 0); - org.tensorflow.op.core.Shape tShape = org.tensorflow.op.core.Shape.create(scope, a); - Operand axesT; - Operand freeT; - if (aShape.numDimensions() - != Shape.UNKNOWN_SIZE) { // we know the rank, but there are unknown dimensions - long[] aShapeDims = aShape.asArray(); - if (aShapeDims == null) aShapeDims = new long[0]; - - // get int array from axis Operand - int[] iAxes = getIntArray(scope, axis); - // Convert negative axes to positive - for (int i = 0; i < iAxes.length; i++) - iAxes[i] = iAxes[i] >= 0 ? iAxes[i] : Math.floorMod(iAxes[i], iAxes.length); - - // convert integer axis to long axis - long[] lAxes = Arrays.stream(iAxes).mapToLong(i -> i).toArray(); - - // create list of the axes, dims, and free axes - List axesList = Arrays.stream(lAxes).boxed().collect(Collectors.toList()); - List dimsList = Arrays.stream(aShapeDims).boxed().collect(Collectors.toList()); - List freeList = new ArrayList<>(axesList); - freeList.removeAll(dimsList); - - // create array of free dims - long[] freeDims = freeList.stream().mapToLong(i -> i).toArray(); - freeDimsStatic = freeDims; - - axesT = Constant.vectorOf(scope, iAxes); - freeT = Cast.create(scope, Constant.vectorOf(scope, freeDims), TInt32.class); - - } else { // we don't know the rank yet - Rank rank = Rank.create(scope, a); - - // convert axis to positive - axesT = - Select.create( - scope, - GreaterEqual.create(scope, axis, Constant.scalarOf(scope, 0)), - axis, - Add.create(scope, axis, rank)); - - SetDiff1d diff = - SetDiff1d.create( - scope, Range.create(scope, Constant.scalarOf(scope, 0), rank, one), axesT); - freeT = diff.out(); - } - Operand freeDims = Gather.create(scope, tShape, freeT, zero); - Operand axesDims = Gather.create(scope, tShape, axesT, zero); - Operand prodFreeDims = ReduceProd.create(scope, freeDims, minusOne); - Operand prodAxesDims = ReduceProd.create(scope, axesDims, minusOne); - Operand perm; - Operand newShape; - if (flipped) { - perm = Concat.create(scope, Arrays.asList(axesT, freeT), zero); - newShape = Stack.create(scope, Arrays.asList(prodAxesDims, prodFreeDims)); - } else { - perm = Concat.create(scope, Arrays.asList(freeT, axesT), zero); - newShape = Stack.create(scope, Arrays.asList(prodFreeDims, prodAxesDims)); - } - Operand aReshaped = Reshape.create(scope, Transpose.create(scope, a, perm), newShape); - return new Object[] {aReshaped, freeDims, freeDimsStatic}; - } - } - - /** - * Gets an int array from an Operand<TInt32> operand. - * - * @param scope the TensorFlow scope - * @param axes the Operand to fetch the values - * @return the int array from an Operand<TInt32> - */ - private static int[] getIntArray(Scope scope, Operand axes) { - List result = new ArrayList<>(); - if (scope.env().isEager()) { - axes.asTensor().scalars().forEach(s -> result.add(s.getInt())); - } else { - try (Session session = new Session((Graph) scope.env()); - TInt32 tensor = (TInt32) session.runner().fetch(axes).run().get(0)) { - tensor.scalars().forEach(s -> result.add(s.getInt())); - } - } - return result.stream().mapToInt(i -> i).toArray(); - } - - /** - * Generates two sets of contraction axes for the two tensor arguments. - * - * @param scope the scope - * @param a the Operand to analyze - * @param axis the axes - * @param the data type for the Operand - * @return the contraction axes - */ - @SuppressWarnings("unchecked") - private static Operand[] tensordotAxes( - Scope scope, Operand a, int axis) { - Shape aShape = a.asOutput().shape(); - if (axis < 0) { - throw new IllegalArgumentException("'axis' must be at least 0."); - } - int rank = aShape.numDimensions(); - Operand[] result = new Operand[2]; - if (rank != Shape.UNKNOWN_SIZE) { - if (axis > rank) { - throw new IllegalArgumentException( - String.format( - "'axis' must not be larger than the number of dimensions of tensor %s.", rank)); - } - int min = rank - axis; - int postRange = rank - min; - int[] postAxis = new int[postRange]; - for (int i = 0; i < postRange; i++) postAxis[i] = i + min; - - int[] preAxis = new int[axis]; - for (int i = 0; i < axis; i++) preAxis[i] = i; - - result[0] = Constant.vectorOf(scope, postAxis); - result[1] = Constant.vectorOf(scope, preAxis); - } else { - Rank rankT = Rank.create(scope, a); - Constant axisT = Constant.scalarOf(scope, axis); - Constant one = Constant.scalarOf(scope, 1); - Constant zero = Constant.scalarOf(scope, 0); - AssertThat assertion = - AssertThat.create( - scope, - Less.create(scope, axisT, rankT), - Arrays.asList( - Constant.scalarOf( - scope, "'axes' must not be larger than the number of dimensions of tensor "), - rankT)); - Scope scope1 = scope.withControlDependencies(Collections.singletonList(assertion)); - result[0] = Range.create(scope1, Sub.create(scope, rankT, axisT), rankT, one); - result[1] = Range.create(scope1, zero, axisT, one); - } - return result; - } - - /** - * Generates two sets of contraction axes for the two tensor arguments. - * - * @param scope the scope - * @param a the Operand to analyze - * @param axes the axes - * @param the data type for the Operand - * @return the contraction axes - */ - @SuppressWarnings({"unchecked", "unused"}) - private static Operand[] tensordotAxes( - Scope scope, Operand a, int[] axes) { - if (axes.length != 2) - throw new IllegalArgumentException( - "'axes' must have length 1 or 2, provided with " + axes.length); - int[] aAxis = new int[] {axes[0]}; - int[] bAxis = new int[] {axes[1]}; - Operand[] result = new Operand[2]; - result[0] = Constant.vectorOf(scope, aAxis); - result[1] = Constant.vectorOf(scope, bAxis); - - return result; - } - - /** - * Generates two sets of contraction axes for the two tensor arguments. - * - * @param scope the scope - * @param a the Operand to analyze - * @param axes the axes - * @param the data type for the Operand - * @return the contraction axes - */ - @SuppressWarnings({"unchecked", "unused"}) - private static Operand[] tensordotAxes( - Scope scope, Operand a, int[][] axes) { - if (axes.length != 2) - throw new IllegalArgumentException( - "'axes' must have length 1 or 2, provided with " + axes.length); - int[] aAxis = axes[0]; - int[] bAxis = axes[1]; - if (aAxis.length != bAxis.length) - throw new IllegalArgumentException( - String.format( - "Different number of contraction axes 'a' and 'b', %d != %d", - aAxis.length, bAxis.length)); - Operand[] result = new Operand[2]; - result[0] = Constant.vectorOf(scope, aAxis); - result[1] = Constant.vectorOf(scope, bAxis); - return result; - } - - /** - * Generates two sets of contraction axes for the two tensor arguments. - * - * @param scope the scope - * @param a the Operand to analyze - * @param axes the axes - * @param the data type for the Operand - * @return the contraction axes - */ - @SuppressWarnings({"unchecked", "unused"}) - private static Operand[] tensordotAxes( - Scope scope, Operand a, Operand axes) { - - Constant one = Constant.scalarOf(scope, 1); - Constant zero = Constant.scalarOf(scope, 0); - Operand[] result = new Operand[2]; - result[0] = - Slice.create( - scope, - axes, - Cast.create(scope, zero, TInt32.class), - Cast.create(scope, one, TInt32.class)); - result[1] = - Slice.create( - scope, - axes, - Cast.create(scope, one, TInt32.class), - Cast.create(scope, one, TInt32.class)); - return result; - } - - /** - * Tensor contraction of a and b along specified axes and outer product. - *

- * Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length - * and consist of unique integers that specify valid axes for each of the - * tensors. Additionally outer product is supported by passing - * axes=0. - *

- * This operation corresponds to numpy.tensordot(a, b, axes). - *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. - *

- * Example 2: When a and`b are matrices (order 2), - * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. - *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order - * 4. - *

- * Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor - * cjklm whose entry corresponding to the indices - * (j,k,l,m) is given by: - *

- * cjklm = Σi aijk - * blmi . - *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). - *

- * - * @param scope the TensorFlow Scope - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. - * @param axis sum over the last N axes of a and the - * first N axes of b in order. If `axes=0`, computes the outer - * product between `a` and `b`. - * @param the datatype of the Operands, must be either TFloat32 or - * TFloat64 - * @return A `Operand` with the same type as `a`. - * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type - */ - @Endpoint(name = "tensordot") - public static Operand tensordot( - Scope scope, Operand a, Operand b, int axis) { - - Operand[] abAxis = tensordotAxes(scope, a, axis); - Operand aAxis = abAxis[0]; - Operand bAxis = abAxis[1]; - return tensordot(scope, a, b, aAxis, bAxis); - } - - /** - * Tensor contraction of a and b along specified axes and outer product. - *

- * Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length - * and consist of unique integers that specify valid axes for each of the - * tensors. Additionally outer product is supported by passing - * axes=0. - *

- * This operation corresponds to numpy.tensordot(a, b, axes). - *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. - *

- * Example 2: When a and`b are matrices (order 2), - * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. - *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order - * 4. - *

- * Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor - * cjklm whose entry corresponding to the indices - * (j,k,l,m) is given by: - *

- * cjklm = Σi aijk - * blmi . - *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). - *

- * - * @param scope the TensorFlow Scope - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. - * @param axes If axes is a scalar, sum over the last N axes of a and the - * first N axes of b in order. If axes is a list, the first and second row - * contain the set of unique integers specifying axes along which the - * contraction is computed, for `a` and `b`, respectively. The number of - * axes for `a` and `b` must be equal. If `axes=0`, computes the outer - * product between `a` and `b`. - * @param the datatype of the Operands, must be either TFloat32 or - * TFloat64 - * @return A `Operand` with the same type as `a`. - * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type - */ - @Endpoint(name = "tensordot") - public static Operand tensordot( - Scope scope, Operand a, Operand b, Operand axes) { - - Operand[] abAxis = tensordotAxes(scope, a, axes); - Operand aAxis = abAxis[0]; - Operand bAxis = abAxis[1]; - - return tensordot(scope, a, b, aAxis, bAxis); - } - - /** - * Tensor contraction of a and b along specified axes and outer product. - *

- * Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length - * and consist of unique integers that specify valid axes for each of the - * tensors. Additionally outer product is supported by passing - * axes=0. - *

- * This operation corresponds to numpy.tensordot(a, b, axes). - *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. - *

- * Example 2: When a and`b are matrices (order 2), - * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. - *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order - * 4. - *

- * Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor - * cjklm whose entry corresponding to the indices - * (j,k,l,m) is given by: - *

- * cjklm = Σi aijk - * blmi . - *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). - *

- * - * @param scope the TensorFlow Scope - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. - * @param axes the first and second row - * contain the set of unique integers specifying axes along which the - * contraction is computed, for `a` and `b`, respectively. The number of - * axes for `a` and `b` must be equal. I - * @param the datatype of the Operands, must be either TFloat32 or - * TFloat64 - * @return A `Operand` with the same type as `a`. - * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type - */ - @Endpoint(name = "tensordot") - public static Operand tensordot( - Scope scope, Operand a, Operand b, int[] axes) { - - Operand[] abAxis = tensordotAxes(scope, a, axes); - Operand aAxis = abAxis[0]; - Operand bAxis = abAxis[1]; - - return tensordot(scope, a, b, aAxis, bAxis); - } - - /** - * Tensor contraction of a and b along specified axes and outer product. - *

- * Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length - * and consist of unique integers that specify valid axes for each of the - * tensors. Additionally outer product is supported by passing - * axes=0. - *

- * This operation corresponds to numpy.tensordot(a, b, axes). - *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. - *

- * Example 2: When a and`b are matrices (order 2), - * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. - *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order - * 4. - *

- * Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor - * cjklm whose entry corresponding to the indices - * (j,k,l,m) is given by: - *

- * cjklm = Σi aijk - * blmi . - *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). - *

- * - * @param scope the TensorFlow Scope - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. - * @param axes the first and second row - * contain the set of unique integers specifying axes along which the - * contraction is computed, for `a` and `b`, respectively. The number of - * axes for `a` and `b` must be equal. I - * @param the datatype of the Operands, must be either TFloat32 or - * TFloat64 - * @return A `Operand` with the same type as `a`. - * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type - */ - @Endpoint(name = "tensordot") - public static Operand tensordot( - Scope scope, Operand a, Operand b, int[][] axes) { - - Operand[] abAxis = tensordotAxes(scope, a, axes); - Operand aAxis = abAxis[0]; - Operand bAxis = abAxis[1]; - - return tensordot(scope, a, b, aAxis, bAxis); - } - - /** - * Tensor contraction of a and b along specified axes and outer product. - *

- * Tensordot (also known as tensor contraction) sums the product of elements - * from a and b` over the indices specified by - * a_axes and b_axes. The lists - * a_axes and b_axes specify those pairs of axes - * along which to contract the tensors. The axis a_axes[i] of - * a must have the same dimension as axis - * b_axes[i] of b for all i in - * range(0, len(a_axes)). The lists - * a_axes and b_axes must have identical length - * and consist of unique integers that specify valid axes for each of the - * tensors. Additionally outer product is supported by passing - * axes=0. - *

- * This operation corresponds to numpy.tensordot(a, b, axes). - *

- * Example 1: When a and b are matrices (order 2), - * the case axes = 1 is equivalent to matrix multiplication. - *

- * Example 2: When a and`b are matrices (order 2), - * the case - * axes = [[1], [0]] is equivalent to matrix multiplication. - *

- * Example 3: When a and b are matrices (order 2), - * the case axes=0 gives the outer product, a tensor of order - * 4. - *

- * Example 4: Suppose that aijk and blmn - * represent two tensors of order 3. Then, contract(a, b, [[0], [2]]) is the order 4 tensor - * cjklm whose entry corresponding to the indices - * (j,k,l,m) is given by: - *

- * cjklm = Σi aijk - * blmi . - *

- * In general, order(c) = order(a) + order(b) - 2*len(axes[0]). - *

- * - * @param scope the TensorFlow Scope - * @param a `Operand` of type `float32` or `float64`. - * @param b `Operand` with the same type as `a`. - * @param aAxis axes for the a Operand - * @param bAxis axes for the b Operand - * @param the datatype of the Operands, must be either TFloat32 or - * TFloat64 - * @return A `Operand` with the same type as `a`. - * @throws IllegalArgumentException if a is not a float32 or float64 data type and if a and b are not the same data type - */ - @SuppressWarnings({"unchecked", "unused"}) - @Endpoint(name = "tensordot") - public static Operand tensordot( - Scope scope, Operand a, Operand b, Operand aAxis, Operand bAxis) { - - if (a.type().equals(TBfloat16.class) || a.type().equals(TFloat16.class)) { - throw new IllegalArgumentException( - String.format( - "Operand 'a' must be either TFloat32 or TFloat64 DataType, 'a' is a %s DataType", - a.type().getSimpleName())); - } - if (!a.type().equals(b.type())) { - throw new IllegalArgumentException( - String.format( - "Operands a and b must be the same data type, a is %s DataType, b is %s DataType", - a.type().getSimpleName(), b.type().getSimpleName())); - } - - // first result is Operand, second result is Operand, third result is long[] and it is - // ignored here. - Object[] aResult = tensordotReshape(scope, a, aAxis, false); - Operand reshapedA = (Operand) aResult[0]; - Operand aFreeDims = (Operand) aResult[1]; - long[] aFreeDimsStatic = (long[]) aResult[2]; - - // first result is Operand, second result is Operand, third result is long[] and it is - // ignored here. - Object[] bResult = tensordotReshape(scope, b, bAxis, true); - Operand reshapedB = (Operand) bResult[0]; - Operand bFreeDims = (Operand) bResult[1]; - long[] bFreeDimsStatic = (long[]) bResult[2]; - - Operand abMatmul = MatMul.matmul(scope, reshapedA, reshapedB); - long[] abDimsStatic = new long[aFreeDimsStatic.length + bFreeDimsStatic.length]; - System.arraycopy(aFreeDimsStatic, 0, abDimsStatic, 0, aFreeDimsStatic.length); - System.arraycopy( - bFreeDimsStatic, 0, abDimsStatic, aFreeDimsStatic.length, bFreeDimsStatic.length); - if (!abMatmul.shape().hasUnknownDimension() - && abMatmul.shape().equals(Shape.of(abDimsStatic))) { - return abMatmul; - } else { - return Reshape.create(scope, abMatmul, Constant.vectorOf(scope, abDimsStatic)); - } - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java deleted file mode 100644 index ade4cbb9166..00000000000 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java +++ /dev/null @@ -1,119 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -=======================================================================*/ -package org.tensorflow.framework.op.nn; - -import org.tensorflow.Operand; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; -import org.tensorflow.op.core.*; -import org.tensorflow.op.linalg.Transpose; -import org.tensorflow.op.math.Add; -import org.tensorflow.op.math.Sub; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.family.TFloating; -import org.tensorflow.types.family.TNumber; - -import java.util.Arrays; - -/** - * Higher level operation for Softmax. This class will move the desired axis to the last axis, if - * necessary, before calling the low level tf.nn.softmax method. - */ - -/* TODO this is a higher level of abstraction from the low level softmax -this method is defined as tf.nn.softmax in TF Python -Should this be defined here? */ -@Operator(group = "nn") -public class Softmax { - - /** - * Calculates a Softmax operation. If the exis is not the last dimension, then the input axis is - * moved to the last axis berfore calling tf.nn.softmax, then restored before returning. - * - * @param scope The TensorFlow scope - * @param input the input - * @param axis the axis - * @return the softmax of the input for the specified axis. - * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive - * @param the data type for the input and result - */ - @Endpoint(name = "softmax") - public static Operand softmax(Scope scope, Operand input, int axis) { - Shape shape = input.shape(); - boolean isLastDim = axis == -1 || axis == shape.numDimensions() - 1; - if (isLastDim) { - return org.tensorflow.op.nn.Softmax.create(scope, input); - } - - if (axis <= -shape.numDimensions() || axis >= shape.numDimensions()) { - throw new IllegalArgumentException( - String.format( - "Axis (%d) must be in the range [%d, %d] where %d is the number of dimensions in the input.", - axis, -shape.numDimensions(), shape.numDimensions(), shape.numDimensions())); - } - - int dim = Math.floorMod(axis, shape.numDimensions()); - Operand rank = Rank.create(scope, input); - Operand dimOp = Constant.scalarOf(scope, dim); - Operand one = Constant.scalarOf(scope, 1); - Operand lastIndex = Sub.create(scope, rank, one); - Operand swappedInputs = swapAxis(scope, input, dimOp, lastIndex); - Operand output = org.tensorflow.op.nn.Softmax.create(scope, swappedInputs); - return fixOutput(scope, output, shape, dimOp, lastIndex); - } - - /** - * Restores the specified axis, then reshapes the input to the provided shaoe. - * - * @param scope The TensorFlow scope - * @param output the output - * @param shape the desired shape - * @param dim the dimension to move - * @return the restored output based on the dimension and shape. - */ - private static Operand fixOutput( - Scope scope, Operand output, Shape shape, Operand dim, Operand lastIndex) { - - Operand result = swapAxis(scope, output, dim, lastIndex); - return Reshape.create(scope, result, Constant.tensorOf(scope, shape)); - } - - /** - * Moves the specified Axis to the last axis - * - * @param input the input - * @param dim the dimension to move - * @param lastIndex the last dimension - * @return input with the dimension swapped to the last dimension - */ - private static Operand swapAxis( - Scope scope, Operand input, Operand dim, Operand lastIndex) { - - Operand zero = Constant.scalarOf(scope, 0); - Operand one = Constant.scalarOf(scope, 1); - return Transpose.create( - scope, - input, - Concat.create( - scope, - Arrays.asList( - Range.create(scope, zero, dim, one), - Range.create(scope, Add.create(scope, dim, one), lastIndex, one), - dim), - zero)); - } -} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java new file mode 100644 index 00000000000..7d59941f27a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/SoftmaxCrossEntropyWithLogits.java @@ -0,0 +1,226 @@ +package org.tensorflow.framework.op.nn; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.core.Concat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.core.Slice; +import org.tensorflow.op.dtypes.Cast; +import org.tensorflow.op.linalg.Transpose; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TBfloat16; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import java.util.Arrays; +import java.util.List; + +// @Operator(group = "nn") +public class SoftmaxCrossEntropyWithLogits { + + /** + * Computes softmax cross entropy between logits and labels. + * + *

Measures the probability error in discrete classification tasks in which the classes are + * mutually exclusive (each entry is in exactly one class). For example, each CIFAR-10 image is + * labeled with one and only one label: an image can be a dog or a truck, but not both. + * + *

NOTE: + * + *

While the classes are mutually exclusive, their probabilities need not be. All that is + * required is that each row of labels is a valid probability distribution. If they + * are not, the computation of the gradient will be incorrect. + * + *

If using exclusive labels (wherein one and only one class is true at a time), + * see {@link org.tensorflow.op.NnOps#sparseSoftmaxCrossEntropyWithLogits} + * + *

Usage: + * + *

+   *   Operand<TFloat32> logits =
+   *       tf.constant(new float[][] {{4.0F, 2.0F, 1.0F}, {0.0F, 5.0F, 1.0F}} );
+   *   Operand<TFloat32> labels =
+   *       tf.constant(new float[][] {{1.0F, 0.0F, 0.0F}, {0.0F, 0.8F, 0.2F}} );
+   *   Operand<TFloat32> output =
+   *       tf.nn.softmaxCrossEntropyWithLogits(labels, logits, -1);
+   *   // output Shape = [2]
+   *   // dataType = FLOAT (1)
+   *   // values { 0.169846, 0.824745 }
+   * 
+ * + *

Backpropagation will happen into both logits and labels. To + * disallow backpropagation into labels, pass label tensors through + * tf.stopGradient before feeding it to this function. + * + * @param scope current scope + * @param labels Each vector along the class dimension should hold a valid probability + * distribution e.g. for the case in which labels are of shape [batch_size, num_classes] + * , each row of labels[i] must be a valid probability distribution. + * @param logits Per-label activations, typically a linear output. These activation energies are + * interpreted as unnormalized log probabilities. + * @param axis The class dimension. -1 is the last dimension. + * @param the data type for the logits and return operand + * @param the data type for the labels + * @return the softmax cross entropy loss. Its type is the same as logits and its + * shape is the same as labels except that it does not have the last dimension of + * labels. + */ + @SuppressWarnings("unchecked") + @Endpoint(name = "softmaxCrossEntropyWithLogits") + public static Operand softmaxCrossEntropyWithLogits( + Scope scope, Operand labels, Operand logits, int axis) { + scope = scope.withSubScope("SoftmaxCrossEntropyWithLogits"); + axis = axis % logits.shape().numDimensions(); + if (axis < 0) { + axis += logits.shape().numDimensions(); + } + + if (logits.asOutput().type() == TFloat16.class || logits.asOutput().type() == TBfloat16.class) { + Operand result = + softmaxCrossEntropyWithLogits( + scope, + Cast.create(scope, labels, TFloat32.class), + Cast.create(scope, logits, TFloat32.class), + axis); + return Cast.create(scope, result, logits.asOutput().type()); + } + + if (logits.asOutput().type() != labels.asOutput().type()) { + return softmaxCrossEntropyWithLogits( + scope, Cast.create(scope, labels, logits.asOutput().type()), logits, axis); + } + + Operand inputRank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); + Shape shape = logits.shape(); + + // Move the dim to the end if dim is not the last dimension. + if (axis != -1 && axis != logits.shape().numDimensions() - 1) { + logits = moveDimToEnd(scope, logits, axis, inputRank); + labels = moveDimToEnd(scope, labels, axis, inputRank); + } + + Operand tLabels; + if (labels.type() != logits.type()) { + tLabels = Cast.create(scope, labels, logits.type()); + } else { + // Unchecked warning checked in if statement. + tLabels = (Operand) labels; + } + + Shape inputShape = logits.shape(); + logits = flattenOuterDims(scope, logits); + tLabels = flattenOuterDims(scope, tLabels); + + org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits smax = + org.tensorflow.op.nn.SoftmaxCrossEntropyWithLogits.create(scope, logits, tLabels); + /* cannot use generic on cost, because cost may be recast later. */ + Operand cost = smax.loss(); + Operand outputShape = + Slice.create( + scope, + Constant.tensorOf(scope, inputShape), + Constant.arrayOf(scope, 0L), + Constant.arrayOf(scope, inputShape.numDimensions() - 1L)); + cost = Reshape.create(scope, cost, outputShape); + if (scope.env().isGraph() && !shape.hasUnknownDimension()) { + long[] array = shape.asArray(); + if (array == null) { + array = new long[0]; + } + long[] newArray = new long[array.length - 1]; + if (axis < 0) { + axis = shape.numDimensions() + axis; + } + for (int i = 0; i < axis; i++) { + newArray[i] = shape.size(i); + } + for (int i = axis + 1; i < shape.numDimensions(); i++) { + newArray[i - 1] = shape.size(i); + } + cost = Reshape.create(scope, cost, Constant.vectorOf(scope, newArray)); + } + + return cost; + } + + /** + * Flattens logits' outer dimensions and keep its last dimension. + * + * @param scope the TensorFlow scope + * @param logits the logits + * @param the type of logits + * @return the flattened logits + */ + private static Operand flattenOuterDims(Scope scope, Operand logits) { + Operand one = Constant.scalarOf(scope, 1L); + + Shape shape = logits.shape(); + int ndims = shape.numDimensions(); + if (!shape.hasUnknownDimension()) { + long product = 1L; + boolean productValid = true; + for (int i = ndims - 2; i >= 0; i--) { + long d = shape.size(i); + if (d == Shape.UNKNOWN_SIZE) { + productValid = false; + break; + } + product *= d; + } + if (productValid) { + return Reshape.create(scope, logits, Constant.arrayOf(scope, product, shape.size(-1))); + } + } + + Operand rank = Cast.create(scope, Rank.create(scope, logits), TInt64.class); + Operand rankMinusOne = Sub.create(scope, rank, one); + + Operand lastDimSize = + Slice.create( + scope, + org.tensorflow.op.core.Shape.create(scope, logits, TInt64.class), + rankMinusOne, + one); + Operand concat = + Concat.create( + scope, + Arrays.asList(Constant.arrayOf(scope, -1L), lastDimSize), + Constant.scalarOf(scope, 0)); + return Reshape.create(scope, logits, concat); + } + + /** + * Move the dim to the end if dimIndex is not the last dimension. + * + * @param scope The TensorFlow Scope + * @param input the input to reshape + * @param dimIndex the index to move + * @param rank the number of Dimensions in the tensor + * @param the data type of the tensor. + * @param the data type of the rank + * @return the reshaped input + */ + private static Operand moveDimToEnd( + Scope scope, Operand input, int dimIndex, Operand rank) { + Class rankType = rank.asOutput().type(); + Operand one = Cast.create(scope, Constant.scalarOf(scope, 1), rankType); + List> concatList = + Arrays.asList( + Range.create( + scope, Cast.create(scope, Constant.scalarOf(scope, dimIndex), rankType), one, one), + Range.create( + scope, + Cast.create(scope, Constant.scalarOf(scope, dimIndex + 1), rankType), + one, + one)); + return Transpose.create( + scope, input, Concat.create(scope, concatList, Constant.scalarOf(scope, 0))); + } +} From 056d3ecb0db4287be85d1d844c4ae1e08ece69cb Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 6 May 2021 13:44:06 -0400 Subject: [PATCH 27/31] Initial checkin --- .../annotations/org/tensorflow/op/Ops.java | 12 +- .../framework/layers/Activation.java | 12 +- .../org/tensorflow/framework/layers/Add.java | 3 +- .../framework/layers/AlphaDropout.java | 9 +- .../tensorflow/framework/layers/Average.java | 2 + .../framework/layers/Concatenate.java | 1 - .../tensorflow/framework/layers/Dense.java | 126 +- .../org/tensorflow/framework/layers/Dot.java | 57 +- .../tensorflow/framework/layers/Dropout.java | 8 +- .../org/tensorflow/framework/layers/ELU.java | 9 +- .../tensorflow/framework/layers/Flatten.java | 27 +- .../framework/layers/GaussianDropout.java | 9 +- .../framework/layers/GaussianNoise.java | 5 +- .../tensorflow/framework/layers/Input.java | 71 +- .../tensorflow/framework/layers/Lambda.java | 38 +- .../tensorflow/framework/layers/Layer.java | 133 +- .../framework/layers/LeakyReLU.java | 12 +- .../tensorflow/framework/layers/Maximum.java | 2 + .../tensorflow/framework/layers/Minimum.java | 4 +- .../tensorflow/framework/layers/Multiply.java | 2 + .../org/tensorflow/framework/layers/ReLU.java | 36 +- .../framework/layers/RepeatVector.java | 8 +- .../tensorflow/framework/layers/Softmax.java | 10 +- .../tensorflow/framework/layers/Subtract.java | 3 +- .../framework/layers/ThresholdedReLU.java | 9 +- .../framework/layers/impl/InputSpec.java | 15 + .../framework/layers/impl/Merge.java | 2 +- .../framework/layers/impl/VariableDef.java | 85 +- .../org/tensorflow/framework/op/NnOps.java | 16 + .../tensorflow/framework/op/nn/Softmax.java | 118 ++ .../tensorflow/framework/layers/AddTest.java | 15 +- .../framework/layers/AverageTest.java | 4 +- .../framework/layers/ConcatenateTest.java | 8 +- .../framework/layers/DenseTest.java | 86 +- .../tensorflow/framework/layers/DotTest.java | 35 +- .../tensorflow/framework/layers/ELUTest.java | 96 +- .../framework/layers/GaussianDropoutTest.java | 1 - .../framework/layers/GaussianNoiseTest.java | 6 +- .../framework/layers/InputTest.java | 4 +- .../framework/layers/LambdaTest.java | 2 +- .../framework/layers/LeakyReLUTest.java | 9 +- .../framework/layers/MaximumTest.java | 4 +- .../framework/layers/MinimumTest.java | 4 +- .../framework/layers/MultiplyTest.java | 9 +- .../tensorflow/framework/layers/ReLUTest.java | 8 +- .../framework/layers/RepeatVectorTest.java | 6 +- .../framework/layers/ReshapeTest.java | 215 ++-- .../framework/layers/SubtractTest.java | 15 +- .../framework/layers/impl/InputSpecTest.java | 60 +- .../framework/layers/impl/TensorDotTest.java | 186 --- .../framework/op/LinalgOpsTest.java | 2 +- .../tensorflow/framework/op/SetOpsTest.java | 3 +- .../framework/utils/EagerTestSession.java | 687 +--------- .../framework/utils/GraphTestSession.java | 1103 ++--------------- .../org/tensorflow/framework/utils/ND.java | 6 +- .../framework/utils/QuadConsumer.java | 42 + .../framework/utils/TestSession.java | 865 +++++++++++-- .../framework/utils/TriConsumer.java | 40 + 58 files changed, 1849 insertions(+), 2516 deletions(-) create mode 100644 tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java delete mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/TensorDotTest.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/utils/QuadConsumer.java create mode 100644 tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TriConsumer.java diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 733e7ca7051..250ea35b9fa 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -354,20 +354,20 @@ public final class Ops { public final SparseOps sparse; - public final TpuOps tpu; - public final BitwiseOps bitwise; + public final TpuOps tpu; + public final MathOps math; public final AudioOps audio; public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -385,13 +385,13 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - tpu = new TpuOps(this); bitwise = new BitwiseOps(this); + tpu = new TpuOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java index 5698e4766a2..0bde8e0889c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Activation.java @@ -42,9 +42,7 @@ public class Activation extends Layer { * @param type the data type for the weights and computation */ public Activation( - Ops tf, - org.tensorflow.framework.activations.Activation activation, - Class type) { + Ops tf, org.tensorflow.framework.activations.Activation activation, Class type) { this(tf, null, activation, type, null); } @@ -74,10 +72,10 @@ public Activation( * @param type the data type for the weights and computation */ public Activation( - Ops tf, - String name, - org.tensorflow.framework.activations.Activation activation, - Class type) { + Ops tf, + String name, + org.tensorflow.framework.activations.Activation activation, + Class type) { this(tf, name, activation, type, null); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java index 5a7c0ce65e3..02979c02942 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Add.java @@ -35,7 +35,6 @@ */ public class Add extends Merge { - /** * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. * @@ -51,6 +50,7 @@ public Add(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param type the data type for the weights and computation + * @param options the layer options */ public Add(Ops tf, Class type, Options options) { this(tf, null, type, options); @@ -72,6 +72,7 @@ public Add(Ops tf, String name, Class type) { * @param tf the TensorFlow Ops * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} * @param type the data type for the weights and computation + * @param options the layer options */ public Add(Ops tf, String name, Class type, Options options) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java index 3c3f723ecf7..b8f50991f43 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/AlphaDropout.java @@ -46,8 +46,7 @@ public class AlphaDropout extends Layer { * Creates a AlphaDropout layer, using a unique name will be generated based on {@link * Class#getSimpleName()} and no noiseShape. * - * @param tf the TensorFlow Ops, may be null but will need to be set before the first call to the - * {@link #call} method method is called. + * @param tf the TensorFlow Ops * @param rate A number between 0 and 1. Drop probability (as with {@link Dropout}). The * multiplicative noise will have standard deviation sqrt(rate / (1 - rate)). * @param seed the seed for random number generation. An initializer created with a given seed @@ -64,6 +63,7 @@ public AlphaDropout(Ops tf, float rate, long seed, Class type, Options option * Creates a AlphaDropout layer, using a unique name will be generated based on {@link * Class#getSimpleName()}. * + * @param tf the TensorFlow Ops * @param rate A number between 0 and 1. Drop probability (as with {@link Dropout}). The * multiplicative noise will have standard deviation sqrt(rate / (1 - rate)). * @param noiseShape Optional, 1D integer tensor representing the shape of the binary dropout mask @@ -83,8 +83,7 @@ public AlphaDropout( /** * Creates a AlphaDropout layer * - * @param tf the TensorFlow Ops, may be null but will need to be set before the first call to the - * {@link #call} method method is called. + * @param tf the TensorFlow Ops * @param name name the unique name for this layer. If null, a unique name will be generated based * on {@link Class#getSimpleName()}. * @param rate A number between 0 and 1. Drop probability (as with {@link Dropout}). The @@ -164,7 +163,7 @@ public List> call( outputs.add(result); } - return callPostProcess(convertTo(outputs, resultType), training); + return callPostProcess(convertTo(outputs, resultType), true); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java index 5dd116aa38e..7e31e662258 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Average.java @@ -50,6 +50,7 @@ public Average(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param type the data type for the weights and computation + * @param options the layer's options */ public Average(Ops tf, Class type, Options options) { this(tf, null, type, options); @@ -72,6 +73,7 @@ public Average(Ops tf, String name, Class type) { * @param tf the TensorFlow Ops * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} * @param type the data type for the weights and computation + * @param options the layer's options */ public Average(Ops tf, String name, Class type, Options options) { super(tf, name, type, options); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java index 3243f8f5f17..8af1f9a1c18 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Concatenate.java @@ -153,7 +153,6 @@ public Concatenate(Ops tf, Class type, Options options) { this(tf, null, DEFAULT_AXIS, type, options); } - /** * Creates a Concatenate Layer using {@link Class#getSimpleName()} as the layer name. * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java index d433c7bcb86..77b0219dc45 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dense.java @@ -16,13 +16,14 @@ import org.tensorflow.Operand; import org.tensorflow.framework.activations.Activation; -import org.tensorflow.framework.constraints.Constraint; import org.tensorflow.framework.initializers.Glorot; import org.tensorflow.framework.initializers.Initializer; import org.tensorflow.framework.initializers.VarianceScaling; import org.tensorflow.framework.initializers.Zeros; import org.tensorflow.framework.layers.impl.InputSpec; -import org.tensorflow.framework.op.math.TensorDot; +import org.tensorflow.framework.layers.impl.VariableDef; +import org.tensorflow.framework.op.FrameworkOps; +import org.tensorflow.framework.regularizers.Regularizer; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Variable; @@ -32,6 +33,7 @@ import java.util.Collections; import java.util.List; +import java.util.function.UnaryOperator; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -61,8 +63,12 @@ public class Dense extends Layer { private final Activation activation; private final boolean useBias; private final long seed; - private final Constraint kernelConstraint; - private final Constraint biasConstraint; + + private final UnaryOperator> kernelConstraint; + private final UnaryOperator> biasConstraint; + private final Regularizer biasRegularizer; + private final Regularizer kernelRegularizer; + private Initializer kernelInitializer; private Initializer biasInitializer; private Variable kernel; @@ -78,7 +84,7 @@ public class Dense extends Layer { * @param type the data type for the weights and computation */ public Dense(Ops tf, Integer units, long seed, Class type) { - this(tf, null, units, null, true, null, null, null, null, seed, type, null); + this(tf, null, units, null, true, null, null, null, null, null, null, null, seed, type, null); } /** @@ -92,7 +98,8 @@ public Dense(Ops tf, Integer units, long seed, Class type) { * @param options the layer's options. */ public Dense(Ops tf, Integer units, long seed, Class type, Options options) { - this(tf, null, units, null, true, null, null, null, null, seed, type, options); + this( + tf, null, units, null, true, null, null, null, null, null, null, null, seed, type, options); } /** @@ -107,7 +114,7 @@ public Dense(Ops tf, Integer units, long seed, Class type, Options options) { * @param type the data type for the weights and computation */ public Dense(Ops tf, String name, Integer units, long seed, Class type) { - this(tf, name, units, null, true, null, null, null, null, seed, type, null); + this(tf, name, units, null, true, null, null, null, null, null, null, null, seed, type, null); } /** @@ -123,7 +130,8 @@ public Dense(Ops tf, String name, Integer units, long seed, Class type) { * @param options the layer's options. */ public Dense(Ops tf, String name, Integer units, long seed, Class type, Options options) { - this(tf, name, units, null, true, null, null, null, null, seed, type, options); + this( + tf, name, units, null, true, null, null, null, null, null, null, null, seed, type, options); } /** @@ -138,24 +146,47 @@ public Dense(Ops tf, String name, Integer units, long seed, Class type, Optio * @param useBias whether the layer uses a bias vector. * @param kernelInitializer Initializer for the kernel weights matrix. * @param biasInitializer Initializer for the bias vector. + * @param kernelRegularizer Regularizer applied to the kernel weights matrix. + * @param biasRegularizer Regularizer function applied to the bias vector. + * @param activityRegularizer Regularizer function applied to the output of the layer (its + * "activation"). + * @param kernelConstraint a constraint on the kernel variable + * @param biasConstraint a constraint on the bias variable * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the weights and computation */ - @SuppressWarnings("unchecked") public Dense( - Ops tf, - String name, - Integer units, - Activation activation, - boolean useBias, - Initializer kernelInitializer, - Initializer biasInitializer, - Constraint kernelConstraint, - Constraint biasConstraint, - long seed, - Class type) { - this(tf, name, units, activation, useBias, kernelInitializer, biasInitializer, kernelConstraint, biasConstraint, seed, type, null); + Ops tf, + String name, + Integer units, + Activation activation, + boolean useBias, + Initializer kernelInitializer, + Initializer biasInitializer, + Regularizer kernelRegularizer, + Regularizer biasRegularizer, + Regularizer activityRegularizer, + UnaryOperator> kernelConstraint, + UnaryOperator> biasConstraint, + long seed, + Class type) { + this( + tf, + name, + units, + activation, + useBias, + kernelInitializer, + biasInitializer, + kernelRegularizer, + biasRegularizer, + activityRegularizer, + kernelConstraint, + biasConstraint, + seed, + type, + null); } /** * Creates a Dense layer. @@ -169,6 +200,12 @@ public Dense( * @param useBias whether the layer uses a bias vector. * @param kernelInitializer Initializer for the kernel weights matrix. * @param biasInitializer Initializer for the bias vector. + * @param kernelRegularizer Regularizer applied to the kernel weights matrix. + * @param biasRegularizer Regularizer function applied to the bias vector. + * @param activityRegularizer Regularizer function applied to the output of the layer (its + * "activation"). + * @param kernelConstraint a constraint on the kernel variable + * @param biasConstraint a constraint on the bias variable * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and data type. * @param type the data type for the weights and computation @@ -183,8 +220,11 @@ public Dense( boolean useBias, Initializer kernelInitializer, Initializer biasInitializer, - Constraint kernelConstraint, - Constraint biasConstraint, + Regularizer kernelRegularizer, + Regularizer biasRegularizer, + Regularizer activityRegularizer, + UnaryOperator> kernelConstraint, + UnaryOperator> biasConstraint, long seed, Class type, Options options) { @@ -200,6 +240,9 @@ public Dense( this.biasInitializer = biasInitializer != null ? biasInitializer : new Zeros<>(tf); this.kernelConstraint = kernelConstraint; this.biasConstraint = biasConstraint; + this.biasRegularizer = biasRegularizer; + this.kernelRegularizer = kernelRegularizer; + setActivityRegularizer(activityRegularizer); this.seed = seed; addInputSpec(new InputSpec(InputSpec.Options.create().minRank(2))); setSupportsMasking(true); @@ -216,7 +259,8 @@ public Dense( * @param resultType the data tupe for the result * @param the data tupe for the result * @return the output with shape {@code (batch_size, ..., units)}. For instance, for a 2D input - * with shape {@code (batch_size, input_dim)}, the output would have shape {@code (batch_size, units)}. + * with shape {@code (batch_size, input_dim)}, the output would have shape {@code (batch_size, + * units)}. */ @Override public List> call( @@ -228,21 +272,25 @@ public List> call( throw new IllegalArgumentException("Dense only supports 1 input."); Operand singleInput = inputs.get(0); Operand input = cast(getTF(), singleInput, getType()); - System.out.println("Dense.call: " + input.shape()); if (!isBuilt()) build(input.shape()); Shape inputShape = input.shape(); int rank = inputShape.numDimensions(); Operand tOutput; - System.out.println("Dense input: " + inputShape); if (rank == 2 || rank == Shape.UNKNOWN_SIZE) { tOutput = getTF().linalg.matMul(input, getKernel()); } else { - tOutput = TensorDot.tensordot(getTF().scope(), input, getKernel(), new int[] {rank - 1, 0}); + FrameworkOps fops = FrameworkOps.create(getTF()); + tOutput = fops.math.tensordot(input, getKernel(), new int[] {rank - 1, 0}); // Reshape the output back to the original number of dimensions of the input. Shape newShape = inputShape.take(rank - 1).append(getUnits()); tOutput = getTF().reshape(tOutput, getTF().constant(newShape)); } - if (isUseBias()) tOutput = getTF().nn.biasAdd(tOutput, getBias()); + if (isUseBias()) { + tOutput = getTF().nn.biasAdd(tOutput, getBias()); + } + if (activation != null) { + tOutput = activation.call(tOutput); + } return callPostProcess(Collections.singletonList(cast(getTF(), tOutput, resultType)), training); } @@ -263,13 +311,12 @@ public void build(List inputShapes) { if (kernelInitializer == null) { // Cast is required because Glorot is TFloating. kernelInitializer = new Glorot<>(getTF(), VarianceScaling.Distribution.UNIFORM, getSeed()); - } + } if (biasInitializer == null) { biasInitializer = new Zeros<>(getTF()); } Shape inputShape = inputShapes.get(0); - System.out.println("dense.build: " + inputShape); if (inputShape.size(-1) == Shape.UNKNOWN_SIZE) { throw new IllegalArgumentException( "The last dimension of the inputs to `Dense` should be defined. Found `UNKNOWN`."); @@ -279,23 +326,34 @@ public void build(List inputShapes) { kernel = addWeight( - getName() + "_kernel", + "kernel", Shape.of(lastDim, this.getUnits()), kernelInitializer, kernelConstraint, + kernelRegularizer, true, getSeed()); if (isUseBias()) bias = addWeight( - getName() + "_bias", + "bias", Shape.of(this.getUnits()), biasInitializer, biasConstraint, + biasRegularizer, true, getSeed()); } + public Operand applyConstraint(Variable variable) { + VariableDef variableDef = getVariableDef(variable); + if(variableDef != null && variableDef.getConstraint() != null) { + return variableDef.getConstraint().apply(variable); + }else { + return variable; + } + } + /** {@inheritDoc} */ @Override public List computeOutputShape(List inputShapes) { @@ -307,7 +365,7 @@ public List computeOutputShape(List inputShapes) { throw new IllegalArgumentException( String.format( "Dense layer: The innermost dimension of input_shape must be defined, but saw: %s", - singleShape)); + singleShape)); Shape headShape = singleShape.take(singleShape.numDimensions() - 1).append(getUnits()); return Collections.singletonList(headShape); @@ -352,7 +410,7 @@ public Variable getKernel() { /** * Gets the bias variable * - * @return + * @return the bias variable */ public Variable getBias() { return bias; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java index 5ad01445558..bb0353ec3e6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dot.java @@ -16,7 +16,6 @@ import org.tensorflow.Operand; import org.tensorflow.framework.layers.impl.Merge; -import org.tensorflow.framework.losses.Losses; import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; @@ -46,9 +45,6 @@ public class Dot extends Merge { private final int[] axes; private final boolean normalize; - private boolean reshapeRequired; - - /** * Creates a Layer that computes a dot product between samples in two tensors, using {@link * Class#getSimpleName()} as the layer name, and no L2 Normalization. @@ -88,7 +84,6 @@ public Dot(Ops tf, int[] axes, Class type) { this(tf, null, axes, false, type, null); } - /** * Creates a Layer that computes a dot product between samples in two tensors, using {@link * Class#getSimpleName()} as the layer name and no L2 Normalization. @@ -104,7 +99,6 @@ public Dot(Ops tf, int[] axes, Class type, Options options) { this(tf, null, axes, false, type, options); } - /** * Creates a Layer that computes a dot product between samples in two tensors with no L2 * Normalization. @@ -251,6 +245,7 @@ protected void build(List inputShapes) { int[] newAxes; if (axes.length == 1) { newAxes = new int[2]; + // covert negative axes if (axes[0] < 0) { newAxes[0] = Math.floorMod(axes[0], shape1.numDimensions()); newAxes[1] = Math.floorMod(axes[0], shape2.numDimensions()); @@ -303,7 +298,7 @@ protected Operand mergeFunction(List> inputs) { if (normalize) { FrameworkOps fops = FrameworkOps.create(tf); input1 = fops.math.l2Normalize(input1, new int[] {axes[0]}); - input2 = fops.math.l2Normalize(input2, new int[] {axes[0]}); + input2 = fops.math.l2Normalize(input2, new int[] {axes[1]}); } return batchDot(input1, input2, newAxes); } @@ -355,7 +350,7 @@ public List computeOutputShape(List inputShapes) { Shape outputShape = shape1.append(shape2); if (outputShape.numDimensions() == 1) { - outputShape.append(1); + outputShape = outputShape.append(1); } return Collections.singletonList(outputShape); } @@ -371,16 +366,18 @@ public List computeOutputShape(List inputShapes) { * * @param x Operand with numdimensions >= 2. * @param y Operand with numdimensions >= 2. - * @param dotAxes the axes to peform the Dot Product. + * @param axes the axes to peform the Dot Product. * @return A operand with shape equal to the concatenation of x's shape (less the * dimension that was summed over) and y's shape (less the batch dimension and * the dimension that was summed over). If the final rank is 1, the result is reshaped to * (batch_size, 1). */ private Operand batchDot( - Operand x, Operand y, int[] dotAxes) { + Operand x, Operand y, int[] axes) { Ops tf = getTF(); FrameworkOps fops = FrameworkOps.create(tf); + // make local copy for changes later + int[] dotAxes = axes; Operand tX = cast(tf, x, getType()); Operand tY = cast(tf, y, getType()); @@ -410,11 +407,10 @@ private Operand batchDot( if (dotAxes == null) { dotAxes = new int[2]; + dotAxes[0] = xRank - 1; if (yRank == 2) { - dotAxes[0] = xRank - 1; dotAxes[1] = yRank - 1; } else { - dotAxes[0] = xRank - 1; dotAxes[1] = yRank - 2; } } else if (dotAxes.length == 1) { @@ -441,9 +437,10 @@ private Operand batchDot( throw new IllegalArgumentException( String.format( "Cannot do batch_dot on inputs with shapes %s and %s with axes %s. x.shape[%d] != %d, y.shape[%d] != %d", - xShape, yShape, Arrays.toString(dotAxes), a0, d1, d2)); + xShape, yShape, Arrays.toString(dotAxes), a0, d1, a1, d2)); } + // backup rank. Need them rank. int origXRank = xRank; int origYRank = yRank; if (xRank == 2) { @@ -459,11 +456,12 @@ private Operand batchDot( // move x's dimension to be reduced to last axis. if (a0 != xRank - 1) { int[] pattern = new int[xRank]; + // move a0 to last for (int i = 0; i < a0; i++) { pattern[i] = i; } - for (int i = a0, j = 0; i < xRank; i++) { - pattern[j++] = i; + for (int i = a0; i < xRank - 1; i++) { + pattern[i] = i + 1; } pattern[xRank - 1] = a0; tX = tf.linalg.transpose(tX, tf.constant(pattern)); @@ -471,18 +469,17 @@ private Operand batchDot( // move y's dimension to be reduced to axis 1. if (a1 != 1) { int[] pattern = new int[yRank]; - - for (int i = 0, j = 0; i < xRank; i++) { - if (i == 1) { // leave dim 1 slot open - j++; - continue; - } - if (i == a1) { // skip a1 dim - continue; - } - pattern[j++] = i; + pattern[0] = 0; + // skip slot 1 + for (int i = 1; i < a1; i++) { + pattern[i + 1] = i; + } + for (int i = a1; i < pattern.length - 1; i++) { + pattern[i + 1] = i + 1; } pattern[1] = a1; + //noinspection SuspiciousNameCombination + tY = tf.linalg.transpose(tY, tf.constant(pattern)); } // normalize both inputs to rank 3. @@ -490,15 +487,15 @@ private Operand batchDot( Operand xMidShape = null; if (xRank > 3) { org.tensorflow.op.core.Shape tmpShape = tf.shape(tX, TInt64.class); - xMidShape = tf.shape.take(tmpShape, tf.constant((long) (xRank)), TInt64.class); xMidShape = tf.shape.takeLast(tmpShape, tf.constant((long) (xRank - 1)), TInt64.class); Operand squashedShape = tf.stack( Arrays.asList( - tf.shape.size(tmpShape, tf.constant(0l), TInt64.class), + tf.shape.size(tmpShape, tf.constant(0L), TInt64.class), tf.constant(Shape.UNKNOWN_SIZE), tf.shape.size(tmpShape, tf.constant((long) (xRank - 1)), TInt64.class))); + tX = tf.reshape(tX, squashedShape); xSquashed = true; } @@ -515,15 +512,15 @@ private Operand batchDot( tf.shape.size(y, tf.constant(0L), TInt64.class), tf.shape.size(y, tf.constant(1L), TInt64.class), tf.constant(-1L))); + tY = tf.reshape(tY, squashedShape); ySquashed = true; } - Operand result = fops.linalg.matmul(tX, tY); boolean doReshape = false; Operand outputShape = tf.shape(result, TInt64.class); - if (xSquashed && xMidShape != null) { + if (xSquashed) { outputShape = tf.concat( Arrays.asList( @@ -534,7 +531,7 @@ private Operand batchDot( doReshape = true; } - if (ySquashed && yTrailDims != null) { + if (ySquashed) { outputShape = tf.concat( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java index 340ff278ac9..6b9e98feda8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Dropout.java @@ -138,13 +138,7 @@ public Dropout(Ops tf, float rate, Shape noiseShape, long seed, Class type, O * will always produce the same random tensor for a given shape and data type. * @param type the data type for the weights and computation */ - public Dropout( - Ops tf, - String name, - float rate, - Shape noiseShape, - long seed, - Class type) { + public Dropout(Ops tf, String name, float rate, Shape noiseShape, long seed, Class type) { this(tf, name, rate, noiseShape, seed, type, null); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java index b0d453135e6..90dfc6e2d4e 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ELU.java @@ -30,10 +30,10 @@ * *

It follows:: * - *

- *     f(x) =  alpha * (exp(x) - 1.) for x < 0
- *     f(x) = x for x >= 0
- * 
+ *
{@code
+ * f(x) =  alpha * (exp(x) - 1.) for x < 0
+ * f(x) = x for x >= 0
+ * }
* * @param the data type for the layer's weights and computation. */ @@ -71,6 +71,7 @@ public ELU(Ops tf, String name, Class type) { * @param tf the TensorFlow Ops. * @param alpha Negative slope coefficient. Must be >= 0. * @param type the data type for the layer's weights and computation. + * @param options the layer's options */ public ELU(Ops tf, float alpha, Class type, Options options) { this(tf, null, alpha, type, options); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java index f9f858d6953..2a07d6f623f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Flatten.java @@ -31,13 +31,12 @@ /** * Flattens the input. Does not affect the batch size. * - *

Note: If inputs are shaped (batch,) without a feature axis, then flattening adds an extra - * channel dimension and output shape is . + *

Note: If inputs are shaped {@code (batch,)} without a feature axis, then flattening + * adds an extra channel dimension and output shape is {@code (batch, 1)}. * * @param the data type for the layer's weights and computation. */ public class Flatten extends Layer { - private static final int FLATTEN_INPUT_LENGTH = 1; private final TensorFormat dataFormat; /** @@ -64,15 +63,13 @@ public Flatten(Ops tf, String name, Class type) { this(tf, name, TensorFormat.NHWC, type, null); } - /** * Creates a Flatten Layer with a unique name generated based on * {@link Class#getSimpleName()}. * * @param tf the TensorFlow Ops. * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} - * corresponds to inputs with shape (batch, ..., channels) - * while {@link TensorFormat#NCHW} corresponds to inputs with shape - * (batch, channels, ...). + * corresponds to inputs with shape {@code (batch, ..., channels) } while {@link + * TensorFormat#NCHW} corresponds to inputs with shape {@code (batch, channels, ...)}. * @param type the data type for the layer's weights and computation. */ public Flatten(Ops tf, TensorFormat dataFormat, Class type) { @@ -84,10 +81,10 @@ public Flatten(Ops tf, TensorFormat dataFormat, Class type) { * * @param tf the TensorFlow Ops. * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} - * corresponds to inputs with shape (batch, ..., channels) - * while {@link TensorFormat#NCHW} corresponds to inputs with shape - * (batch, channels, ...). + * corresponds to inputs with shape {@code (batch, ..., channels) } while {@link + * TensorFormat#NCHW} corresponds to inputs with shape {@code (batch, channels, ...)}. * @param type the data type for the layer's weights and computation. + * @param options the layer's options */ public Flatten(Ops tf, TensorFormat dataFormat, Class type, Options options) { this(tf, null, dataFormat, type, options); @@ -100,9 +97,8 @@ public Flatten(Ops tf, TensorFormat dataFormat, Class type, Options options) * @param name the unique name for this layer. If null, a unique name will be generated based on * {@link Class#getSimpleName()}. * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} - * corresponds to inputs with shape (batch, ..., channels) - * while {@link TensorFormat#NCHW} corresponds to inputs with shape - * (batch, channels, ...). + * corresponds to inputs with shape {@code (batch, ..., channels) } while {@link + * TensorFormat#NCHW} corresponds to inputs with shape {@code (batch, channels, ...)}. * @param type the data type for the layer's weights and computation. */ public Flatten(Ops tf, String name, TensorFormat dataFormat, Class type) { @@ -115,9 +111,8 @@ public Flatten(Ops tf, String name, TensorFormat dataFormat, Class type) { * @param name the unique name for this layer. If null, a unique name will be generated based on * {@link Class#getSimpleName()}. * @param dataFormat The ordering of the dimensions in the inputs. {@link TensorFormat#NHWC} - * corresponds to inputs with shape (batch, ..., channels) - * while {@link TensorFormat#NCHW} corresponds to inputs with shape - * (batch, channels, ...). + * corresponds to inputs with shape {@code (batch, ..., channels) } while {@link + * TensorFormat#NCHW} corresponds to inputs with shape {@code (batch, channels, ...)}. * @param type the data type for the layer's weights and computation. * @param options the layer's options. */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java index 73707add188..9c1a15eaf74 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianDropout.java @@ -83,12 +83,10 @@ public GaussianDropout(Ops tf, float rate, long seed, Class type, Options opt * will always produce the same random tensor for a given shape and data type. * @param type the data type for the weights and computation */ - public GaussianDropout( - Ops tf, String name, float rate, long seed, Class type) { + public GaussianDropout(Ops tf, String name, float rate, long seed, Class type) { this(tf, name, rate, seed, type, null); } - /** * Creates a GaussianDropout layer * @@ -112,7 +110,6 @@ public GaussianDropout( } /** {@inheritDoc} */ - @SuppressWarnings("unchecked") @Override public List> call( List> inputs, @@ -127,10 +124,10 @@ public List> call( Operand output = cast(tf, input, getType()); + // if in training mode do dropout, otherwise don't + //noinspection IfStatementWithIdenticalBranches if (training && rate >= 0 && rate <= 1) { - Operand rateV = cast(tf, tf.constant(rate), getType()); - output = dropout(output, rateV, seed); outputs.add(output); } else { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java index ce71d7e5805..18718a68f15 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/GaussianNoise.java @@ -57,7 +57,6 @@ public GaussianNoise(Ops tf, float stddev, long seed, Class type) { this(tf, null, stddev, seed, type, null); } - /** * Creates a GaussianNoise layer, using a unique name will be generated based on {@link * Class#getSimpleName()} and no noiseShape. @@ -86,8 +85,7 @@ public GaussianNoise(Ops tf, float stddev, long seed, Class type, Options opt * will always produce the same random tensor for a given shape and data type. * @param type the data type for the weights and computation */ - public GaussianNoise( - Ops tf, String name, float stddev, long seed, Class type) { + public GaussianNoise(Ops tf, String name, float stddev, long seed, Class type) { this(tf, name, stddev, seed, type, null); } /** @@ -111,7 +109,6 @@ public GaussianNoise( } /** {@inheritDoc} */ - @SuppressWarnings("unchecked") @Override public List> call( List> inputs, diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java index adc2ec8f599..2a2c414b86b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Input.java @@ -51,7 +51,6 @@ public Input(Ops tf, Operand input, Class type) { this(tf, null, input, null, type, null); } - /** * Creates an input layer using {@link Class#getSimpleName()} for the name. * @@ -74,8 +73,7 @@ public Input(Ops tf, Operand input, Class type, Options opti * @param input The input * @param type the data type for the layer's weights and computation. */ - public Input( - Ops tf, String name, Operand input, Class type) { + public Input(Ops tf, String name, Operand input, Class type) { this(tf, name, input, null, type, null); } @@ -88,6 +86,7 @@ public Input( * Class#getSimpleName()} * @param input The input * @param type the data type for the layer's weights and computation. + * @param options the layer's options */ public Input( Ops tf, String name, Operand input, Class type, Options options) { @@ -129,8 +128,7 @@ public Input(Ops tf, Class inputType, Class type, Options op * @param inputType the data type for the input and output, if null, input.type() is used * @param type the data type for the layer's weights and computation. */ - public Input( - Ops tf, String name, Class inputType, Class type) { + public Input(Ops tf, String name, Class inputType, Class type) { this(tf, name, null, inputType, type, null); } @@ -162,11 +160,11 @@ public Input( * null, and if both inputShape and input are null. */ public Input( - Ops tf, - String name, - Operand input, - Class inputType, - Class type) { + Ops tf, + String name, + Operand input, + Class inputType, + Class type) { this(tf, name, input, inputType, type, null); } /** @@ -190,24 +188,20 @@ public Input( Class type, Options options) { super(tf, name, true, type, options); - Options c = getInstanceOptions(); + Options inputOptions = getInstanceOptions(); if (inputType == null && input == null) { throw new IllegalArgumentException("both input and inputType cannot be null"); - } + } if (input != null && inputType != null && !input.type().equals(inputType)) { throw new IllegalArgumentException( String.format("input.type() differs from inputType: %s vs. %s", input.type(), inputType)); - } - - //if ((c == null || c.inputShape == null) && input == null) { - // throw new IllegalArgumentException("both input and inputShape cannot be null"); - // } + } - if (c != null) { - if ( c.inputShape != null - && (c.batchSize != null || c.batchInputShape != null)) { + if (inputOptions != null) { + if (inputOptions.inputShape != null + && (inputOptions.batchSize != null || inputOptions.batchInputShape != null)) { throw new IllegalArgumentException( "Only provide the inputShape or the batchSize or batchInputShape parameters at the size."); } @@ -215,20 +209,30 @@ public Input( Shape lShape; - if (c != null && c.batchInputShape != null) { - lShape = c.batchInputShape.takeLast(c.batchInputShape.numDimensions() - 1); - setBatchInputShape(c.batchInputShape); + if (inputOptions != null && inputOptions.batchInputShape != null) { + lShape = + inputOptions.batchInputShape.takeLast(inputOptions.batchInputShape.numDimensions() - 1); + setBatchInputShape(inputOptions.batchInputShape); if (getBatchSize() == null) { - setBatchSize(c.batchInputShape.size(0)); + setBatchSize(inputOptions.batchInputShape.size(0)); } } else { - if(input == null) { - lShape = (c == null || c.inputShape == null) ? Shape.of(Shape.UNKNOWN_SIZE) : c.inputShape; - }else { - lShape = (c == null || c.inputShape == null) ? input.shape() : c.inputShape; - } + if (input == null) { + lShape = + (inputOptions == null || inputOptions.inputShape == null) + ? Shape.of(Shape.UNKNOWN_SIZE) + : inputOptions.inputShape; + } else { + lShape = + (inputOptions == null || inputOptions.inputShape == null) + ? input.shape() + : inputOptions.inputShape; + } - setBatchSize((c == null || c.batchSize == null) ? Shape.UNKNOWN_SIZE : c.batchSize); + setBatchSize( + (inputOptions == null || inputOptions.batchSize == null) + ? Shape.UNKNOWN_SIZE + : inputOptions.batchSize); setBatchInputShape(Shape.of(getBatchSize()).append(lShape)); } @@ -253,8 +257,7 @@ public Input( * @param the data type for the layer's calculations. * @return the output */ - public static Operand input( - Ops tf, Class type) { + public static Operand input(Ops tf, Class type) { return input(tf, type, null); } @@ -267,8 +270,7 @@ public static Operand input( * @param the data type for the layer's calculations. * @return the output */ - public static Operand input( - Ops tf, Class type, Options options) { + public static Operand input(Ops tf, Class type, Options options) { Input layer = new Input<>(tf, type, type, options); return layer.getOutput(type); } @@ -318,6 +320,7 @@ public Operand getOutput() { * {@link #call} methods * * @param resultType the output data type + * @param the data type for the result * @return the output Operand. */ public Operand getOutput(Class resultType) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java index 53575b64d3d..cb2f1e9048c 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Lambda.java @@ -37,9 +37,9 @@ *

the Java lambda function is in the form x = function(tf, input). The first * argument is the TensorFlow Ops, the second argument is the input Operand. For example: * - *

- *        Lambda lambda = new Lambda(tf, (ops, input) -> ops.math.mul(ops.constant(2), input), TFloat32.class);
- *    
+ *
{@code
+ * Lambda lambda = new Lambda(tf, (ops, input) -> ops.math.mul(ops.constant(2), input), TFloat32.class);
+ * }
* * @param the data type for the layer's weights and computation. */ @@ -47,24 +47,19 @@ public class Lambda extends Layer { private BiFunction, Operand> function; /** - * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName() + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops - * @param function the Java lambda function in the form x = function(tf, input). - * The first argument is the TensorFlow Ops, the second argument is the input Operand. * @param type the data type for the layer's weights and computation. - * @param options the layer's options. */ public Lambda(Ops tf, Class type) { - this(tf, null, null, type, null); + this(tf, null, null, type, null); } /** - * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName() + * Creates a Lambda layer, generating a unique name based on {@link Class#getSimpleName()} * * @param tf the TensorFlow Ops - * @param function the Java lambda function in the form x = function(tf, input). - * The first argument is the TensorFlow Ops, the second argument is the input Operand. * @param type the data type for the layer's weights and computation. * @param options the layer's options. */ @@ -91,6 +86,7 @@ public Lambda(Ops tf, String name, Class type) { * @param name the unique name for this layer, if null, generates a unique name based on {@link * Class#getSimpleName()}. * @param type the data type for the layer's weights and computation. + * @param options the layer's options. */ public Lambda(Ops tf, String name, Class type, Options options) { this(tf, name, null, type, options); @@ -101,13 +97,11 @@ public Lambda(Ops tf, String name, Class type, Options options) { * * @param tf the TensorFlow Ops * @param function The Java lambda function in the form x = function(tf, input). The - * first argument is the TensorFlow Ops, the second argument is the input Operand. If function - * is null, then the input is returned un changed. + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. * @param type the data type for the layer's weights and computation. - */ - public Lambda( - Ops tf, BiFunction, Operand> function, Class type) { + public Lambda(Ops tf, BiFunction, Operand> function, Class type) { this(tf, null, function, type, null); } @@ -116,8 +110,8 @@ public Lambda( * * @param tf the TensorFlow Ops * @param function The Java lambda function in the form x = function(tf, input). The - * first argument is the TensorFlow Ops, the second argument is the input Operand. If function - * is null, then the input is returned un changed. + * first argument is the TensorFlow Ops, the second argument is the input Operand. If function + * is null, then the input is returned un changed. * @param type the data type for the layer's weights and computation. * @param options the layer's options. */ @@ -138,10 +132,7 @@ public Lambda( * @param type the data type for the layer's weights and computation. */ public Lambda( - Ops tf, - String name, - BiFunction, Operand> function, - Class type) { + Ops tf, String name, BiFunction, Operand> function, Class type) { this(tf, name, function, type, null); } @@ -155,6 +146,7 @@ public Lambda( * first argument is the TensorFlow Ops, the second argument is the input Operand. If function * is null, then the input is returned un changed. * @param type the data type for the layer's weights and computation. + * @param options the layer's options */ public Lambda( Ops tf, @@ -173,7 +165,7 @@ public Lambda( * x = function(tf, input)
. The first argument is the TensorFlow Ops, the second * argument is the input Operand. */ - public void setLamda(BiFunction, Operand> function) { + public void setLambda(BiFunction, Operand> function) { this.function = function; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java index 8df54232e1f..60ee9c6a9e0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Layer.java @@ -15,16 +15,17 @@ package org.tensorflow.framework.layers; import org.tensorflow.Operand; -import org.tensorflow.framework.constraints.Constraint; import org.tensorflow.framework.initializers.Initializer; import org.tensorflow.framework.layers.impl.InputSpec; import org.tensorflow.framework.layers.impl.VariableDef; import org.tensorflow.framework.losses.Loss; import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.framework.regularizers.Regularizer; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Variable; import org.tensorflow.types.TBool; +import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -34,6 +35,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.UnaryOperator; import java.util.stream.Collectors; import static org.tensorflow.framework.utils.CastHelper.cast; @@ -59,6 +61,9 @@ public abstract class Layer { private final List> trainableWeights = new ArrayList<>(); private final List> nonTrainableWeights = new ArrayList<>(); private final List losses = new ArrayList<>(); + // some loss operations don't have an associated Loss class, so this property holds + // the Operands to calculate the loss, used in the model. + private final List> lossOperations = new ArrayList<>(); private final List> metrics = new ArrayList<>(); private final Map, VariableDef> variableMap = new HashMap<>(); // Note that, unlike other classes, tf may not be set in the constructor, but may be set later. @@ -66,8 +71,7 @@ public abstract class Layer { // sets the tf instance probably during the model.compile phase. private final Ops tf; private boolean trainable; - // TODO change to Regularizer class - private Object activityRegularizer; + private Regularizer activityRegularizer; private boolean built; private boolean stateful; private boolean supportsMasking; @@ -148,12 +152,12 @@ private String genName() { * Invokes the layer's algorithm using a single input, returning a single output. Training mode is * true. * - *

This is a convenience call on top of {@link {@link #call}}. + *

This is a convenience call on top of {@link #call}}. * * @param input the input Operand * @return the output Operand, or null if no output is generated from the layer's logic. */ - public Operand call(Operand input) { + public Operand call(Operand input) { return call(input, null, true, getType()); } @@ -162,9 +166,11 @@ public Operand call(Operand input) { * Invokes the layer's algorithm using a single input, returning a single output. Training mode is * true. * - *

This is a convenience call on top of {@link {@link #call}}. + *

This is a convenience call on top of {@link #call}}. * * @param input the input Operand + * @param type the data type for the result + * @param the data type for the result * @return the output Operand, or null if no output is generated from the layer's logic. */ public Operand call(Operand input, Class type) { @@ -179,6 +185,8 @@ public Operand call(Operand input, Class the data type for the result * @return the output Operand, or null if no output is generated from the layer's logic. */ public Operand call( @@ -194,6 +202,8 @@ public Operand call( * @param input the input Operand * @param mask the mask to apply to the result, may be null * @param training whether the call is in inference mode or training mode + * @param type the data type for the result + * @param the data type for the result * @return the output Operand, or null if no output is generated from the layer's logic. */ public Operand call( @@ -207,6 +217,8 @@ public Operand call( * Invokes the layer's algorithm Training mode is true. * * @param inputs the input Operands + * @param type the data type for the result + * @param the data type for the result * @return the output Operands */ public List> call( @@ -220,6 +232,8 @@ public List> call( * @param inputs the input Operands * @param masks a list of masks, one for each input, to apply to the result, may be null * @param training whether the call is in inference mode or training mode + * @param type the data type for the result + * @param the data type for the result * @return the output Operands. */ public abstract List> call( @@ -232,11 +246,26 @@ public abstract List> call( * Post processes a layer's call result * * @param inputs the input Operands + * @param training true if in training mode + * @param the data type of the inputs and result * @return the output Operands. */ protected List> callPostProcess( - List> inputs, boolean training) { - return handleActivityRegister(inputs); + List> inputs, @SuppressWarnings("unused") boolean training) { + if (activityRegularizer != null && !inputs.isEmpty()) { + boolean aTNumber = TNumber.class.isAssignableFrom(inputs.get(0).type()); + if (aTNumber) { + inputs.forEach( + input -> { + if (input.type() != TString.class) { + Operand tInput = cast(tf, input, getType()); + addLossOperation(activityRegularizer.call(tInput)); + } + }); + } + } + + return inputs; } /** @@ -252,6 +281,8 @@ protected List> convertList(List> inputs) { * Converts a list of inputs to a new list of the internal data type defined for this layer. * * @param inputs the inputs. + * @param resultType the data type of the result + * @param the data type of the result * @return the new list converted to the new type. */ protected List> convertList( @@ -276,16 +307,6 @@ protected List> convertTo( return result; } - private List> handleActivityRegister(List> inputs) { - if (this.activityRegularizer != null) { - // TODO activityRegularizer - return inputs; - - } else { - return inputs; - } - } - /** * Creates the variables of the layer (optional, for subclass implementers). This is a method that * implementers of subclasses of Layer or Model can override if they @@ -407,21 +428,28 @@ public List> getNonTrainableWeights() { * @param name the variable's name * @param shape the variable's shape * @param initializer the variable initializer + * @param constraint a constraint to be applied to the weight + * @param regularizer Regularizer instance * @param trainable whether the variable should be part of the layer's "trainableWeights" + * @param seed a seed value for random number generation * @throws IllegalStateException if the property {@link #tf} has not been set yet. + * @return the variable created for the weight */ public Variable addWeight( String name, Shape shape, Initializer initializer, - Constraint constraint, + UnaryOperator> constraint, + Regularizer regularizer, boolean trainable, long seed) { if (tf == null) { throw new IllegalStateException("Parameter \"tf\" has not been set"); } + VariableDef variableDef = - new VariableDef<>(tf, name, shape, initializer, constraint, trainable, seed, getType()); + new VariableDef<>( + tf, name, shape, initializer, constraint, regularizer, trainable, seed, getType()); Variable variable = variableDef.getVariable(); @@ -432,21 +460,36 @@ public Variable addWeight( return variable; } + /** + * Gets the VariableDef for the specified variable + * + * @param variable the variable + * @return the VariableDef + */ + public VariableDef getVariableDef(Variable variable) { + return variableMap.get(variable); + } + /** * Adds a weight to the layer * + * @param name the weight name * @param variable the variable to add * @param initializer the variable initializer + * @param constraint the constraint on the variable + * @param regularizer the regularizer for the variable * @param trainable whether the variable should be part of the layer's "trainableWeights" * @param seed the seed for random number generation. An initializer created with a given seed * will always produce the same random tensor for a given shape and type. * @throws IllegalStateException if the property {@link #tf} has not been set yet. + * @return the variable created for the weight */ public Variable addWeight( String name, Variable variable, Initializer initializer, - Constraint constraint, + UnaryOperator> constraint, + Regularizer regularizer, boolean trainable, long seed) { if (tf == null) { @@ -456,7 +499,8 @@ public Variable addWeight( throw new IllegalStateException("Parameter \"variable\" has not been set"); } VariableDef variableDef = - new VariableDef<>(tf, name, variable, initializer, constraint, trainable, seed); + new VariableDef<>( + tf, name, variable, initializer, constraint, regularizer, trainable, seed); variableMap.put(variable, variableDef); weights.add(variable); if (trainable) trainableWeights.add(variable); @@ -489,7 +533,7 @@ public List> initializeWeights(long seed) { public Operand initializeWeight(Variable weight, long seed) { VariableDef varDef = variableMap.get(weight); if (varDef == null) { // this should not happen if addWeight was used to create/add the weight - addWeight(null, weight, null, null, true, seed); + addWeight(null, weight, null, null, null, true, seed); varDef = variableMap.get(weight); } return varDef.init(); @@ -527,6 +571,15 @@ public List getLosses() { return losses; } + /** + * Gets the Loss Operations assigned to this layer + * + * @return the Loss Operations assigned to this layer + */ + public List> getLossOperations() { + return lossOperations; + } + /** * Adds a loss to this layer * @@ -536,6 +589,15 @@ public void addLoss(Loss loss) { losses.add(loss); } + /** + * Adds a loss operation to this layer + * + * @param lossOperation the loss operation + */ + public void addLossOperation(Operand lossOperation) { + this.lossOperations.add(lossOperation); + } + /** * Adds losses to this layer * @@ -545,6 +607,15 @@ public void addLosses(List losses) { this.losses.addAll(losses); } + /** + * Adds loss operations to this layer + * + * @param lossOperations the loss operations to add + */ + public void addLossOperations(List> lossOperations) { + this.lossOperations.addAll(lossOperations); + } + /** * Gets the Losses assigned to this layer * @@ -577,6 +648,7 @@ public void addMetrics(List> metrics) { * * @return true, if the build method has been called. */ + @SuppressWarnings("BooleanMethodIsAlwaysInverted") public boolean isBuilt() { return built; } @@ -749,8 +821,7 @@ public Object getActivityRegularizer() { * * @param activityRegularizer the activity Regularizer */ - // TODO change to Regularizer class - public void setActivityRegularizer(Object activityRegularizer) { + public void setActivityRegularizer(Regularizer activityRegularizer) { this.activityRegularizer = activityRegularizer; } @@ -781,7 +852,7 @@ public void setSupportsMasking(boolean supportsMasking) { * @throws IllegalArgumentException if the variable is not known. */ public Operand assign(Variable variable, Operand value) { - VariableDef varDef = variableMap.get(variable); + VariableDef varDef = variableMap.get(variable); if (varDef == null) { throw new IllegalStateException(String.format("Variable %s was not found.", variable)); } @@ -797,7 +868,7 @@ public Operand assign(Variable variable, Operand value) { * @throws IllegalArgumentException if the variable is not known. */ public Operand assignAdd(Variable variable, Operand value) { - VariableDef varDef = variableMap.get(variable); + VariableDef varDef = variableMap.get(variable); if (varDef == null) { throw new IllegalStateException(String.format("Variable %s was not found.", variable)); } @@ -813,7 +884,7 @@ public Operand assignAdd(Variable variable, Operand value) { * @throws IllegalArgumentException if the variable is not known. */ public Operand assignSub(Variable variable, Operand value) { - VariableDef varDef = variableMap.get(variable); + VariableDef varDef = variableMap.get(variable); if (varDef == null) { throw new IllegalStateException(String.format("Variable %s was not found.", variable)); } @@ -827,8 +898,7 @@ public static class Options { protected Long batchSize; protected List> metrics; protected List losses; - // TODO change to Regularizer class - protected Object activityRegularizer; + protected Regularizer activityRegularizer; public static Options create() { return new Options(); @@ -873,8 +943,7 @@ public Layer.Options batchInputShape(Shape batchInputShape) { * @param activityRegularizer the activity Regularizer * @return this Options instance */ - // TODO change to Regularizer class - public Layer.Options activityRegularizer(Object activityRegularizer) { + public Layer.Options activityRegularizer(Regularizer activityRegularizer) { this.activityRegularizer = activityRegularizer; return this; } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java index 57a5d23b291..807b2b4430f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/LeakyReLU.java @@ -29,10 +29,10 @@ * *

It allows a small gradient when the unit is not active: * - *

- *     f(x) = alpha * x if x < 0
- *     f(x) = x if x >= 0
- * 
+ *
{@code
+ * f(x) = alpha * x if x < 0
+ * f(x) = x if x >= 0
+ * }
* * @param the data type for the layer's weights and computation. */ @@ -47,9 +47,10 @@ public class LeakyReLU extends Layer { * * @param tf the TensorFlow Ops. * @param type the data type for the layer's weights and computation. + * @param options the layer's options. */ public LeakyReLU(Ops tf, Class type, Options options) { - this(tf, null, DEFAULT_ALPHA, type, null); + this(tf, null, DEFAULT_ALPHA, type, options); } /** @@ -71,6 +72,7 @@ public LeakyReLU(Ops tf, String name, Class type) { * @param tf the TensorFlow Ops. * @param alpha Negative slope coefficient. Must be >= 0. * @param type the data type for the layer's weights and computation. + * @param options the layer's options. */ public LeakyReLU(Ops tf, float alpha, Class type, Options options) { this(tf, null, alpha, type, options); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java index ea09fafaf98..28f1a80821f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Maximum.java @@ -50,6 +50,7 @@ public Maximum(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param type the data type for the weights and computation + * @param options the layer's options. */ public Maximum(Ops tf, Class type, Options options) { this(tf, null, type, options); @@ -72,6 +73,7 @@ public Maximum(Ops tf, String name, Class type) { * @param tf the TensorFlow Ops * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} * @param type the data type for the weights and computation + * @param options the layer's options. */ public Maximum(Ops tf, String name, Class type, Options options) { super(tf, name, type, options); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java index a1fe012689c..bc46e7e82f3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Minimum.java @@ -35,8 +35,6 @@ */ public class Minimum extends Merge { - - /** * Creates an Add Layer using {@link Class#getSimpleName()} as the layer name. * @@ -52,6 +50,7 @@ public Minimum(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param type the data type for the weights and computation + * @param options the layer's options. */ public Minimum(Ops tf, Class type, Options options) { this(tf, null, type, options); @@ -74,6 +73,7 @@ public Minimum(Ops tf, String name, Class type) { * @param tf the TensorFlow Ops * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} * @param type the data type for the weights and computation + * @param options the layer's options. */ public Minimum(Ops tf, String name, Class type, Options options) { super(tf, name, type, options); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java index fbb190cb00b..d343463f7af 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Multiply.java @@ -50,6 +50,7 @@ public Multiply(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param type the data type for the weights and computation + * @param options the layer's options. */ public Multiply(Ops tf, Class type, Options options) { this(tf, null, type, options); @@ -72,6 +73,7 @@ public Multiply(Ops tf, String name, Class type) { * @param tf the TensorFlow Ops * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} * @param type the data type for the weights and computation + * @param options the layer's options. */ public Multiply(Ops tf, String name, Class type, Options options) { super(tf, name, type, options); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java index 74307780ce0..be502aba57b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ReLU.java @@ -30,11 +30,11 @@ * *

Otherwise, it follows: * - *

- *    f(x) = max_value if x >= max_value
- *     f(x) = x if threshold <= x < max_value
- *     f(x) = negative_slope * (x - threshold) otherwise
- * 
+ *
{@code
+ * f(x) = max_value if x >= max_value
+ *  f(x) = x if threshold <= x < max_value
+ *  f(x) = negative_slope * (x - threshold) otherwise
+ * }
* * @param the data type for the layer's weights and computation. */ @@ -127,12 +127,7 @@ public ReLU(Ops tf, float negativeSlope, Class type, Options options) { * @param type the data type for the layer's weights and computation. * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 */ - public ReLU( - Ops tf, - float negativeSlope, - float maxValue, - float threshold, - Class type) { + public ReLU(Ops tf, float negativeSlope, float maxValue, float threshold, Class type) { this(tf, null, negativeSlope, maxValue, threshold, type, null); } @@ -149,12 +144,12 @@ public ReLU( * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 */ public ReLU( - Ops tf, - float negativeSlope, - float maxValue, - float threshold, - Class type, - Options options) { + Ops tf, + float negativeSlope, + float maxValue, + float threshold, + Class type, + Options options) { this(tf, null, negativeSlope, maxValue, threshold, type, options); } @@ -171,12 +166,7 @@ public ReLU( * @throws IllegalArgumentException if maxValue or negativeSlope is < 0 */ public ReLU( - Ops tf, - String name, - float negativeSlope, - float maxValue, - float threshold, - Class type) { + Ops tf, String name, float negativeSlope, float maxValue, float threshold, Class type) { this(tf, name, negativeSlope, maxValue, threshold, type, null); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java index 32594d72f8a..3808fad8bff 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/RepeatVector.java @@ -36,7 +36,8 @@ public class RepeatVector extends Layer { private final int repeatCount; /** - * Creates a RepeatCount using a unique name will be generated based on * {@link Class#getSimpleName()}. + * Creates a RepeatCount using a unique name will be generated based on * {@link + * Class#getSimpleName()}. * * @param tf the TensorFlow Ops * @param repeatCount the repetition factor. @@ -46,9 +47,9 @@ public RepeatVector(Ops tf, int repeatCount, Class type) { this(tf, null, repeatCount, type, null); } - /** - * Creates a RepeatCountusing a unique name will be generated based on * {@link Class#getSimpleName()}. + * Creates a RepeatVector using a unique name will be generated based on * {@link + * Class#getSimpleName()}. * * @param tf the TensorFlow Ops * @param repeatCount the repetition factor. @@ -88,7 +89,6 @@ public RepeatVector(Ops tf, String name, int repeatCount, Class type, Options } /** - * * @param inputs the input Operands, 2D tensor of shape (num_samples, features) * @param masks a list of masks, one for each input, to apply to the result, may be null * @param training whether the call is in inference mode or training mode diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java index fe954ae0177..a3f32fa29cf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Softmax.java @@ -15,7 +15,7 @@ package org.tensorflow.framework.layers; import org.tensorflow.Operand; -import org.tensorflow.framework.op.math.ReduceLogSumExp; +import org.tensorflow.framework.op.FrameworkOps; import org.tensorflow.op.Ops; import org.tensorflow.types.TBool; import org.tensorflow.types.TFloat16; @@ -40,7 +40,6 @@ public class Softmax extends Layer { * {@link Class#getSimpleName()}. * @param axes axes along which the softmax normalization is applied. * @param type the data type for the layer's weights and computation. - */ public Softmax(Ops tf, String name, int[] axes, Class type) { this(tf, name, axes, type, null); @@ -69,6 +68,7 @@ public List> call( boolean training, Class resultType) { Ops tf = getTF(); + FrameworkOps fops = FrameworkOps.create(tf); // TODO mask List> results = new ArrayList<>(); @@ -91,11 +91,9 @@ public List> call( input = tf.math.add(input, adder); } if (axes.length > 1) { - result = - tf.math.exp( - tf.math.sub(input, ReduceLogSumExp.reduceLogSumExp(tf.scope(), input, axes, true))); + result = tf.math.exp(tf.math.sub(input, fops.math.reduceLogSumExp(input, axes, true))); } else { - result = org.tensorflow.framework.op.nn.Softmax.softmax(tf.scope(), input, axes[0]); + result = fops.nn.softmax(input, axes[0]); } results.add(result); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java index 7f14012a465..a114c64ae86 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/Subtract.java @@ -54,12 +54,12 @@ public Subtract(Ops tf, Class type) { * * @param tf the TensorFlow Ops * @param type the data type for the weights and computation + * @param options the layer's options. */ public Subtract(Ops tf, Class type, Options options) { this(tf, null, type, options); } - /** * Creates an Add Layer * @@ -77,6 +77,7 @@ public Subtract(Ops tf, String name, Class type) { * @param tf the TensorFlow Ops * @param name the name of the layer, if null the name is set to {@link Class#getSimpleName()} * @param type the data type for the weights and computation + * @param options the layer's options. */ public Subtract(Ops tf, String name, Class type, Options options) { super(tf, name, type, options); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java index fd9c458436a..29cf026c0a6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/ThresholdedReLU.java @@ -30,10 +30,10 @@ * *

It follows:: * - *

- *     f(x) = x for x > theta
- *     f(x) = 0 otherwise`
- * 
+ *
{@code
+ * f(x) = x for x > theta
+ * f(x) = 0 otherwise`
+ * }
* * @param the data type for the layer's weights and computation. */ @@ -73,6 +73,7 @@ public ThresholdedReLU(Ops tf, String name, Class type) { * @param tf the TensorFlow Ops * @param theta Negative slope coefficient. Must be >= 0. * @param type the data type for the layer's weights and computation. + * @param options the layer's options. */ public ThresholdedReLU(Ops tf, float theta, Class type, Options options) { this(tf, null, theta, type, options); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java index b7bedc907c6..899a528e534 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/InputSpec.java @@ -322,6 +322,11 @@ public static class Options { private Map axes; private boolean allowLastAxisSqueeze; + /** + * Creates an InputSpecs.Options instance + * + * @return the InputSpecs.Options instance + */ public static Options create() { return new Options(); } @@ -329,6 +334,7 @@ public static Options create() { /** * Sets the expected Data Type of the input. * + * @param dataType the expected Data Type of the input. * @return this Options instance. */ public Options dataType(Class dataType) { @@ -340,6 +346,7 @@ public Options dataType(Class dataType) { * Sets the expected shape of the input (may include {@link Shape#UNKNOWN_SIZE} for unchecked * axes). Includes the batch size. * + * @param shape the expected shape of the input * @return this Options instance. */ public Options shape(Shape shape) { @@ -350,6 +357,7 @@ public Options shape(Shape shape) { /** * Sets the expected rank of the input * + * @param rank the expected rank of the input * @return this Options instance. */ public Options rank(Integer rank) { @@ -360,6 +368,7 @@ public Options rank(Integer rank) { /** * Sets the maximum rank of the input. * + * @param maxRank the maximum rank of the input. * @return this Options instance. */ public Options maxRank(Integer maxRank) { @@ -370,6 +379,7 @@ public Options maxRank(Integer maxRank) { /** * Sets the minimum rank of the input. * + * @param minRank the minimum rank of the input. * @return this Options instance. */ public Options minRank(Integer minRank) { @@ -379,6 +389,7 @@ public Options minRank(Integer minRank) { /** * Sets the Dictionary mapping integer axes to a specific dimension value. * + * @param axes the Dictionary mapping integer axes to a specific dimension value. * @return this Options instance. */ public Options axesMap(Map axes) { @@ -388,6 +399,8 @@ public Options axesMap(Map axes) { /** * Sets the Dictionary mapping integer axes to a specific dimension value. * + * @param key the integer axis + * @param dim the dimension value for the specified axis * @return this Options instance. */ public Options axesMap(Integer key, Long dim) { @@ -403,6 +416,8 @@ public Options axesMap(Integer key, Long dim) { * N+1 as long as the last axis of the input is 1, as well as inputs of rank N-1 as long as the * last axis of the spec is 1. * + * @param allowLastAxisSqueeze indicator that the allow last axis squeeze indicator for the + * input * @return this Options instance. */ public Options allowLastAxisSqueeze(boolean allowLastAxisSqueeze) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java index c78b3ccebdb..74d948e0262 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/Merge.java @@ -374,7 +374,7 @@ protected Shape computeElementWiseOpOutputShape(Shape shape1, Shape shape2) { String.format( "Operands could not be broadcast together with shapes %s %s", shape1, shape2)); } else { - outputShape.append(shape1.size(i)); + outputShape = outputShape.append(shape1.size(i)); } } return outputShape; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java index ef002bfb49f..370d1907a74 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/layers/impl/VariableDef.java @@ -15,23 +15,29 @@ package org.tensorflow.framework.layers.impl; import org.tensorflow.Operand; -import org.tensorflow.framework.constraints.Constraint; import org.tensorflow.framework.initializers.Glorot; import org.tensorflow.framework.initializers.Initializer; import org.tensorflow.framework.initializers.VarianceScaling; import org.tensorflow.framework.initializers.Zeros; +import org.tensorflow.framework.regularizers.Regularizer; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Assign; +import org.tensorflow.op.core.AssignAdd; +import org.tensorflow.op.core.AssignSub; import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; +import java.util.function.UnaryOperator; + public class VariableDef { private final Ops tf; private final String name; private final Shape shape; private final Initializer initializer; - private final Constraint constraint; + private final UnaryOperator> constraint; + private final Regularizer regularizer; private final boolean trainable; private final Variable variable; private final Operand initOperand; @@ -42,7 +48,8 @@ public VariableDef( String name, Shape shape, Initializer initializer, - Constraint constraint, + UnaryOperator> constraint, + Regularizer regularizer, boolean trainable, long seed, Class type) { @@ -50,12 +57,13 @@ public VariableDef( this.type = type; this.name = name; this.constraint = constraint; + this.regularizer = regularizer; this.trainable = trainable; this.shape = shape == null ? Shape.scalar() : shape; this.initializer = initializer == null ? getDefaultInitializer(seed) : initializer; initOperand = this.initializer.call(tf.constant(this.shape), type); - variable = tf.variable(initOperand); + variable = tf.withSubScope(name).variable(initOperand); } public VariableDef( @@ -63,12 +71,14 @@ public VariableDef( String name, Variable variable, Initializer initializer, - Constraint constraint, + UnaryOperator> constraint, + Regularizer regularizer, boolean trainable, long seed) { this.tf = tf.withName(name); this.name = name == null ? variable.toString() : name; this.constraint = constraint; + this.regularizer = regularizer; this.trainable = trainable; this.variable = variable; shape = variable.shape(); @@ -87,39 +97,65 @@ public Operand init() { } /** - * Assigns a value to the variable + * Assigns a value to the variable, with locking set to false * * @param value the value to assign * @return the operand that assigns the value to this variable */ public Operand assign(Operand value) { - // apply constraint if it exists - Operand tValue = constraint != null ? constraint.call(value) : value; - return tf.assign(variable, tValue); + return assign(value, false); + } + /** + * Assigns a value to the variable + * + * @param value the value to assign + * @param useLocking If true, use locking during the assignment. + * @return the operand that assigns the value to this variable + */ + public Operand assign(Operand value, boolean useLocking) { + return tf.assign(variable, value, Assign.useLocking(useLocking)); } /** - * Adds a value to the variable + * Adds a value to the variable, without locking. * * @param value the value to add * @return the operand that adds the value to this variable */ public Operand assignAdd(Operand value) { - Operand add = tf.assignAdd(variable, value); - // apply constraint if it exists - return constraint != null ? tf.assign(variable, constraint.call(add)) : add; + return assignAdd(value, false); } /** - * Subtracts a value from the variable + * Adds a value to the variable + * + * @param value the value to add + * @param useLocking If true, use locking during the assignment. + * @return the operand that adds the value to this variable + */ + public Operand assignAdd(Operand value, boolean useLocking) { + return tf.assignAdd(variable, value, AssignAdd.useLocking(useLocking)); + } + + /** + * Subtracts a value from the variable, without locking. * * @param value the value to subtract * @return the operand that subtracts the value from this variable */ public Operand assignSub(Operand value) { - Operand sub = tf.assignSub(variable, value); - // apply constraint if it exists - return constraint != null ? tf.assign(variable, constraint.call(sub)) : sub; + return assignSub(value, false); + } + + /** + * Subtracts a value from the variable + * + * @param value the value to subtract + * @param useLocking If true, use locking during the assignment. + * @return the operand that subtracts the value from this variable + */ + public Operand assignSub(Operand value, boolean useLocking) { + return tf.assignSub(variable, value, AssignSub.useLocking(useLocking)); } /** @@ -132,12 +168,14 @@ public Operand assignSub(Operand value) { @SuppressWarnings("unchecked") private Initializer getDefaultInitializer(long seed) { Initializer initializer; + if (TFloating.class.isAssignableFrom(type)) { // this creates a "Casting 'new Glorot<>(...)' to 'Initializer' is redundant" warning. // Ignored here as Glorot takes a TFloating which is a subclass of // and is checked in the if statement above. If you remove this cast, you'll get an error. - initializer = ( Initializer)new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); + //noinspection RedundantCast + initializer = (Initializer) new Glorot<>(tf, VarianceScaling.Distribution.UNIFORM, seed); } else { initializer = new Zeros<>(tf); } @@ -173,9 +211,18 @@ public Initializer getInitializer() { * * @return the variable constraint */ - public Constraint getConstraint() { + public UnaryOperator> getConstraint() { return constraint; } + + /** + * Gets the variable constraint + * + * @return the variable constraint + */ + public Regularizer getRegularizer() { + return regularizer; + } /** * Gets the variable trainable indicator * diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java index 4f5120a3dbf..fa05f5a6307 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java @@ -16,10 +16,12 @@ import org.tensorflow.Operand; import org.tensorflow.framework.op.nn.SigmoidCrossEntropyWithLogits; +import org.tensorflow.framework.op.nn.Softmax; import org.tensorflow.framework.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.framework.op.nn.SparseSoftmaxCrossEntropyWithLogits; import org.tensorflow.op.Op; import org.tensorflow.op.Scope; +import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; /** @@ -191,4 +193,18 @@ public Operand sparseSoftmaxCrossEntro return SparseSoftmaxCrossEntropyWithLogits.sparseSoftmaxCrossEntropyWithLogits( scope, labels, logits); } + + /** + * Calculates a Softmax operation. If the exis is not the last dimension, then the input axis is + * moved to the last axis berfore calling tf.nn.softmax, then restored before returning. + * + * @param input the input + * @param axis the axis + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + public Operand softmax(Operand input, int axis){ + return Softmax.softmax(scope, input, axis); + } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java new file mode 100644 index 00000000000..ea300acb0b2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java @@ -0,0 +1,118 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.op.nn; + +import org.tensorflow.Operand; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.op.core.Concat; +import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Range; +import org.tensorflow.op.core.Rank; +import org.tensorflow.op.core.Reshape; +import org.tensorflow.op.linalg.Transpose; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Sub; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TFloating; + +import java.util.Arrays; + +/** + * Higher level operation for Softmax. This class will move the desired axis to the last axis, if + * necessary, before calling the low level tf.nn.softmax method. + */ +@Operator(group = "nn") +public class Softmax { + + /** + * Calculates a Softmax operation. If the exis is not the last dimension, then the input axis is + * moved to the last axis berfore calling tf.nn.softmax, then restored before returning. + * + * @param scope The TensorFlow scope + * @param input the input + * @param axis the axis + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + @Endpoint(name = "softmax") + public static Operand softmax(Scope scope, Operand input, int axis) { + Shape shape = input.shape(); + boolean isLastDim = axis == -1 || axis == shape.numDimensions() - 1; + if (isLastDim) { + return org.tensorflow.op.nn.Softmax.create(scope, input); + } + + if (axis <= -shape.numDimensions() || axis >= shape.numDimensions()) { + throw new IllegalArgumentException( + String.format( + "Axis (%d) must be in the range [%d, %d] where %d is the number of dimensions in the input.", + axis, -shape.numDimensions(), shape.numDimensions(), shape.numDimensions())); + } + + int dim = Math.floorMod(axis, shape.numDimensions()); + Operand rank = Rank.create(scope, input); + Operand dimOp = Constant.scalarOf(scope, dim); + Operand one = Constant.scalarOf(scope, 1); + Operand lastIndex = Sub.create(scope, rank, one); + Operand swappedInputs = swapAxis(scope, input, dimOp, lastIndex); + Operand output = org.tensorflow.op.nn.Softmax.create(scope, swappedInputs); + return fixOutput(scope, output, shape, dimOp, lastIndex); + } + + /** + * Restores the specified axis, then reshapes the input to the provided shaoe. + * + * @param scope The TensorFlow scope + * @param output the output + * @param shape the desired shape + * @param dim the dimension to move + * @return the restored output based on the dimension and shape. + */ + private static Operand fixOutput( + Scope scope, Operand output, Shape shape, Operand dim, Operand lastIndex) { + + Operand result = swapAxis(scope, output, dim, lastIndex); + return Reshape.create(scope, result, Constant.tensorOf(scope, shape)); + } + + /** + * Moves the specified Axis to the last axis + * + * @param input the input + * @param dim the dimension to move + * @param lastIndex the last dimension + * @return input with the dimension swapped to the last dimension + */ + private static Operand swapAxis( + Scope scope, Operand input, Operand dim, Operand lastIndex) { + + Operand zero = Constant.scalarOf(scope, 0); + Operand one = Constant.scalarOf(scope, 1); + return Transpose.create( + scope, + input, + Concat.create( + scope, + Arrays.asList( + Range.create(scope, zero, dim, one), + Range.create(scope, Add.create(scope, dim, one), lastIndex, one), + dim), + zero)); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java index 65d31e97945..cf2c3a76117 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AddTest.java @@ -15,7 +15,9 @@ import java.util.List; import java.util.Map; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; class AddTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -104,13 +106,12 @@ public void testAdd() { Layer.Options.create().inputShape(Shape.of(4, 5))); Add instance = new Add<>(tf, TFloat64.class); List> resultList = - instance.call( - Arrays.asList( - i1.getOutput(TFloat64.class), - i2.getOutput(TFloat64.class), - i3.getOutput(TFloat64.class)), - TFloat64.class); + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)), + TFloat64.class); Operand result = resultList.get(0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java index 66e55a8deda..aec95201b32 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/AverageTest.java @@ -84,7 +84,9 @@ public void testAverage() { Layer.Options.create().inputShape(Shape.of(4, 5))); Average instance = new Average<>(tf, TFloat64.class); List> resultList = - instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); Operand result = resultList.get(0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java index fd976b74f51..bb69145182b 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ConcatenateTest.java @@ -15,7 +15,9 @@ import java.util.List; import java.util.Map; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; class ConcatenateTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -92,7 +94,9 @@ public void testConcatenate() { Layer.Options.create().inputShape(Shape.of(4, 5))); Concatenate instance = new Concatenate<>(tf, 1, TFloat64.class); List> resultList = - instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); Operand result = resultList.get(0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java index 555afc85676..78dfe67ec06 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DenseTest.java @@ -26,7 +26,6 @@ import java.util.Collections; import java.util.List; -import java.util.Random; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -34,7 +33,6 @@ public class DenseTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - private final Random random = new Random(1001L); @Test public void testShape3_2() { @@ -45,7 +43,8 @@ public void testShape3_2() { int units = 3; Dense instance = - new Dense<>(tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + new Dense<>( + tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); float[][] data = { {6.600953f, 4.659476f}, @@ -83,7 +82,8 @@ public void testShape4_2() { int units = 3; Dense instance = - new Dense<>(tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + new Dense<>( + tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); float[][] inputArray = { {6.600953f, 4.659476f}, @@ -126,7 +126,8 @@ public void testShapeN_N_2() { int units = 3; Dense instance = - new Dense<>(tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + new Dense<>( + tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); Shape fullShape = Shape.of(5, 10, 2); float[][][] data = { @@ -280,7 +281,8 @@ public void testShape3_4_5_2() { Shape inputShape = Shape.of(3, 4, 5, 2); Dense instance = - new Dense<>(tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); + new Dense<>( + tf, units, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); assertEquals("Dense", instance.getName()); session.run(tf.init()); @@ -505,8 +507,11 @@ public void testConstraintsNonNeg() { true, null, null, - nonNeg, - nonNeg, + null, + null, + null, + nonNeg::call, + nonNeg::call, 1001L, TFloat32.class, Layer.Options.create().inputShape(inputShape)); @@ -526,11 +531,12 @@ public void testConstraintsNonNeg() { {-0, 4, -0} }; - float[] biasConstraintInput = { -1, 2, 5 }; - float[] biasConstraintExpected = { -0, 2, 5 }; + float[] biasConstraintInput = {-1, 2, 5}; + float[] biasConstraintExpected = {-0, 2, 5}; Operand input = tf.constant(data); + @SuppressWarnings("unused") Operand y = instance.call(input, TFloat32.class); // initialize variables session.run(tf.init()); @@ -542,15 +548,14 @@ public void testConstraintsNonNeg() { Variable kernel = instance.getKernel(); Operand varUpdate = instance.assign(kernel, tf.constant(constraintInput)); session.run(varUpdate); - session.evaluate(tf.constant(constraintExpected), kernel); + session.evaluate(tf.constant(constraintExpected), instance.applyConstraint(kernel)); // test bias Variable bias = instance.getBias(); assertEquals(Shape.of(units), bias.shape()); varUpdate = instance.assignAdd(bias, tf.constant(biasConstraintInput)); session.run(varUpdate); - session.evaluate(tf.constant(biasConstraintExpected), bias); - + session.evaluate(tf.constant(biasConstraintExpected), instance.applyConstraint(bias)); } } @@ -565,42 +570,46 @@ public void testConstraintsMinMaxNorm() { MinMaxNorm minMaxNorm = new MinMaxNorm(tf); Dense instance = - new Dense<>( - tf, - "constraintTest", - units, - null, - true, - null, - null, - minMaxNorm, - minMaxNorm, - 1001L, - TFloat32.class, - Layer.Options.create().inputShape(inputShape)); + new Dense<>( + tf, + "constraintTest", + units, + null, + true, + null, + null, + null, + null, + null, + minMaxNorm::call, + minMaxNorm::call, + 1001L, + TFloat32.class, + Layer.Options.create().inputShape(inputShape)); float[][] data = { - {6.600953f, 4.659476f}, - {6.943807f, 2.113826f}, - {4.667166f, 6.931125f} + {6.600953f, 4.659476f}, + {6.943807f, 2.113826f}, + {4.667166f, 6.931125f} }; float[][] constraintInput = { - {1, 0.5f, 2}, - {-2, 0.75f, 0} + {1, 0.5f, 2}, + {-2, 0.75f, 0} }; float[][] constraintExpected = { - { 0.447214f, 0.5f, 1}, - {-0.894427f, 0.75f, 0} + {0.447214f, 0.5f, 1}, + {-0.894427f, 0.75f, 0} }; - float[] biasConstraintInput = { -1, 2, 5 }; - float[] biasConstraintExpected = { -0.182574f,0.365148f, 0.912871f }; + float[] biasConstraintInput = {-1, 2, 5}; + float[] biasConstraintExpected = {-0.182574f, 0.365148f, 0.912871f}; Operand input = tf.constant(data); + @SuppressWarnings("unused") Operand y = instance.call(input, TFloat32.class); - //initialize variables + // initialize variables session.run(tf.init()); List> weights = instance.getWeights(); @@ -610,7 +619,7 @@ public void testConstraintsMinMaxNorm() { Variable kernel = instance.getKernel(); Operand varUpdate = instance.assign(kernel, tf.constant(constraintInput)); session.run(varUpdate); - session.evaluate(tf.constant(constraintExpected), kernel); + session.evaluate(tf.constant(constraintExpected), instance.applyConstraint(kernel)); // test bias Variable bias = instance.getBias(); @@ -618,8 +627,7 @@ public void testConstraintsMinMaxNorm() { varUpdate = instance.assignAdd(bias, tf.constant(biasConstraintInput)); session.run(varUpdate); - session.evaluate(tf.constant(biasConstraintExpected), bias); - + session.evaluate(tf.constant(biasConstraintExpected), instance.applyConstraint(bias)); } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java index 27292222814..c8856dcbba2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/DotTest.java @@ -37,13 +37,23 @@ public void testDot() { Ops tf = session.getTF(); Input i1 = new Input<>( - tf, "l1", TFloat64.class, TFloat64.class, Layer.Options.create().inputShape(Shape.of(4))); + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4))); Input i2 = new Input<>( - tf, "l2", TFloat64.class, TFloat64.class, Layer.Options.create().inputShape(Shape.of(4))); + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4))); Dot instance = new Dot<>(tf, 1, TFloat64.class); List> resultList = - instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); Operand result = resultList.get(0); @@ -70,13 +80,23 @@ public void testDotNegativeAxis() { Ops tf = session.getTF(); Input i1 = new Input<>( - tf, "l1", TFloat64.class, TFloat64.class, Layer.Options.create().inputShape(Shape.of(4))); + tf, + "l1", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4))); Input i2 = new Input<>( - tf, "l2", TFloat64.class, TFloat64.class, Layer.Options.create().inputShape(Shape.of(4))); + tf, + "l2", + TFloat64.class, + TFloat64.class, + Layer.Options.create().inputShape(Shape.of(4))); Dot instance = new Dot<>(tf, new int[] {-1, -1}, TFloat64.class); List> resultList = - instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); Operand result = resultList.get(0); @@ -106,7 +126,8 @@ public void testDotComputeOutputShape() { List outputShapes = dot.computeOutputShape(Arrays.asList(Shape.of(4, 5), Shape.of(4, 5))); assertFalse(outputShapes.isEmpty()); - assertArrayEquals(new long[] {4}, outputShapes.get(0).asArray()); + assertArrayEquals(new long[] {4,1}, outputShapes.get(0).asArray()); + } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java index a0530e27e4a..b492d0c7c8d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ELUTest.java @@ -12,32 +12,32 @@ class ELUTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; float[][][] inputArray = { - { - {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, - {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, - {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} - }, - { - {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, - {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, - {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} - } + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} + } }; @Test public void testCallAlpha0() { float[][][] expectedArray = { - { - {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, - {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, - {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} - }, - { - {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, - {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, - {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} - } + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } }; try (TestSession session = TestSession.createTestSession(tfMode)) { @@ -45,8 +45,8 @@ public void testCallAlpha0() { float alpha = 0f; ELU instance = - new ELU<>( - tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + new ELU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); Operand result = instance.call(tf.constant(inputArray), TFloat32.class); @@ -58,16 +58,16 @@ public void testCallAlpha0() { public void testCallAlpha0Point5() { float[][][] expectedArray = { - { - {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, - {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, - {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} - }, - { - {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, - {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, - {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} - } + { + {2.7085743f, 8.254536f, 9.754793f, 1.1027353f}, + {8.698364f, 2.2781835f, 8.608563f, 1.4326588f}, + {0.7584583f, 5.604635f, 7.3599877f, 0.06365667f} + }, + { + {4.8735523f, 9.90222f, 5.390144f, 2.052634f}, + {5.9165273f, 0.9186602f, 0.9137567f, 0.5605333f}, + {2.0804656f, 8.537634f, 6.403787f, 5.8328476f} + } }; try (TestSession session = TestSession.createTestSession(tfMode)) { @@ -75,8 +75,8 @@ public void testCallAlpha0Point5() { float alpha = 0.5f; ELU instance = - new ELU<>( - tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + new ELU<>( + tf, alpha, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); Operand result = instance.call(tf.constant(inputArray), TFloat32.class); @@ -88,16 +88,16 @@ public void testCallAlpha0Point5() { public void testCallAlphaMinus1() { double[][][] expectedArray = { - { - {2.7085743, 8.254536, 9.754793, 1.1027353}, - {8.698364, 2.2781835, 8.608563, 1.4326588}, - {0.7584583, 5.604635, 7.3599877, 0.06365667} - }, - { - {4.8735523, 9.90222, 5.390144, 2.052634}, - {5.9165273, 0.9186602, 0.9137567, 0.5605333}, - {2.0804656, 8.537634, 6.403787, 5.8328476} - } + { + {2.7085743, 8.254536, 9.754793, 1.1027353}, + {8.698364, 2.2781835, 8.608563, 1.4326588}, + {0.7584583, 5.604635, 7.3599877, 0.06365667} + }, + { + {4.8735523, 9.90222, 5.390144, 2.052634}, + {5.9165273, 0.9186602, 0.9137567, 0.5605333}, + {2.0804656, 8.537634, 6.403787, 5.8328476} + } }; try (TestSession session = TestSession.createTestSession(tfMode)) { @@ -105,15 +105,13 @@ public void testCallAlphaMinus1() { float alpha = -1.f; ELU instance = - new ELU<>( - tf, alpha, TFloat64.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + new ELU<>( + tf, alpha, TFloat64.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); Operand result = - instance.call( - tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); session.evaluate(tf.constant(expectedArray), result); } } } - diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java index be74d8c736b..c031b8c3387 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianDropoutTest.java @@ -47,7 +47,6 @@ public void testShape3_2_3() { result = instance.call(input, true, TFloat64.class); assertEquals(expectedShape, result.shape()); session.evaluate(expected, result); - } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java index eb703c94ae2..ab7cdde0906 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/GaussianNoiseTest.java @@ -40,15 +40,15 @@ public void testShape3_2_3() { new double[][][] { {{3.679269, 1.977210, 8.807560}, {9.322395, 6.767639, 7.653679}}, {{6.874095, 3.046032, 4.328507}, {3.396144, 3.414768, 1.099349}}, - {{6.208934, 6.194471, 3.045125}, {10.126389, 7.881398, 9.125002}} + {{6.208934, 6.194471, 3.045125}, {10.126389, 7.881398, 9.125002}} }); // second pass, trainable is true, so there should be noise applied result = instance.call(input, true, TFloat64.class); assertEquals(expectedShape, result.shape()); // cannot evaluate more than once, else it doesn't match expected - // because of random number generation. - //session.print(result); + // because of random number generation. + // session.print(result); session.evaluate(expected, result); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java index 587e24da029..185262ee5dd 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/InputTest.java @@ -27,7 +27,9 @@ import java.util.List; import java.util.Map; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; public class InputTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java index 689ab1163e9..738d5daf5e0 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LambdaTest.java @@ -21,7 +21,7 @@ public void testCallLambda() { Shape shape = Shape.of(3, 2); Lambda instance = new Lambda<>(tf, TFloat32.class); - instance.setLamda((t, y) -> t.math.mul(cast(t, t.constant(2), y.type()), y)); + instance.setLambda((t, y) -> t.math.mul(cast(t, t.constant(2), y.type()), y)); double[][] array = { {0.41448207, 0.71509451}, {0.21307868, 0.76890945}, {0.37533432, 0.7761148} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java index f5baa2f9574..363894827cb 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/LeakyReLUTest.java @@ -7,9 +7,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.family.TType; - -import static org.tensorflow.framework.utils.CastHelper.cast; class LeakyReLUTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -83,7 +80,7 @@ public void testCallAlpha0Point5() { Operand result = instance.call(tf.constant(inputArray), TFloat32.class); - session.evaluate(tf.constant(expectedArray), result); + session.evaluate(tf.constant(expectedArray), result); } } @@ -111,9 +108,9 @@ public void testCallAlphaMinus1() { new LeakyReLU<>( tf, alpha, TFloat64.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); - Operand result = instance.call( tf.constant(inputArray) , TFloat64.class); + Operand result = instance.call(tf.constant(inputArray), TFloat64.class); - session.evaluate(tf.constant(expectedArray), result); + session.evaluate(tf.constant(expectedArray), result); } } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java index ed35ab8408d..4e7ed046c34 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MaximumTest.java @@ -83,7 +83,9 @@ public void testAverage() { Layer.Options.create().inputShape(Shape.of(4, 5))); Maximum instance = new Maximum<>(tf, TFloat64.class); List> resultList = - instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); Operand result = resultList.get(0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java index c5b7673a6d7..c9132472d41 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MinimumTest.java @@ -83,7 +83,9 @@ public void testAverage() { Layer.Options.create().inputShape(Shape.of(4, 5))); Minimum instance = new Minimum<>(tf, TFloat64.class); List> resultList = - instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); Operand result = resultList.get(0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java index 5e3e802b371..33bc5775f96 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/MultiplyTest.java @@ -104,10 +104,11 @@ public void testAdd() { Multiply instance = new Multiply<>(tf, TFloat64.class); List> resultList = instance.call( - Arrays.asList( - i1.getOutput(TFloat64.class), - i2.getOutput(TFloat64.class), - i3.getOutput(TFloat64.class)), TFloat64.class); + Arrays.asList( + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i3.getOutput(TFloat64.class)), + TFloat64.class); Operand result = resultList.get(0); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java index 2561eb2d2cd..d2ad1988124 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReLUTest.java @@ -81,7 +81,10 @@ public void testCallNegativeSlope() { float negativeSlope = 0.2f; ReLU instance = new ReLU<>( - tf, negativeSlope, TFloat32.class, Layer.Options.create().inputShape(Shape.of(2, 3, 4))); + tf, + negativeSlope, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(2, 3, 4))); Operand result = instance.call(tf.constant(inputArray), TFloat32.class); @@ -119,8 +122,7 @@ public void testCallMaxValue6() { Layer.Options.create().inputShape(Shape.of(2, 3, 4))); Operand result = - instance.call( - tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); session.evaluate(tf.constant(expectedArray), result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java index aee2e9a7dd4..77d75999a3c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/RepeatVectorTest.java @@ -8,7 +8,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; - import static org.junit.jupiter.api.Assertions.assertEquals; class RepeatVectorTest { @@ -34,11 +33,8 @@ public void testCall3_2() { {{0.37533432, 0.7761148}, {0.37533432, 0.7761148}, {0.37533432, 0.7761148}} }; - - Operand result = - instance.call( - tf.dtypes.cast(tf.constant(array), TFloat64.class), TFloat64.class); + instance.call(tf.dtypes.cast(tf.constant(array), TFloat64.class), TFloat64.class); assertEquals(expectedShape, result.shape()); session.evaluate(tf.constant(expected), result); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java index 66a5f75d8f3..a31c10973e9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/ReshapeTest.java @@ -8,114 +8,119 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; class ReshapeTest { - private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; - - float[][][] inputArray = { - { - {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, - {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, - {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} - }, - { - {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, - {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, - {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} - } - }; - - float[][][] inputArrayNN2 = { - { - {2.70857435f, 8.25453567f}, {9.75479311f, 1.10273526f}, - {8.69836437f, 2.27818352f}, {8.60856328f, 1.43265882f}, - {0.75845834f, 5.60463474f}, {7.35998787f, 0.06365667f} - }, - { - {4.87355239f, 9.90221978f}, {5.39014402f, 2.05263398f}, - {5.91652733f, 0.9186602f}, {0.91375672f, 0.56053326f}, - {2.08046551f, 8.53763374f}, {6.40378721f, 5.83284758f} - } - }; - - @Test - public void testCall43() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Shape targetShape = Shape.of(4,3); - long batchSize = 2; - Reshape instance = new Reshape<>(tf, targetShape, - TFloat32.class, Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4)) ); - - Operand result = - instance.call( - tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); - - assertArrayEquals(targetShape.prepend(batchSize).asArray(), result.shape().asArray()); - - - } + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + float[][][] inputArray = { + { + {2.70857435f, 8.25453567f, 9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f, 8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f, 7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f, 5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f, 0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f, 6.40378721f, 5.83284758f} } - - @Test - public void testCallUnknown_1() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Shape targetShape = Shape.of(Shape.UNKNOWN_SIZE,1); - long batchSize = 2; - Reshape instance = new Reshape<>(tf, targetShape, - TFloat32.class, Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4)) ); - - Shape expectedShape = Shape.of(batchSize, 12, 1); - - Operand result = - instance.call( - tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); - - assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); - - - } + }; + + float[][][] inputArrayNN2 = { + { + {2.70857435f, 8.25453567f}, {9.75479311f, 1.10273526f}, + {8.69836437f, 2.27818352f}, {8.60856328f, 1.43265882f}, + {0.75845834f, 5.60463474f}, {7.35998787f, 0.06365667f} + }, + { + {4.87355239f, 9.90221978f}, {5.39014402f, 2.05263398f}, + {5.91652733f, 0.9186602f}, {0.91375672f, 0.56053326f}, + {2.08046551f, 8.53763374f}, {6.40378721f, 5.83284758f} } - - @Test - public void testCall1_Unknown_1() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Shape targetShape = Shape.of(1, Shape.UNKNOWN_SIZE); - long batchSize = 2; - Reshape instance = new Reshape<>(tf, targetShape, - TFloat32.class, Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4)) ); - - Operand result = - instance.call( - tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); - - Shape expectedShape = Shape.of(batchSize, 1, 12); - assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); - - - } + }; + + @Test + public void testCall43() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(4, 3); + long batchSize = 2; + Reshape instance = + new Reshape<>( + tf, + targetShape, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4))); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + assertArrayEquals(targetShape.prepend(batchSize).asArray(), result.shape().asArray()); } - - @Test - public void testCallUnknownUnknown2() { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Shape targetShape = Shape.of(Shape.UNKNOWN_SIZE,1); - long batchSize = 2; - Reshape instance = new Reshape<>(tf, targetShape, - TFloat32.class, Layer.Options.create().inputShape(Shape.of(Shape.UNKNOWN_SIZE, Shape.UNKNOWN_SIZE, 2)) ); - - Operand result = - instance.call( - tf.dtypes.cast(tf.constant(inputArrayNN2), TFloat64.class), TFloat64.class); - - Shape expectedShape = Shape.of(batchSize, 12, 1); - assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); - - - } + } + + @Test + public void testCallUnknown_1() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(Shape.UNKNOWN_SIZE, 1); + long batchSize = 2; + Reshape instance = + new Reshape<>( + tf, + targetShape, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4))); + + Shape expectedShape = Shape.of(batchSize, 12, 1); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); + } + } + + @Test + public void testCall1_Unknown_1() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(1, Shape.UNKNOWN_SIZE); + long batchSize = 2; + Reshape instance = + new Reshape<>( + tf, + targetShape, + TFloat32.class, + Layer.Options.create().inputShape(Shape.of(batchSize, 3, 4))); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArray), TFloat64.class), TFloat64.class); + + Shape expectedShape = Shape.of(batchSize, 1, 12); + assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); + } + } + + @Test + public void testCallUnknownUnknown2() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape targetShape = Shape.of(Shape.UNKNOWN_SIZE, 1); + long batchSize = 2; + Reshape instance = + new Reshape<>( + tf, + targetShape, + TFloat32.class, + Layer.Options.create() + .inputShape(Shape.of(Shape.UNKNOWN_SIZE, Shape.UNKNOWN_SIZE, 2))); + + Operand result = + instance.call(tf.dtypes.cast(tf.constant(inputArrayNN2), TFloat64.class), TFloat64.class); + + Shape expectedShape = Shape.of(batchSize, 12, 1); + assertArrayEquals(expectedShape.asArray(), result.shape().asArray()); } -} \ No newline at end of file + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java index 7057951ad84..2a20e999907 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/SubtractTest.java @@ -15,7 +15,9 @@ import java.util.List; import java.util.Map; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; class SubtractTest { private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; @@ -84,7 +86,9 @@ public void testSubtract() { Layer.Options.create().inputShape(Shape.of(4, 5))); Subtract instance = new Subtract<>(tf, TFloat64.class); List> resultList = - instance.call(Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), TFloat64.class); + instance.call( + Arrays.asList(i1.getOutput(TFloat64.class), i2.getOutput(TFloat64.class)), + TFloat64.class); Operand result = resultList.get(0); @@ -132,9 +136,10 @@ public void testSubtractInvalidInputsLength() { List> resultList = instance.call( Arrays.asList( - i1.getOutput(TFloat64.class), - i2.getOutput(TFloat64.class), - i2.getOutput(TFloat64.class)), TFloat64.class); + i1.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class), + i2.getOutput(TFloat64.class)), + TFloat64.class); } }); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java index daee9a859c7..1a8336e33f5 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/InputSpecTest.java @@ -5,7 +5,9 @@ import java.util.Arrays; -import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; class InputSpecTest { @@ -28,45 +30,39 @@ public void testAxis() { .shape(Shape.of(1, Shape.UNKNOWN_SIZE, 2, 3)) .axesMap(4, 5L)); }); - } @Test public void testDefinedShape() { - Shape expected = Shape.of(1, Shape.UNKNOWN_SIZE, 2, 3); - InputSpec instance = - new InputSpec(InputSpec.Options.create().shape(expected)); - assertArrayEquals(expected.asArray(), instance.toShape().asArray()); + Shape expected = Shape.of(1, Shape.UNKNOWN_SIZE, 2, 3); + InputSpec instance = new InputSpec(InputSpec.Options.create().shape(expected)); + assertArrayEquals(expected.asArray(), instance.toShape().asArray()); } - @Test - public void testDefinedRank() { - InputSpec instance = - new InputSpec(InputSpec.Options.create().rank(5)); - long[] dims = new long[5]; - Arrays.fill(dims, Shape.UNKNOWN_SIZE); - assertArrayEquals(dims, instance.toShape().asArray()); - - instance = new InputSpec(InputSpec.Options.create().rank(0)); - dims = new long[0]; - assertArrayEquals(dims, instance.toShape().asArray()); + @Test + public void testDefinedRank() { + InputSpec instance = new InputSpec(InputSpec.Options.create().rank(5)); + long[] dims = new long[5]; + Arrays.fill(dims, Shape.UNKNOWN_SIZE); + assertArrayEquals(dims, instance.toShape().asArray()); - instance = new InputSpec(InputSpec.Options.create().rank(3).axesMap(1,3L).axesMap(-1,2L)); - dims = new long[] {Shape.UNKNOWN_SIZE, 3, 2}; - assertArrayEquals(dims, instance.toShape().asArray()); - } + instance = new InputSpec(InputSpec.Options.create().rank(0)); + dims = new long[0]; + assertArrayEquals(dims, instance.toShape().asArray()); - @Test - public void testUndefinedShapes() { - InputSpec instance = - new InputSpec(InputSpec.Options.create().maxRank(5)); - Shape genShaped = instance.toShape(); - assertTrue(genShaped.isUnknown()); + instance = new InputSpec(InputSpec.Options.create().rank(3).axesMap(1, 3L).axesMap(-1, 2L)); + dims = new long[] {Shape.UNKNOWN_SIZE, 3, 2}; + assertArrayEquals(dims, instance.toShape().asArray()); + } - instance = - new InputSpec(InputSpec.Options.create().minRank(5).maxRank(5)); - genShaped = instance.toShape(); - assertTrue(genShaped.isUnknown()); + @Test + public void testUndefinedShapes() { + InputSpec instance = new InputSpec(InputSpec.Options.create().maxRank(5)); + Shape genShaped = instance.toShape(); + assertTrue(genShaped.isUnknown()); - } + instance = new InputSpec(InputSpec.Options.create().minRank(5).maxRank(5)); + genShaped = instance.toShape(); + assertTrue(genShaped.isUnknown()); + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/TensorDotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/TensorDotTest.java deleted file mode 100644 index be5d2495c65..00000000000 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/layers/impl/TensorDotTest.java +++ /dev/null @@ -1,186 +0,0 @@ -package org.tensorflow.framework.layers.impl; - -import org.junit.jupiter.api.Test; -import org.tensorflow.Operand; -import org.tensorflow.Tensor; -import org.tensorflow.exceptions.TFInvalidArgumentException; -import org.tensorflow.framework.op.math.TensorDot; -import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.Shape; -import org.tensorflow.op.Ops; -import org.tensorflow.types.TFloat32; -import org.tensorflow.types.TInt32; -import org.tensorflow.types.family.TType; - -import java.util.HashMap; -import java.util.Map; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - -class TensorDotTest { - private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - - float[][] aArray = { - {1, 1, 1}, - {1, 1, 1}, - {1, 1, 1}, - }; - - float[][][] bArray = {{{2, 3, 1}}}; - - @Test - public void testInvalidShape() { - for (TestSession.Mode tfMode : tfModes) - assertThrows( - TFInvalidArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - float[][] a = new float[][] {{1, 2}, {3, 4}}; - float[][] b = new float[][] {{1, 2}, {3, 4}, {5, 6}}; - - Operand aOp = tf.constant(a); - Operand bOp = tf.constant(b); - - TensorDot.tensordot(tf.scope(), aOp, bOp, new int[] {1, 0}); - } - }); - } - - @Test - public void testInvalidDynamicShape() { - assertThrows( - TFInvalidArgumentException.class, - () -> { - try (TestSession session = TestSession.createTestSession(TestSession.Mode.GRAPH)) { - Ops tf = session.getTF(); - - Operand aPH = tf.placeholder(TFloat32.class); - Operand bPH = tf.placeholder(TFloat32.class); - Operand axesPH = tf.placeholder(TInt32.class); - - float[][] a = new float[][] {{1, 2}, {3, 4}}; - float[][] b = new float[][] {{1, 2}, {3, 4}, {5, 6}}; - - Operand aOp = tf.constant(a); - Operand bOp = tf.constant(b); - Operand axesOp = tf.constant(new int[] {1, 0}); - - Operand output = TensorDot.tensordot(tf.scope(), aPH, bPH, axesPH); - - try (TFloat32 aTensor = - (TFloat32) session.getGraphSession().runner().fetch(aOp).run().get(0); - TFloat32 bTensor = - (TFloat32) session.getGraphSession().runner().fetch(bOp).run().get(0); - TInt32 axesTensor = - (TInt32) session.getGraphSession().runner().fetch(axesOp).run().get(0)) { - Map, Tensor> feedMap = new HashMap<>(); - feedMap.put(aPH, aTensor); - feedMap.put(bPH, bTensor); - feedMap.put(axesPH, axesTensor); - session.run(output, feedMap); - } - } - }); - } - - @Test - public void testInvalidAxes() { - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - float[][] a = new float[][] {{1, 2}, {3, 4}}; - float[][] b = new float[][] {{1, 2}, {3, 4}}; - - Operand aOp = tf.constant(a); - Operand bOp = tf.constant(b); - assertThrows( - IllegalArgumentException.class, () -> TensorDot.tensordot(tf.scope(), aOp, bOp, -1)); - assertThrows( - IllegalArgumentException.class, () -> TensorDot.tensordot(tf.scope(), aOp, bOp, 3)); - assertThrows( - IllegalArgumentException.class, - () -> TensorDot.tensordot(tf.scope(), aOp, bOp, new int[] {1, 0, 1})); - assertThrows( - Exception.class, () -> TensorDot.tensordot(tf.scope(), aOp, bOp, new int[] {0, 7})); - } - } - - @Test - public void testValidAxis1() { - Shape expectedShape = Shape.of(3, 1, 1); - - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Operand expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}}); - Operand a = tf.constant(aArray); - Operand b = tf.constant(bArray); - Operand result = TensorDot.tensordot(tf.scope(), a, b, new int[] {1, 2}); - assertEquals(expectedShape, result.shape()); - session.evaluate(expected, result); - } - } - - @Test - public void testValidAxis2() { - - Shape expectedShape = Shape.of(3, 1, 1); - - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Operand expected = tf.constant(new float[][][] {{{6}}, {{6}}, {{6}}}); - Operand a = tf.constant(aArray); - Operand b = tf.constant(bArray); - Operand result = TensorDot.tensordot(tf.scope(), a, b, new int[][] {{1}, {2}}); - assertEquals(expectedShape, result.shape()); - session.evaluate(expected, result); - } - } - - @Test - public void testValidAxis3() { - Shape expectedShape = Shape.of(3, 3, 1, 1, 3); - for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Operand expected = - tf.constant( - new float[][][][][] { - {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, - {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, - {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}} - }); - - Operand a = tf.constant(aArray); - Operand b = tf.constant(bArray); - Operand result = TensorDot.tensordot(tf.scope(), a, b, 0); - assertEquals(expectedShape, result.shape()); - session.evaluate(expected, result); - } - } - - @Test - public void testValidAxis4() { - Shape expectedShape = Shape.of(3, 3, 1, 1, 3); - // for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(TestSession.Mode.GRAPH)) { - Ops tf = session.getTF(); - Operand expected = - tf.constant( - new float[][][][][] { - {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, - {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}}, - {{{{2, 3, 1}}}, {{{2, 3, 1}}}, {{{2, 3, 1}}}} - }); - - Operand a = tf.constant(aArray); - Operand b = tf.constant(bArray); - Operand result = TensorDot.tensordot(tf.scope(), a, b, new int[][] {{}, {}}); - assertEquals(expectedShape, result.shape()); - session.evaluate(expected, result); - } - } -} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java index f2c297ce032..76d86a95e85 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/LinalgOpsTest.java @@ -53,7 +53,7 @@ public void test2D() { expected64 = tf.constant( new double[][] {{154.01892}, {231.81863}, {166.91096}, {126.92895}, {83.58413}}); - session.setEpsilon(1e-4f); + TestSession.setEpsilon(1e-4f); session.evaluate(expected64, ans64); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java index 7dee866abf2..0c4b6ab9a51 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/SetOpsTest.java @@ -34,9 +34,10 @@ public void testSetIntersectionMultirow2() { int[][] expected = new int[][] {{1, 9}, {0, 0}}; Shape expectedShape = Shape.of(2, 2); for (Class type : types) { + // Use raw type because of changing type in for loop Operand aa = cast(tf, a, type); Operand bb = cast(tf, b, type); - Operand intersection = fops.sets.intersection(aa, bb); + Operand intersection = fops.sets.intersection(aa, bb); session.evaluate(cast(tf, tf.constant(expected), type), intersection); session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java index 7884308c9fb..6dc43ebfe64 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/EagerTestSession.java @@ -14,22 +14,21 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; +import org.tensorflow.EagerSession; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.IntNdArray; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; +import java.util.Map; import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.*; - /** Eager Mode Test Session */ public class EagerTestSession extends TestSession { @@ -83,676 +82,70 @@ public EagerSession getEagerSession() { /** {@inheritDoc} */ @Override - public void evaluate(double expected, Operand input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } else if (inputType == TInt64.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } else if (inputType == TUint8.class) { - Operand o = (Operand) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } - index.set(0); - o.asTensor().scalars().forEach(f -> assertEquals((long) expected, f.getByte())); - } + public void evaluate( + double expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor(), input.type()); } /** {@inheritDoc} */ @Override - public void evaluate(Number[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%x). %d\n", index.getAndIncrement(), f.getByte())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].byteValue(), f.getByte())); - } + public void evaluate( + Number[] expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor(), input.type()); } /** {@inheritDoc} */ @Override - public void evaluate(FloatNdArray expected, Output input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicLong index = new AtomicLong(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals(expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> - assertEquals(expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - index.set(0); - for (IntNdArray f : o.asTensor().scalars()) { - assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt()); - } - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } - index.set(0); - o.asTensor() - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); - } + public void evaluate( + FloatNdArray expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor(), input.type()); } /** {@inheritDoc} */ @Override - public void evaluateString(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %s\n", predicate.test(input.asTensor().getObject()), input.asTensor().getObject()); - } else { - input - .asTensor() - .scalars() - .forEachIndexed( - (idx, s) -> - System.out.printf( - "%d). %b <==> %s\n", - index.getAndIncrement(), predicate.test(s.getObject()), s.getObject())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(input.asTensor().getObject())); - } else { - input.asTensor().scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); - } + public void evaluateString( + Operand input, + Predicate predicate, + Map, Tensor> feedMap) { + + super.evaluateString(input.asTensor(), predicate); } /** {@inheritDoc} */ @Override - public void evaluate(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getFloat())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); - } - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getDouble()), o.asTensor().getDouble()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getDouble())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getDouble()))); - } - } else if (inputType == TFloat16.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", predicate.test(o.asTensor().getFloat()), o.asTensor().getFloat()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getFloat())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getFloat()))); - } - } else if (inputType == TInt32.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(o.asTensor().getInt()), o.asTensor().getInt()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getInt())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getInt()))); - } - } else if (inputType == TInt64.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(o.asTensor().getLong()), o.asTensor().getLong()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getLong())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getLong()))); - } - } else if (inputType == TUint8.class) { - Output o = (Output) input; - if (debug) { - if (isScalar) { - System.out.printf( - "0). %b <==> %x\n", predicate.test(o.asTensor().getByte()), o.asTensor().getByte()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %x\n", - index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); - } - } - index.set(0); - if (isScalar) { - assertTrue(predicate.test(o.asTensor().getByte())); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(o.asTensor().getByte()))); - } - } else { - fail("Unexpected Class: " + inputType); - } + public void evaluate( + Operand input, + Predicate predicate, + Map, Tensor> feedMap) { + super.evaluate(input.asTensor(), input.type(), predicate); } /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { - input - .asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } - index.set(0); - input - .asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + public void evaluate( + String[] expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor()); } /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected size (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { - input - .asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); - } - index.set(0); - input - .asTensor() - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); + public void evaluate( + Boolean[] expected, Operand input, Map, Tensor> feedMap) { + super.evaluate(expected, input.asTensor()); } /** {@inheritDoc} */ @Override - public void evaluate(Output expected, Output input) { - assert input.shape().equals(expected.shape()) - : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); - Class inputType = input.asOutput().type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.asTensor().getFloat(), o.asTensor().getFloat()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), x.asTensor().getFloat(idx), f.getFloat())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getFloat(), o.asTensor().getFloat(), epsilon); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(x.asTensor().getFloat(idx), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %f <==> %f\n", x.asTensor().getDouble(), o.asTensor().getDouble()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), x.asTensor().getDouble(idx), f.getDouble())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getDouble(), o.asTensor().getDouble(), epsilon); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(x.asTensor().getDouble(idx), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.asTensor().getInt(), o.asTensor().getInt()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), x.asTensor().getInt(idx), f.getInt())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getInt(), o.asTensor().getInt()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getInt(idx), f.getInt())); - } - } else if (inputType == TInt64.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %d <==> %d\n", x.asTensor().getLong(), o.asTensor().getLong()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), x.asTensor().getLong(idx), f.getLong())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getLong(), o.asTensor().getLong()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getLong(idx), f.getLong())); - } - } else if (inputType == TUint8.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %x <==> %x\n", x.asTensor().getByte(), o.asTensor().getByte()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %x <==> %x\n", - index.getAndIncrement(), x.asTensor().getByte(idx), f.getByte())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getByte(), o.asTensor().getByte()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getByte(idx), f.getByte())); - } - } else if (inputType == TString.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %s <==> %s\n", x.asTensor().getObject(), o.asTensor().getObject()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %s <==> %s\n", - index.getAndIncrement(), x.asTensor().getObject(idx), f.getObject())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getObject(), o.asTensor().getObject()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getObject(idx), f.getObject())); - } - } else if (inputType == TBool.class) { - Output x = (Output) expected; - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - if (debug) { - if (isScalar) { - System.out.printf("0). %b <==> %b\n", x.asTensor().getBoolean(), o.asTensor().getBoolean()); - } else { - o.asTensor() - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %b\n", - index.getAndIncrement(), x.asTensor().getBoolean(idx), f.getBoolean())); - } - } - index.set(0); - if (isScalar) { - assertEquals(x.asTensor().getBoolean(), o.asTensor().getBoolean()); - } else { - o.asTensor() - .scalars() - .forEachIndexed((idx, f) -> assertEquals(x.asTensor().getBoolean(idx), f.getBoolean())); - } - } + public void evaluate( + Operand expected, Operand input, Map, Tensor> feedMap) { + + super.evaluate(expected.asTensor(), input.asTensor(), input.type()); } /** {@inheritDoc} */ @Override - public void print(PrintWriter writer, Output input) { - Class inputType = input.asOutput().type(); - if (inputType == TFloat32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } else if (inputType == TFloat64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } else if (inputType == TInt32.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } else if (inputType == TInt64.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } else if (inputType == TUint8.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } else if (inputType == TString.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } else if (inputType == TBool.class) { - Output o = (Output) input; - AtomicInteger index = new AtomicInteger(); - o.asTensor() - .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); - } else { - writer.println("Unexpected Class: " + inputType); - } - writer.flush(); + public void print( + PrintWriter writer, Operand input, Map, Tensor> feedMap) { + super.print(writer, input.asTensor(), input.type()); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 43c0642939e..0583913a6c6 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -14,34 +14,31 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.FloatNdArray; -import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; -import org.tensorflow.types.*; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TString; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; import java.io.PrintWriter; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; +import java.util.Map; import java.util.function.Predicate; -import static org.junit.jupiter.api.Assertions.*; - -/** - * Graph Mode Test Session - */ +/** Graph Mode Test Session */ public class GraphTestSession extends TestSession { private final Graph graph; private final Session session; private final Ops tf; - /** - * Create a Graph mode test session. - */ + /** Create a Graph mode test session. */ public GraphTestSession() { graph = new Graph(); session = new Session(graph); @@ -49,16 +46,25 @@ public GraphTestSession() { } /** - * {@inheritDoc} + * Create a Graph mode test session. + * + * @param graph the graph + * @param session the session + * @param tf the TensorFlow Ops */ + public GraphTestSession(Graph graph, Session session, Ops tf) { + this.graph = graph; + this.session = session; + this.tf = tf; + } + + /** {@inheritDoc} */ @Override public Ops getTF() { return tf; } - /** - * Get the Graph object that is represented by this Test Session - */ + /** Get the Graph object that is represented by this Test Session */ public Graph getGraph() { return graph; } @@ -72,1051 +78,154 @@ public Session getSession() { return session; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void close() { session.close(); graph.close(); } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public boolean isEager() { return false; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public Session getGraphSession() { return this.session; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public EagerSession getEagerSession() { return null; } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override public void initialize() { graph.initializers().forEach(initializer -> session.runner().addTarget(initializer).run()); } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void run(Op op) { - session.run(op); + public void run(Op op, Map, Tensor> feedMap) { + createRunner(op, feedMap).run(); } /** - * {@inheritDoc} + * Create a runner for the Operation + * + * @param feedMap the dictionary of values to use for the runner's feed operations. Required when + * placeholders are used. + * @return the runner */ - @Override - public void evaluate(double expected, Operand input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((float) expected, f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals(expected, f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((int) expected, f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((long) expected, f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result.scalars().forEach(f -> assertEquals((long) expected, f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); - } + public final Session.Runner createRunner(Map, Tensor> feedMap) { + return createRunner(null, feedMap); } /** - * {@inheritDoc} + * Create a runner for the Operation + * + * @param op the operation + * @param feedMap the dictionary of values to use for the runner's feed operations. Required when + * placeholders are used. + * @return the runner */ + public final Session.Runner createRunner(Op op, Map, Tensor> feedMap) { + Session.Runner runner = session.runner(); + if (op != null) runner.addTarget(op.op()); + if (feedMap != null) feedMap.forEach((operand, tensor) -> runner.feed(operand, tensor)); + + return runner; + } + + /** {@inheritDoc} */ @Override - public void evaluate(Number[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - if (size != Shape.UNKNOWN_SIZE) { - assertEquals( - expected.length, - size, - () -> - String.format("expected length (%d) != to input length (%d)", expected.length, size)); - } - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].floatValue(), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected[index.getAndIncrement()].doubleValue(), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].intValue(), f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()].longValue(), f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); + public void evaluate( + double expected, Operand input, Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor, input.type()); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(FloatNdArray expected, Output input) { - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicLong index = new AtomicLong(); - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected.getFloat(index.getAndIncrement()), f.getFloat(), epsilon)); - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> - assertEquals( - expected.getFloat(index.getAndIncrement()), f.getDouble(), epsilon)); - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((int) expected.getFloat(index.getAndIncrement()), f.getInt())); - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getLong())); - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %d\n", index.getAndIncrement(), f.getByte())); - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach( - f -> assertEquals((long) expected.getFloat(index.getAndIncrement()), f.getByte())); - } - } else { - fail("Unexpected type class: " + inputType); + public void evaluate( + Number[] expected, Operand input, Map, Tensor> feedMap) { + + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor, input.type()); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(String[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - if (size != Shape.UNKNOWN_SIZE) { - assertEquals( - expected.length, - size, - () -> - String.format("expected length (%d) != to input length (%d)", expected.length, size)); - } - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } - } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + public void evaluate( + FloatNdArray expected, Operand input, Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor, input.type()); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(Boolean[] expected, Output input) { - int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> System.out.printf("%d). %b\n", index.getAndIncrement(), f.getObject())); - } - } - index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - result - .scalars() - .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + public void evaluateString( + Operand input, + Predicate predicate, + Map, Tensor> feedMap) { + try (TString tensor = (TString) createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluateString(tensor, predicate); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(Output expected, Output input) { - assert input.shape().equals(expected.shape()) - : String.format( - "expected shape (%s) != to input shape (%s)", - expected.shape().toString(), input.shape().toString()); - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - if (!inputType.equals(expected.type())) { - throw new IllegalArgumentException( - String.format( - "Both data type must be equal, inout = %s, expected = %s", - inputType, expected.dataType())); - } - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getFloat(idx), - f.getFloat())); - } - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat32 expectedResult = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); - } - } - } else if (inputType == TFloat64.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getDouble(), result.getDouble()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getDouble(idx), - f.getDouble())); - } - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat64 expectedResult = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getDouble(), result.getDouble(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getDouble(idx), f.getDouble(), epsilon)); - } - } - } else if (inputType == TFloat16.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %f <==> %f\n", expectedResult.getFloat(), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %f <==> %f\n", - index.getAndIncrement(), - finalExpected.asTensor().getFloat(idx), - f.getFloat())); - } - } - } - index.set(0); - try (TFloat16 result = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0); - TFloat16 expectedResult = - (TFloat16)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getFloat(), result.getFloat(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getFloat(idx), f.getFloat(), epsilon)); - } - } - } else if (inputType == TInt32.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); - TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getInt(), result.getInt()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), finalExpected.asTensor().getInt(idx), f.getInt())); - } - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0); - TInt32 expectedResult = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getInt(), result.getInt(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getInt(idx), f.getInt(), epsilon)); - } - } - } else if (inputType == TInt64.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); - TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getLong(), result.getLong()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), - finalExpected.asTensor().getLong(idx), - f.getLong())); - } - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0); - TInt64 expectedResult = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getLong(), result.getLong(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getLong(idx), f.getLong(), epsilon)); - } - } - } else if (inputType == TUint8.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); - TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %d <==> %d\n", expectedResult.getByte(), result.getByte()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %d <==> %d\n", - index.getAndIncrement(), - finalExpected.asTensor().getByte(idx), - f.getByte())); - } - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0); - TUint8 expectedResult = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getByte(), result.getByte(), epsilon); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - assertEquals(expectedResult.getByte(idx), f.getByte(), epsilon)); - } - } - } else if (inputType == TBool.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); - TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %b\n", expectedResult.getBoolean(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %b\n", - index.getAndIncrement(), - finalExpected.asTensor().getBoolean(idx), - f.getBoolean())); - } - } - } - index.set(0); - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0); - TBool expectedResult = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getBoolean(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getBoolean(idx), f.getBoolean())); - } - } - } else if (inputType == TString.class) { - final Output finalExpected = (Output) expected; - if (debug) { - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); - TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %s <==> %s\n", expectedResult.getObject(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %s <==> %s\n", - index.getAndIncrement(), - finalExpected.asTensor().getObject(idx), - f.getObject())); - } - } - } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0); - TString expectedResult = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertEquals(expectedResult.getObject(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> assertEquals(expectedResult.getObject(idx), f.getObject())); - } - } - } else { - fail("Unexpected type class: " + inputType); + public void evaluate( + Operand input, + Predicate predicate, + Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(tensor, input.type(), predicate); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluateString(Output input, Predicate predicate) { - boolean isScalar = input.shape().equals(Shape.scalar()); - AtomicInteger index = new AtomicInteger(); - if (debug) { - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %s\n", - predicate.test(result.getObject()), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %s\n", - index.getAndIncrement(), predicate.test(f.getObject()), f.getObject())); - } - } - } - index.set(0); - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getObject())); - } else { - result - .scalars() - .forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); - } + public void evaluate( + String[] expected, Operand input, Map, Tensor> feedMap) { + try (TString tensor = (TString) createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void evaluate(Output input, Predicate predicate) { - AtomicInteger index = new AtomicInteger(); - Class inputType = input.type(); - boolean isScalar = input.shape().equals(Shape.scalar()); - if (inputType == TFloat32.class) { - if (debug) { - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getFloat()), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getFloat()), f.getFloat())); - } - } - } - index.set(0); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getFloat())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getFloat()))); - } - } - } else if (inputType == TFloat64.class) { - if (debug) { - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %f\n", - predicate.test(result.getDouble()), result.getDouble()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %f\n", - index.getAndIncrement(), predicate.test(f.getDouble()), f.getDouble())); - } - } - } - index.set(0); - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getDouble())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getDouble()))); - } - } - } else if (inputType == TInt32.class) { - if (debug) { - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", predicate.test(result.getInt()), result.getInt()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getInt()), f.getInt())); - } - } - } - index.set(0); - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getInt())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getInt()))); - } - } - } else if (inputType == TInt64.class) { - if (debug) { - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getLong()), result.getLong()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getLong()), f.getLong())); - } - } - } - index.set(0); - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getLong())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getLong()))); - } - } - } else if (inputType == TUint8.class) { - if (debug) { - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - System.out.printf( - "0). %b <==> %d\n", - predicate.test(result.getByte()), result.getByte()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> - System.out.printf( - "%d). %b <==> %d\n", - index.getAndIncrement(), predicate.test(f.getByte()), f.getByte())); - } - } - } - index.set(0); - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - assertTrue(predicate.test(result.getByte())); - } else { - result - .scalars() - .forEachIndexed((idx, f) -> assertTrue(predicate.test(result.getByte()))); - } - } - } else { - fail("Unexpected type class: " + inputType); + public void evaluate( + Boolean[] expected, Operand input, Map, Tensor> feedMap) { + try (TBool tensor = (TBool) createRunner(feedMap).fetch(input).run().get(0)) { + super.evaluate(expected, tensor); } } - /** - * {@inheritDoc} - */ + /** {@inheritDoc} */ @Override - public void print(PrintWriter writer, Output input) { - boolean isScalar = input.shape().size() == 1; - - Class inputType = input.type(); - if (inputType == TFloat32.class) { - AtomicInteger index = new AtomicInteger(); - try (TFloat32 result = - (TFloat32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf("%d). %f\n", index.getAndIncrement(), result.getFloat()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getFloat())); - } - } - } else if (inputType == TFloat64.class) { - AtomicInteger index = new AtomicInteger(); - - try (TFloat64 result = - (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %f\n", index.getAndIncrement(), result.getDouble()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %f\n", index.getAndIncrement(), f.getDouble())); - } - } - } else if (inputType == TInt32.class) { - AtomicInteger index = new AtomicInteger(); - - try (TInt32 result = - (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(),result.getInt()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getInt())); - } - } - } else if (inputType == TInt64.class) { - AtomicInteger index = new AtomicInteger(); - - try (TInt64 result = - (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %d\n", index.getAndIncrement(), result.getLong()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %d\n", index.getAndIncrement(), f.getLong())); - } - } - } else if (inputType == TUint8.class) { - AtomicInteger index = new AtomicInteger(); - - try (TUint8 result = - (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %x\n", index.getAndIncrement(), result.getByte()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %x\n", index.getAndIncrement(), f.getByte())); - } - } - } else if (inputType == TBool.class) { - AtomicInteger index = new AtomicInteger(); - - try (TBool result = - (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %b\n", index.getAndIncrement(), result.getBoolean()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %b\n", index.getAndIncrement(), f.getBoolean())); - } - } - } else if (inputType == TString.class) { - AtomicInteger index = new AtomicInteger(); + public void evaluate( + Operand expected, Operand input, Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0); + Tensor expectedTensor = createRunner(feedMap).fetch(expected).run().get(0)) { + super.evaluate(expectedTensor, tensor, input.type()); + } + } - try (TString result = - (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { - if (isScalar) { - writer.printf( - "%d). %s\n", index.getAndIncrement(), result.getObject()); - } else { - result - .scalars() - .forEachIndexed( - (idx, f) -> writer.printf("%d). %s\n", index.getAndIncrement(), f.getObject())); - } - } - } else { - writer.println("Unexpected type class: " + inputType); + /** {@inheritDoc} */ + @Override + public void print( + PrintWriter writer, Operand input, Map, Tensor> feedMap) { + try (Tensor tensor = createRunner(feedMap).fetch(input).run().get(0)) { + super.print(writer, tensor, input.type()); } - writer.flush(); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index c0c0f12fbf9..a103ef9884a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -14,7 +14,11 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.ndarray.*; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.NdArrays; +import org.tensorflow.ndarray.Shape; import java.util.Arrays; import java.util.concurrent.atomic.AtomicBoolean; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/QuadConsumer.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/QuadConsumer.java new file mode 100644 index 00000000000..7f576e8b614 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/QuadConsumer.java @@ -0,0 +1,42 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.utils; + +/** + * Represents an operation that accepts four input arguments and returns no result. This is the + * quaternary specialization of {@link java.util.function.Consumer}. Unlike most other functional + * interfaces, {@code QuadConsumer} is expected to operate via side-effects. + * + *

This is a functional interface whose functional method is {@link #accept(Object, Object, + * Object, Object)}. + * + * @param the type of the first argument to the operation + * @param the type of the second argument to the operation + * @param the type of the third argument to the operation + * @param the type of the fourth argument to the operation + */ +@FunctionalInterface +interface QuadConsumer { + + /** + * Performs this operation on the given arguments. + * + * @param t the first input argument + * @param s the second input argument + * @param u the third input argument + * @param v the forth input argument + */ + void accept(T t, S s, U u, V v); +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index 2c252d467c7..d5578042321 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -14,13 +14,25 @@ =======================================================================*/ package org.tensorflow.framework.utils; -import org.tensorflow.*; -import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.EagerSession; +import org.tensorflow.Graph; +import org.tensorflow.Operand; +import org.tensorflow.Session; +import org.tensorflow.Tensor; import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.NdArray; +import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; +import org.tensorflow.types.TBfloat16; import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat16; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; +import org.tensorflow.types.TUint8; import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; @@ -28,14 +40,229 @@ import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.io.Writer; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.BiConsumer; import java.util.function.Predicate; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; /** Base class for Test Session */ public abstract class TestSession implements AutoCloseable { - protected float epsilon = 1e-5F; + protected static final Map< + Class, TriConsumer>> + printMap = + new HashMap< + Class, + TriConsumer>>() { + { + put( + TUint8.class, + (writer, idx, o) -> + writer.printf( + "%s. %s\n", Arrays.toString(idx), ((Number) o.getObject()).byteValue())); + put( + TInt32.class, + (writer, idx, o) -> + writer.printf( + "%s. %d\n", Arrays.toString(idx), ((Number) o.getObject()).intValue())); + put( + TInt64.class, + (writer, idx, o) -> + writer.printf( + "%s. %d\n", Arrays.toString(idx), ((Number) o.getObject()).longValue())); + put( + TFloat32.class, + (writer, idx, o) -> + writer.printf( + "%s. %f\n", Arrays.toString(idx), ((Number) o.getObject()).floatValue())); + put( + TFloat64.class, + (writer, idx, o) -> + writer.printf( + "%s. %f\n", + Arrays.toString(idx), ((Number) o.getObject()).doubleValue())); + put( + TBfloat16.class, + (writer, idx, o) -> + writer.printf( + "%s. %f\n", Arrays.toString(idx), ((Number) o.getObject()).floatValue())); + put( + TFloat16.class, + (writer, idx, o) -> + writer.printf( + "%s. %f\n", Arrays.toString(idx), ((Number) o.getObject()).floatValue())); + put( + TBool.class, + (writer, idx, o) -> + writer.printf("%s. %b\n", Arrays.toString(idx), o.getObject())); + put( + TString.class, + (writer, idx, o) -> + writer.printf("%s. %s\n", Arrays.toString(idx), o.getObject())); + } + }; + protected static final Map< + Class, + QuadConsumer, NdArray>> + printPredicate = + new HashMap< + Class, + QuadConsumer, NdArray>>() { + { + put( + TUint8.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %d\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).byteValue()), + ((Number) o.getObject()).byteValue())); + put( + TInt32.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %d\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).intValue()), + ((Number) o.getObject()).intValue())); + put( + TInt64.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %d\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).longValue()), + ((Number) o.getObject()).longValue())); + put( + TFloat32.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %f\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).floatValue()), + ((Number) o.getObject()).floatValue())); + put( + TFloat64.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %f\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).doubleValue()), + ((Number) o.getObject()).doubleValue())); + put( + TBfloat16.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %f\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).floatValue()), + ((Number) o.getObject()).floatValue())); + put( + TFloat16.class, + (writer, idx, predicate, o) -> + writer.printf( + "%s. %b <==> %f\n", + Arrays.toString(idx), + predicate.test(((Number) o.getObject()).floatValue()), + ((Number) o.getObject()).floatValue())); + } + }; + protected static final Map< + Class, BiConsumer, NdArray>> + evalPredicate = + new HashMap< + Class, BiConsumer, NdArray>>() { + { + put( + TUint8.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).byteValue()))); + put( + TInt32.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).intValue()))); + put( + TInt64.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).longValue()))); + put( + TFloat32.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).floatValue()))); + put( + TFloat64.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).doubleValue()))); + put( + TBfloat16.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).floatValue()))); + put( + TFloat16.class, + (predicate, o) -> + assertTrue(predicate.test(((Number) o.getObject()).floatValue()))); + } + }; + private static final long[] ZERO_IDX = new long[0]; + private static final PrintWriter DEFAULT_WRITER = new PrintWriter(System.out); + protected static float epsilon = 1e-5F; + protected static final Map, BiConsumer>> + evalMap = + new HashMap, BiConsumer>>() { + { + put( + TUint8.class, + (expected, o) -> + assertEquals( + ((Number) expected).byteValue(), ((Number) o.getObject()).byteValue())); + put( + TInt32.class, + (expected, o) -> + assertEquals( + ((Number) expected).intValue(), ((Number) o.getObject()).intValue())); + put( + TInt64.class, + (expected, o) -> + assertEquals( + ((Number) expected).longValue(), ((Number) o.getObject()).longValue())); + put( + TFloat32.class, + (expected, o) -> + assertEquals( + ((Number) expected).floatValue(), + ((Number) o.getObject()).floatValue(), + epsilon)); + put( + TFloat64.class, + (expected, o) -> + assertEquals( + ((Number) expected).doubleValue(), + ((Number) o.getObject()).doubleValue(), + epsilon)); + put( + TBfloat16.class, + (expected, o) -> + assertEquals( + ((Number) expected).floatValue(), + ((Number) o.getObject()).floatValue(), + epsilon)); + put( + TFloat16.class, + (expected, o) -> + assertEquals( + ((Number) expected).floatValue(), + ((Number) o.getObject()).floatValue(), + epsilon)); + put(TBool.class, (expected, o) -> assertEquals(expected, o.getObject())); + put(TString.class, (expected, o) -> assertEquals(expected, o.getObject().toString())); + } + }; protected boolean debug; /** @@ -56,6 +283,18 @@ public static TestSession createGraphSession() { return new GraphTestSession(); } + /** + * Creates a Graph Test Session without creating its own graph + * + * @param graph the graph + * @param session the session + * @param tf the TensorFlow Ops + * @return the Graph Test Session + */ + public static TestSession createGraphSession(Graph graph, Session session, Ops tf) { + return new GraphTestSession(graph, session, tf); + } + /** * Creates a Test Session * @@ -66,17 +305,47 @@ public static TestSession createTestSession(Mode mode) { return mode == Mode.EAGER ? createEagerSession() : createGraphSession(); } + /** + * Get the epsilon value for evaluating float values + * + * @return the epsilon value for evaluating float values + */ + public static float getEpsilon() { + return epsilon; + } + + /** + * Set the epsilon value for evaluating float values + * + * @param epsilonValue the epsilon value for evaluating float values + */ + public static void setEpsilon(float epsilonValue) { + epsilon = epsilonValue; + } + /** Initializes the Test Session, default implementation is do nothing. */ public void initialize() { // empty } /** - * Runs the Operation + * Runs the Operation, in EagerMode this does nothing * - * @param op the Operation to run + * @param op the Operation to run. */ + @SuppressWarnings("unused") public void run(Op op) { + run(op, null); + } + + /** + * Runs the Operation, in EagerMode this does nothing + * + * @param op the Operation to run. + * @param feedMap a optional Map to feed to the run session when placeholders are used. + */ + @SuppressWarnings("unused") + public void run(Op op, Map, Tensor> feedMap) { // empty } @@ -98,7 +367,7 @@ public Graph getGraph() { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(Number expected, Operand input) { - evaluate(new Number[] {expected}, input); + evaluate(new Number[] {expected}, input, null); } /** @@ -106,22 +375,25 @@ public void evaluate(Number expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Number expected, Op input) { - evaluate(new Number[] {expected}, input); + public void evaluate( + Number expected, Operand input, Map, Tensor> feedMap) { + evaluate(new Number[] {expected}, input, feedMap); } /** - * Evaluates the input against the expected values + * Evaluates the input against the expected value * - * @param expected the expected values + * @param expected the expected value * @param input the operand to evaluate + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Number[] expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); + public void evaluate(byte expected, Operand input) { + evaluate((double) expected, input, null); } /** @@ -129,12 +401,13 @@ public void evaluate(Number[] expected, Op input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Number[] expected, Operand input) { - Output output = input.asOutput(); - evaluate(expected, output); + public void evaluate( + byte expected, Operand input, Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -145,8 +418,8 @@ public void evaluate(Number[] expected, Operand input) { * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(byte expected, Operand input) { - evaluate((double) expected, input); + public void evaluate(int expected, Operand input) { + evaluate((double) expected, input, null); } /** @@ -154,11 +427,13 @@ public void evaluate(byte expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(int expected, Operand input) { - evaluate((double) expected, input); + public void evaluate( + int expected, Operand input, Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -170,7 +445,21 @@ public void evaluate(int expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(long expected, Operand input) { - evaluate((double) expected, input); + evaluate((double) expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + long expected, Operand input, Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -182,7 +471,21 @@ public void evaluate(long expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(float expected, Operand input) { - evaluate((double) expected, input); + evaluate((double) expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + float expected, Operand input, Map, Tensor> feedMap) { + evaluate((double) expected, input, feedMap); } /** @@ -193,7 +496,21 @@ public void evaluate(float expected, Operand input) { * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(double expected, Operand input); + public void evaluate(double expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public abstract void evaluate( + double expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -204,9 +521,23 @@ public void evaluate(float expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(byte[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + byte[] expected, Operand input, Map, Tensor> feedMap) { Byte[] iArray = new Byte[expected.length]; for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(iArray, input, feedMap); } /** @@ -218,9 +549,23 @@ public void evaluate(byte[] expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(int[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + int[] expected, Operand input, Map, Tensor> feedMap) { Integer[] iArray = new Integer[expected.length]; for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(iArray, input, feedMap); } /** @@ -232,9 +577,23 @@ public void evaluate(int[] expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(long[] expected, Operand input) { + evaluate(expected, input, null); + } + + /** + * Evaluates the input against the expected value + * + * @param expected the expected value + * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + public void evaluate( + long[] expected, Operand input, Map, Tensor> feedMap) { Long[] iArray = new Long[expected.length]; for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(iArray, input, feedMap); } /** @@ -246,23 +605,23 @@ public void evaluate(long[] expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(float[] expected, Operand input) { - Float[] iArray = new Float[expected.length]; - for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(expected, input, null); } /** - * Evaluates the input against the expected value + * Evaluates the input against the expected values * - * @param expected the expected value + * @param expected the expected values * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(double[] expected, Operand input) { - Double[] iArray = new Double[expected.length]; + public void evaluate( + float[] expected, Operand input, Map, Tensor> feedMap) { + Float[] iArray = new Float[expected.length]; for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; - evaluate(iArray, input); + evaluate(iArray, input, feedMap); } /** @@ -273,17 +632,24 @@ public void evaluate(double[] expected, Operand input) { * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(Number[] expected, Output input); + public void evaluate(double[] expected, Operand input) { + evaluate(expected, input, null); + } /** * Evaluates the input against the expected value * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(String expected, Operand input) { - evaluate(new String[] {expected}, input); + public void evaluate( + double[] expected, Operand input, Map, Tensor> feedMap) { + Double[] iArray = new Double[expected.length]; + for (int i = 0; i < expected.length; i++) iArray[i] = expected[i]; + evaluate(iArray, input, feedMap); } /** @@ -291,10 +657,11 @@ public void evaluate(String expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(String expected, Op input) { - evaluate(new String[] {expected}, input); + public void evaluate(Number[] expected, Operand input) { + evaluate(expected, input, null); } /** @@ -302,12 +669,12 @@ public void evaluate(String expected, Op input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(String[] expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); - } + public abstract void evaluate( + Number[] expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -316,9 +683,8 @@ public void evaluate(String[] expected, Op input) { * @param input the operand to evaluate * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(String[] expected, Operand input) { - Output output = input.asOutput(); - evaluate(expected, output); + public void evaluate(String expected, Operand input) { + evaluate(new String[] {expected}, input, null); } /** @@ -326,9 +692,13 @@ public void evaluate(String[] expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(String[] expected, Output input); + public void evaluate( + String expected, Operand input, Map, Tensor> feedMap) { + evaluate(new String[] {expected}, input, feedMap); + } /** * Evaluates the input against the expected value @@ -337,8 +707,8 @@ public void evaluate(String[] expected, Operand input) { * @param input the operand to evaluate * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Boolean expected, Operand input) { - evaluate(new Boolean[] {expected}, input); + public void evaluate(String[] expected, Operand input) { + evaluate(expected, input, null); } /** @@ -346,11 +716,11 @@ public void evaluate(Boolean expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Boolean expected, Op input) { - evaluate(new Boolean[] {expected}, input); - } + public abstract void evaluate( + String[] expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -359,9 +729,8 @@ public void evaluate(Boolean expected, Op input) { * @param input the operand to evaluate * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Boolean[] expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); + public void evaluate(Boolean expected, Operand input) { + evaluate(new Boolean[] {expected}, input, null); } /** @@ -369,11 +738,12 @@ public void evaluate(Boolean[] expected, Op input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Boolean[] expected, Operand input) { - Output output = input.asOutput(); - evaluate(expected, output); + public void evaluate( + Boolean expected, Operand input, Map, Tensor> feedMap) { + evaluate(new Boolean[] {expected}, input, feedMap); } /** @@ -383,20 +753,20 @@ public void evaluate(Boolean[] expected, Operand input) { * @param input the operand to evaluate * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(Boolean[] expected, Output input); + public void evaluate(Boolean[] expected, Operand input) { + evaluate(expected, input, null); + } /** * Evaluates the input against the expected value * * @param expected the expected value * @param input the operand to evaluate - * @param the data type of the expected Operand + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void evaluate(Operand expected, Op input) { - Output output = input.op().output(0); - evaluate(expected, output); - } + public abstract void evaluate( + Boolean[] expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -407,7 +777,7 @@ public void evaluate(Operand expected, Op input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(Operand expected, Operand input) { - evaluate(expected.asOutput(), input.asOutput()); + evaluate(expected, input, null); } /** @@ -415,10 +785,12 @@ public void evaluate(Operand expected, Operand input) { * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(Output expected, Output input); + public abstract void evaluate( + Operand expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -429,7 +801,7 @@ public void evaluate(Operand expected, Operand input) { * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(FloatNdArray expected, Operand input) { - evaluate(expected, input.asOutput()); + evaluate(expected, input, null); } /** @@ -437,21 +809,23 @@ public void evaluate(FloatNdArray expected, Operand input) * * @param expected the expected value * @param input the operand to evaluate + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(FloatNdArray expected, Output input); + public abstract void evaluate( + FloatNdArray expected, Operand input, Map, Tensor> feedMap); /** * Evaluates the input against the expected value * * @param input the operand to evaluate - * @param predicate the Predicate + * @param predicate The Predicate that evaluates the each value from input * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ public void evaluate(Operand input, Predicate predicate) { - evaluate(input.asOutput(), predicate); + evaluate(input, predicate, null); } /** @@ -459,10 +833,12 @@ public void evaluate(Operand input, Predicate predi * * @param input the operand to evaluate * @param predicate The Predicate that evaluates the each value from input + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluate(Output input, Predicate predicate); + public abstract void evaluate( + Operand input, Predicate predicate, Map, Tensor> feedMap); /** * Evaluates the input against the expected string value @@ -471,7 +847,7 @@ public void evaluate(Operand input, Predicate predi * @param predicate The Predicate that evaluates the each value from input */ public void evaluateString(Operand input, Predicate predicate) { - evaluateString(input.asOutput(), predicate); + evaluateString(input, predicate, null); } /** @@ -479,8 +855,13 @@ public void evaluateString(Operand input, Predicate predicate) * * @param input the operand to evaluate * @param predicate The Predicate that evaluates the each value from input + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract void evaluateString(Output input, Predicate predicate); + public abstract void evaluateString( + Operand input, + Predicate predicate, + Map, Tensor> feedMap); /** * Evaluates the input against the expected value @@ -494,106 +875,110 @@ public void evaluate(FloatNdArray input, Predicate predicate) { } /** - * Evaluates the input against the expected value - * - * @param input the operand to evaluate - * @param predicate The Predicate that evaluates the each value from input - * @throws org.opentest4j.AssertionFailedError if the evaluation fails - */ - public void evaluate(DoubleNdArray input, Predicate predicate) { - input.scalars().forEach(f -> assertTrue(predicate.test(f.getDouble()))); - } - - /** - * Print the input + * Prints the input's values to standard out * - * @param out the output stream * @param input the operand to print * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(OutputStream out, Operand input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput()); + public void print(Operand input) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input, null); } /** - * Print the input to standard out + * Prints the input's values to standard out * - * @param input the op to print + * @param input the operand to print + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(Op input) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0)); + public void print( + Operand input, Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input, feedMap); } /** - * Print the input + * Prints the input's values to the output stream * * @param out the output stream - * @param input the op to print - */ - public void print(OutputStream out, Op input) { - print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0)); - } - - /** - * Print the input to standard out - * - * @param input the op to print + * @param input the operand to print * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(Output input) { - print(new PrintWriter(new OutputStreamWriter(System.out)), input); + public void print(OutputStream out, Operand input) { + print(new PrintWriter(new OutputStreamWriter(out)), input, null); } /** - * Print the input + * Prints the input's values to the output stream * * @param out the output stream - * @param input the op to print + * @param input the operand to print + * @param feedMap a optional Map to feed to the run session when placeholders are used. * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(OutputStream out, Output input) { - print(new PrintWriter(new OutputStreamWriter(out)), input); + public void print( + OutputStream out, Operand input, Map, Tensor> feedMap) { + print(new PrintWriter(new OutputStreamWriter(out)), input, feedMap); } /** - * Print the input + * Prints the input's values to the PrintWriter * * @param writer the output writer - * @param input the operand to print + * @param input the op to print * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ public void print(Writer writer, Operand input) { - print(new PrintWriter(writer), input.asOutput()); + print(new PrintWriter(writer), input, null); } /** - * Print the input + * Prints the input's values to the PrintWriter * * @param writer the output writer * @param input the op to print + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(Writer writer, Op input) { - print(new PrintWriter(writer), input.op().output(0)); + public void print( + Writer writer, Operand input, Map, Tensor> feedMap) { + print(new PrintWriter(writer), input, feedMap); } /** - * Print the input + * Prints the input's values to the PrintWriter * * @param writer the output writer * @param input the op to print - * @param the data type of the input + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public void print(Writer writer, Output input) { - print(new PrintWriter(writer), input); + public void print(PrintWriter writer, Operand input) { + print(writer, input, null); } /** - * Print the input + * Prints the input's values to the PrintWriter * * @param writer the output writer * @param input the op to print + * @param feedMap a optional Map to feed to the run session when placeholders are used. + * @throws IllegalArgumentException if the data type for the input does not have a print function + * registered. */ - public abstract void print(PrintWriter writer, Output input); + public abstract void print( + PrintWriter writer, Operand input, Map, Tensor> feedMap); /** * Get the TensorFlow Ops @@ -619,36 +1004,246 @@ public boolean isGraph() { } /** - * Get the epsilon value for evaluating float values + * Get the TensorFlow session object associated with this Test Session * - * @return the epsilon value for evaluating float values + * @return a TensorFlow session if this is a Graph session, otherwise null + */ + public abstract Session getGraphSession(); + + /** + * Get the TensorFlow eager session object associated with this Test Session + * + * @return a TensorFlow session if this is an eager session, otherwise null + */ + public abstract EagerSession getEagerSession(); + + // The following methods are called by the subclasses, + // after resolving the tensor for the Operands + + /** + * Evaluates the tensor's values against the expected value + * + * @param expected the expected value + * @param tensor the tensor whose values are compared to the expected values. + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public float getEpsilon() { - return this.epsilon; + @SuppressWarnings("unchecked") + protected void evaluate(double expected, Tensor tensor, Class type) { + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + BiConsumer> evaluateFunc = evalMap.get(type); + if (evaluateFunc == null) fail("Unexpected Type Class: " + type); + if (isScalar) evaluateFunc.accept(expected, (NdArray) tensor); + else ((NdArray) tensor).scalars().forEach(f -> evaluateFunc.accept(expected, f)); } /** - * Set the epsilon value for evaluating float values + * Evaluates the tensor's values against the expected values * - * @param epsilon the epsilon value for evaluating float values + * @param expected the expected values + * @param tensor the tensor whose values are compared to the expected values. + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public void setEpsilon(float epsilon) { - this.epsilon = epsilon; + @SuppressWarnings("unchecked") + protected void evaluate(Number[] expected, Tensor tensor, Class type) { + int size = tensor.shape().size() == 0 ? 1 : (int) tensor.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + + boolean isScalar = tensor.shape().equals(Shape.scalar()); + AtomicInteger index = new AtomicInteger(); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + BiConsumer> evaluateFunc = evalMap.get(type); + if (evaluateFunc == null) fail("Unexpected Type Class: " + type); + if (isScalar) evaluateFunc.accept(expected[0], (NdArray) tensor); + else + ((NdArray) tensor) + .scalars() + .forEach(f -> evaluateFunc.accept(expected[index.getAndIncrement()], f)); } /** - * Get the TensorFlow session object associated with this Test Session + * Evaluates the tensor's values against the expected values * - * @return a TensorFlow session if this is a Graph session, otherwise null + * @param expected the expected values + * @param tensor the tensor whose values are compared to the expected values. + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract Session getGraphSession(); + @SuppressWarnings("unchecked") + protected void evaluate(FloatNdArray expected, Tensor tensor, Class type) { + boolean isScalar = tensor.shape().equals(Shape.scalar()); + AtomicInteger index = new AtomicInteger(); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + BiConsumer> evaluateFunc = evalMap.get(type); + if (evaluateFunc == null) fail("Unexpected Type Class: " + type); + if (isScalar) + evaluateFunc.accept(expected.getObject(index.getAndIncrement()), (NdArray) tensor); + else + ((NdArray) tensor) + .scalars() + .forEach(f -> evaluateFunc.accept(expected.getObject(index.getAndIncrement()), f)); + } /** - * Get the TensorFlow eager session object associated with this Test Session + * Evaluates the tensor's values against the predicate test * - * @return a TensorFlow session if this is an eager session, otherwise null + * @param tensor the tensor to evaluate + * @param predicate the predicate to test the value of the tensor values + * @throws org.opentest4j.AssertionFailedError if the evaluation fails */ - public abstract EagerSession getEagerSession(); + protected void evaluateString(TString tensor, Predicate predicate) { + AtomicInteger index = new AtomicInteger(); + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, TString.class); + } + index.set(0); + if (isScalar) { + assertTrue(predicate.test(tensor.getObject())); + } else { + tensor.scalars().forEachIndexed((idx, s) -> assertTrue(predicate.test(s.getObject()))); + } + } + + /** + * Evaluates the tensor's values against the predicate test + * + * @param tensor the tensor to evaluate + * @param type the data type of the tensor + * @param predicate the predicate to test the value of the tensor values + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + @SuppressWarnings("unchecked") + protected void evaluate( + Tensor tensor, Class type, Predicate predicate) { + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + BiConsumer, NdArray> evalFunc = evalPredicate.get(type); + if (evalFunc == null) fail("Unexpected Type Class: " + type); + if (isScalar) { + evalFunc.accept(predicate, (NdArray) tensor); + } else { + ((NdArray) tensor).scalars().forEach(f -> evalFunc.accept(predicate, f)); + } + } + + /** + * Evaluates the tensor's values against the expected values + * + * @param expected the expected values + * @param tensor the tensor whose values are compared to the expected values + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + protected void evaluate(String[] expected, TString tensor) { + int size = tensor.shape().size() == 0 ? 1 : (int) tensor.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, TString.class); + } + if (isScalar) assertEquals(expected[0], tensor.getObject()); + else + tensor.scalars().forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getObject())); + } + + /** + * Evaluates the tensor's values against the expected values + * + * @param expected the expected value + * @param tensor the tensor whose values are compared to the expected values + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + protected void evaluate(Boolean[] expected, TBool tensor) { + int size = tensor.shape().size() == 0 ? 1 : (int) tensor.shape().size(); + assertEquals( + expected.length, + size, + () -> String.format("expected size (%d) != to input length (%d)", expected.length, size)); + AtomicInteger index = new AtomicInteger(); + boolean isScalar = tensor.shape().equals(Shape.scalar()); + if (debug) { + print(DEFAULT_WRITER, tensor, TBool.class); + } + if (isScalar) assertEquals(expected[index.getAndIncrement()], tensor.getBoolean()); + else + tensor + .scalars() + .forEach(f -> assertEquals(expected[index.getAndIncrement()], f.getBoolean())); + } + + /** + * Evaluates the tensor's values against the expected tensor's values + * + * @param expected the tensor whose values are expected + * @param tensor the tensor whose values are compared to the expected values. + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws org.opentest4j.AssertionFailedError if the evaluation fails + */ + @SuppressWarnings("unchecked") + protected void evaluate(Tensor expected, Tensor tensor, Class type) { + assert tensor.shape().equals(expected.shape()) + : String.format( + "expected shape (%s) != to input shape (%s)", + expected.shape().toString(), tensor.shape().toString()); + boolean isScalar = tensor.shape().equals(Shape.scalar()); + + AtomicInteger index = new AtomicInteger(); + if (debug) { + print(DEFAULT_WRITER, tensor, type); + } + index.set(0); + BiConsumer> evaluateFunc = evalMap.get(type); + if (evaluateFunc == null) fail("Unexpected Type Class: " + type); + NdArray expectedArray = (NdArray) expected; + if (isScalar) evaluateFunc.accept(expectedArray.getObject(), (NdArray) tensor); + else + ((NdArray) tensor) + .scalars() + .forEachIndexed((idx, f) -> evaluateFunc.accept(expectedArray.getObject(idx), f)); + } + + /** + * Prints the tensor's values to the print writer + * + * @param writer the output writer + * @param tensor teh tensor to print + * @param type the data type of the tensor + * @param the data type of the tensor + * @throws IllegalArgumentException if the data type for the tensor does not have a print function + * registered. + */ + @SuppressWarnings("unchecked") + protected void print(PrintWriter writer, Tensor tensor, Class type) { + boolean isScalar = tensor.shape().equals(Shape.scalar()); + TriConsumer> printFunc = printMap.get(type); + if (printFunc == null) throw new IllegalArgumentException("Unexpected Type Class: " + type); + if (isScalar) printFunc.accept(writer, ZERO_IDX, (NdArray) tensor); + else + ((NdArray) tensor).scalars().forEachIndexed((idx, f) -> printFunc.accept(writer, idx, f)); + writer.flush(); + } /** {@inheritDoc} */ @Override diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TriConsumer.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TriConsumer.java new file mode 100644 index 00000000000..e67829eca92 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TriConsumer.java @@ -0,0 +1,40 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.utils; + +/** + * Represents an operation that accepts three input arguments and returns no result. This is the + * tertiary specialization of {@link java.util.function.Consumer}. Unlike most other functional + * interfaces, {@code TriConsumer} is expected to operate via side-effects. + * + *

This is a functional interface whose functional method is {@link #accept(Object, Object, + * Object)}. + * + * @param the type of the first argument to the operation + * @param the type of the second argument to the operation + * @param the type of the third argument to the operation + */ +@FunctionalInterface +interface TriConsumer { + + /** + * Performs this operation on the given arguments. + * + * @param t the first input argument + * @param u the second input argument + * @param v the third input argument + */ + void accept(T t, U u, V v); +} From f621a88eec9c623247da9304f5173abb979952ca Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Thu, 6 May 2021 19:45:20 -0400 Subject: [PATCH 28/31] changed tf.nn.raw to tf.nn based on Framework Ops change Add softmax test, fixed bugs in framework softmax --- .../tensorflow/framework/op/nn/Softmax.java | 34 ++-- .../tensorflow/framework/op/NnOpsTest.java | 164 ++++++++++++++++++ .../optimizers/GradientDescentTest.java | 42 +++-- 3 files changed, 208 insertions(+), 32 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java index ea300acb0b2..8f768425dc9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java @@ -21,6 +21,7 @@ import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.Concat; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.ExpandDims; import org.tensorflow.op.core.Range; import org.tensorflow.op.core.Rank; import org.tensorflow.op.core.Reshape; @@ -40,8 +41,8 @@ public class Softmax { /** - * Calculates a Softmax operation. If the exis is not the last dimension, then the input axis is - * moved to the last axis berfore calling tf.nn.softmax, then restored before returning. + * Calculates a Softmax operation. If the axis is not the last dimension, then the input axis is + * moved to the last axis before calling tf.nn.softmax, then restored before returning. * * @param scope The TensorFlow scope * @param input the input @@ -58,25 +59,30 @@ public static Operand softmax(Scope scope, Operand i return org.tensorflow.op.nn.Softmax.create(scope, input); } - if (axis <= -shape.numDimensions() || axis >= shape.numDimensions()) { + // validate axis + if (!(-shape.numDimensions() <= axis && axis < shape.numDimensions())) { throw new IllegalArgumentException( String.format( "Axis (%d) must be in the range [%d, %d] where %d is the number of dimensions in the input.", axis, -shape.numDimensions(), shape.numDimensions(), shape.numDimensions())); } + // If dim is not the last dimension, we have to do a transpose so that we can + // still perform the op on its last dimension. + + // In case dim is negative (and is not last dimension -1), convert to positive int dim = Math.floorMod(axis, shape.numDimensions()); - Operand rank = Rank.create(scope, input); + Operand inputRank = Rank.create(scope, input); Operand dimOp = Constant.scalarOf(scope, dim); Operand one = Constant.scalarOf(scope, 1); - Operand lastIndex = Sub.create(scope, rank, one); + Operand lastIndex = Sub.create(scope, inputRank, one); Operand swappedInputs = swapAxis(scope, input, dimOp, lastIndex); Operand output = org.tensorflow.op.nn.Softmax.create(scope, swappedInputs); return fixOutput(scope, output, shape, dimOp, lastIndex); } /** - * Restores the specified axis, then reshapes the input to the provided shaoe. + * Restores the specified axis, then reshapes the input to the provided shape. * * @param scope The TensorFlow scope * @param output the output @@ -104,15 +110,13 @@ private static Operand swapAxis( Operand zero = Constant.scalarOf(scope, 0); Operand one = Constant.scalarOf(scope, 1); + Operand minus1 = Constant.scalarOf(scope, -1); + Operand range1 = Range.create(scope, zero, dim, one); + Operand range2 = Range.create(scope, Add.create(scope, dim, one), lastIndex, one); + Operand xDim = ExpandDims.create(scope, dim, minus1); + Operand xLastIndex = ExpandDims.create(scope, lastIndex, minus1); + return Transpose.create( - scope, - input, - Concat.create( - scope, - Arrays.asList( - Range.create(scope, zero, dim, one), - Range.create(scope, Add.create(scope, dim, one), lastIndex, one), - dim), - zero)); + scope, input, Concat.create(scope, Arrays.asList(range1, xLastIndex, range2, xDim), zero)); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java index 0436fdd57cf..ccfd344535a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java @@ -5,6 +5,7 @@ import org.tensorflow.framework.utils.TestSession; import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; import org.tensorflow.types.TInt32; class NnOpsTest { @@ -65,4 +66,167 @@ public void testSparseSoftmaxCrossEntropyWithLogits() { session.evaluate(0.69314718f, loss); } } + + @Test + public void testSoftmax() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + + double[][] x = { + { + 1.53975978e-01, + -5.55871308e-01, + 1.06272554e+00, + -7.75577792e-05, + -1.07574403e+00, + -1.70856595e+00, + 6.31895363e-01, + 2.69239008e-01, + 5.44192731e-01, + 6.31500483e-01 + }, + { + -1.15359895e-01, + -2.49849468e-01, + 8.04671764e-01, + 6.24943256e-01, + -4.80956525e-01, + 5.99363089e-01, + 7.44674265e-01, + 1.03888428e+00, + -2.00478077e-01, + 5.33391297e-01 + }, + { + 3.77073050e-01, + -4.92661327e-01, + -8.23421478e-01, + -6.53828621e-01, + 2.00867987e+00, + 2.94002771e-01, + 1.70056212e+00, + 5.50198834e-04, + 6.12767756e-01, + -2.29190066e-01 + }, + { + -1.60519981e+00, + -3.48692238e-01, + -3.25094163e-03, + 4.39969897e-01, + 1.50762582e+00, + 9.69331264e-01, + -1.18115306e+00, + 1.34852254e+00, + -1.24402285e+00, + -3.12961072e-01 + }, + { + -1.40357280e+00, + -1.08287978e+00, + -3.79449308e-01, + 1.51061141e+00, + 7.71783948e-01, + 5.29040515e-01, + 8.77655566e-01, + -1.53738844e+00, + 9.32778895e-01, + 3.69026303e-01 + } + }; + + double[][] expected = { + { + 0.09007322, + 0.04429074, + 0.2234913, + 0.07721311, + 0.0263351, + 0.01398634, + 0.14526248, + 0.10107733, + 0.13306525, + 0.14520513 + }, + { + 0.05687012, + 0.04971369, + 0.14270815, + 0.11923224, + 0.0394555, + 0.11622094, + 0.13439782, + 0.1803707, + 0.05222973, + 0.10880107 + }, + { + 0.06962293, + 0.02917639, + 0.02095966, + 0.02483347, + 0.35591814, + 0.06407304, + 0.26153892, + 0.04777828, + 0.08812784, + 0.03797131 + }, + { + 0.01272309, + 0.0446979, + 0.06314083, + 0.0983555, + 0.28607225, + 0.16699266, + 0.01944258, + 0.24399339, + 0.01825786, + 0.04632387 + }, + { + 0.0151052, + 0.02081621, + 0.04206274, + 0.2784457, + 0.13300617, + 0.10433973, + 0.1478602, + 0.01321329, + 0.15623957, + 0.08891118 + } + }; + + Operand input = tf.constant(x); + Operand expectedResult = tf.constant(expected); + Operand result = fops.nn.softmax(input, 1); + session.evaluate(expectedResult, result); + } + } + + @Test + public void testSoftmaxAxes() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + + double[][] arr = { + {0., 0.09090909, 0.18181818, 0.27272727}, + {0.36363636, 0.45454545, 0.54545455, 0.63636364}, + {0.72727273, 0.81818182, 0.90909091, 1.} + }; + + Operand arrTF = tf.constant(arr); + + Operand xNegAxisResult = fops.nn.softmax(arrTF, -2); + Operand yPosAxisResult = fops.nn.softmax(arrTF, 0); + Operand zGtAxisResult = fops.nn.softmax(arrTF, 0); + session.evaluate(xNegAxisResult, yPosAxisResult); + session.evaluate(yPosAxisResult, zGtAxisResult); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index d6786b71972..acaf6bcaa67 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -1,13 +1,17 @@ package org.tensorflow.framework.optimizers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; import org.tensorflow.Graph; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.framework.initializers.Glorot; import org.tensorflow.framework.initializers.VarianceScaling; import org.tensorflow.framework.utils.TestSession; -import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Op; @@ -26,10 +30,8 @@ import org.tensorflow.types.family.TType; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** Test cases for GradientDescent Optimizer */ @@ -162,7 +164,7 @@ public void testDeterminism() { tf.withName("output").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 2))); Mean loss = tf.math.mean( - tf.nn.raw.softmaxCrossEntropyWithLogits(output, placeholder).loss(), tf.constant(0)); + tf.nn.softmaxCrossEntropyWithLogits(output, placeholder).loss(), tf.constant(0)); lossName = loss.op().name(); GradientDescent gd = new GradientDescent(g, 10.0f); @@ -205,12 +207,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - TFloat32 lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + TFloat32 lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); initialLoss[i] = lossVal.getFloat(); lossVal.close(); @@ -222,12 +227,15 @@ public void testDeterminism() { .fetch(outputBiasName) .run()); - lossVal = (TFloat32) s.runner() - .addTarget(trainName) - .feed("input", dataTensor) - .feed("output", targetTensor) - .fetch(lossName) - .run().get(0); + lossVal = + (TFloat32) + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run() + .get(0); postTrainingLoss[i] = lossVal.getFloat(); lossVal.close(); } From adb5a89c1aacbc57e0762ea5e3a8bd062101a3cc Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 10 May 2021 10:20:37 -0400 Subject: [PATCH 29/31] Refactor the 2d transpose that is common to softmax and logSoftmax. Added NNhelper for the transpose code, change softmax to use the transpose code and added logSoftmax to do the same thing. Added test cases for logSoftmax. --- .../org/tensorflow/framework/op/NnOps.java | 52 +++++- .../op/nn/{Softmax.java => NNhelper.java} | 79 ++++----- .../tensorflow/framework/op/NnOpsTest.java | 166 +++++++++++++++++- 3 files changed, 251 insertions(+), 46 deletions(-) rename tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/{Softmax.java => NNhelper.java} (55%) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java index fa05f5a6307..89f5ee159fe 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java @@ -16,7 +16,6 @@ import org.tensorflow.Operand; import org.tensorflow.framework.op.nn.SigmoidCrossEntropyWithLogits; -import org.tensorflow.framework.op.nn.Softmax; import org.tensorflow.framework.op.nn.SoftmaxCrossEntropyWithLogits; import org.tensorflow.framework.op.nn.SparseSoftmaxCrossEntropyWithLogits; import org.tensorflow.op.Op; @@ -24,6 +23,8 @@ import org.tensorflow.types.family.TFloating; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.op.nn.NNhelper.wrap2DFunction; + /** * An API for building {@code nn} operations as {@link Op Op}s * @@ -194,9 +195,48 @@ public Operand sparseSoftmaxCrossEntro scope, labels, logits); } + + /** + * Calculates a Softmax operation operation on the last dimension + * + * @param input the input + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + public Operand softmax(Operand input) { + return wrap2DFunction( scope, input, org.tensorflow.op.nn.Softmax::create, -1 ); + } + + /** + * Calculates a Softmax operation. If the axis is not the last dimension, then the input axis is + * moved to the last axis before calling tf.nn.softmax, then restored before returning. + * + * @param input the input + * @param axis the axis + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + public Operand softmax(Operand input, int axis) { + return wrap2DFunction( scope, input, org.tensorflow.op.nn.Softmax::create, axis ); + } + + /** + * Calculates a Log Softmax operation on the last dimension + * + * @param input the input + * @return the softmax of the input for the specified axis. + * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive + * @param the data type for the input and result + */ + public Operand logSoftmax(Operand input) { + return wrap2DFunction( scope, input, org.tensorflow.op.nn.LogSoftmax::create, -1 ); + } + /** - * Calculates a Softmax operation. If the exis is not the last dimension, then the input axis is - * moved to the last axis berfore calling tf.nn.softmax, then restored before returning. + * Calculates a Log Softmax operation. If the axis is not the last dimension, then the input axis is + * moved to the last axis before calling tf.nn.softmax, then restored before returning. * * @param input the input * @param axis the axis @@ -204,7 +244,9 @@ public Operand sparseSoftmaxCrossEntro * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive * @param the data type for the input and result */ - public Operand softmax(Operand input, int axis){ - return Softmax.softmax(scope, input, axis); + public Operand logSoftmax(Operand input, int axis) { + return wrap2DFunction( scope, input, org.tensorflow.op.nn.LogSoftmax::create, axis ); } + + } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java similarity index 55% rename from tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java rename to tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java index 8f768425dc9..d4c174a5389 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/Softmax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java @@ -1,4 +1,4 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -17,8 +17,6 @@ import org.tensorflow.Operand; import org.tensorflow.ndarray.Shape; import org.tensorflow.op.Scope; -import org.tensorflow.op.annotation.Endpoint; -import org.tensorflow.op.annotation.Operator; import org.tensorflow.op.core.Concat; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.ExpandDims; @@ -30,55 +28,57 @@ import org.tensorflow.op.math.Sub; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; import java.util.Arrays; +import java.util.function.BiFunction; -/** - * Higher level operation for Softmax. This class will move the desired axis to the last axis, if - * necessary, before calling the low level tf.nn.softmax method. - */ -@Operator(group = "nn") -public class Softmax { - +/** package private Helper class for nn functions */ +public class NNhelper { /** - * Calculates a Softmax operation. If the axis is not the last dimension, then the input axis is - * moved to the last axis before calling tf.nn.softmax, then restored before returning. + * Helper function for ops that accept and return 2d inputs of same shape. * - * @param scope The TensorFlow scope + *

It reshapes and transposes the inputs into a 2-D Tensor and then invokes the given function. + * The output would be transposed and reshaped back. + * + * @param scope the TensorFlow Scope * @param input the input - * @param axis the axis - * @return the softmax of the input for the specified axis. + * @param computeOp The function to wrap. Must accept the scope as the first argument, and the input as the second argument. + * (e.g. {@code org.tensorflow.op.nn.Softmax::create} + * @param axis The axisension the operation should operate on. {@code -1} indicates the last + * axisension. + * @param the data type from the input and the result. + * @return the result of the operation, the same shape as the input * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive - * @param the data type for the input and result */ - @Endpoint(name = "softmax") - public static Operand softmax(Scope scope, Operand input, int axis) { + public static Operand wrap2DFunction( + Scope scope, Operand input, BiFunction, Operand> computeOp, int axis) { Shape shape = input.shape(); boolean isLastDim = axis == -1 || axis == shape.numDimensions() - 1; if (isLastDim) { - return org.tensorflow.op.nn.Softmax.create(scope, input); + return computeOp.apply(scope, input); } // validate axis if (!(-shape.numDimensions() <= axis && axis < shape.numDimensions())) { throw new IllegalArgumentException( String.format( - "Axis (%d) must be in the range [%d, %d] where %d is the number of dimensions in the input.", + "Axis (%d) must be in the range [%d, %d] where %d is the number of axisensions in the input.", axis, -shape.numDimensions(), shape.numDimensions(), shape.numDimensions())); } - // If dim is not the last dimension, we have to do a transpose so that we can - // still perform the op on its last dimension. + // If axis is not the last axisension, we have to do a transpose so that we can + // still perform the op on its last axisension. - // In case dim is negative (and is not last dimension -1), convert to positive - int dim = Math.floorMod(axis, shape.numDimensions()); + // In case axis is negative (and is not last axisension -1), convert to positive + int lAxis = Math.floorMod(axis, shape.numDimensions()); Operand inputRank = Rank.create(scope, input); - Operand dimOp = Constant.scalarOf(scope, dim); + Operand axisOp = Constant.scalarOf(scope, lAxis); Operand one = Constant.scalarOf(scope, 1); Operand lastIndex = Sub.create(scope, inputRank, one); - Operand swappedInputs = swapAxis(scope, input, dimOp, lastIndex); - Operand output = org.tensorflow.op.nn.Softmax.create(scope, swappedInputs); - return fixOutput(scope, output, shape, dimOp, lastIndex); + Operand swappedInputs = swapAxis(scope, input, axisOp, lastIndex); + Operand output = computeOp.apply(scope, swappedInputs); + return fixOutput(scope, output, shape, axisOp, lastIndex); } /** @@ -87,13 +87,12 @@ public static Operand softmax(Scope scope, Operand i * @param scope The TensorFlow scope * @param output the output * @param shape the desired shape - * @param dim the dimension to move - * @return the restored output based on the dimension and shape. + * @param axis the axisension to move + * @return the restored output based on the axisension and shape. */ private static Operand fixOutput( - Scope scope, Operand output, Shape shape, Operand dim, Operand lastIndex) { - - Operand result = swapAxis(scope, output, dim, lastIndex); + Scope scope, Operand output, Shape shape, Operand axis, Operand lastIndex) { + Operand result = swapAxis(scope, output, axis, lastIndex); return Reshape.create(scope, result, Constant.tensorOf(scope, shape)); } @@ -101,19 +100,19 @@ private static Operand fixOutput( * Moves the specified Axis to the last axis * * @param input the input - * @param dim the dimension to move - * @param lastIndex the last dimension - * @return input with the dimension swapped to the last dimension + * @param axis the axisension to move + * @param lastIndex the last axisension + * @return input with the axisension swapped to the last axisension */ private static Operand swapAxis( - Scope scope, Operand input, Operand dim, Operand lastIndex) { + Scope scope, Operand input, Operand axis, Operand lastIndex) { Operand zero = Constant.scalarOf(scope, 0); Operand one = Constant.scalarOf(scope, 1); Operand minus1 = Constant.scalarOf(scope, -1); - Operand range1 = Range.create(scope, zero, dim, one); - Operand range2 = Range.create(scope, Add.create(scope, dim, one), lastIndex, one); - Operand xDim = ExpandDims.create(scope, dim, minus1); + Operand range1 = Range.create(scope, zero, axis, one); + Operand range2 = Range.create(scope, Add.create(scope, axis, one), lastIndex, one); + Operand xDim = ExpandDims.create(scope, axis, minus1); Operand xLastIndex = ExpandDims.create(scope, lastIndex, minus1); return Transpose.create( diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java index ccfd344535a..1f538493d91 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/op/NnOpsTest.java @@ -202,7 +202,7 @@ public void testSoftmax() { Operand input = tf.constant(x); Operand expectedResult = tf.constant(expected); - Operand result = fops.nn.softmax(input, 1); + Operand result = fops.nn.softmax(input); session.evaluate(expectedResult, result); } } @@ -229,4 +229,168 @@ public void testSoftmaxAxes() { session.evaluate(yPosAxisResult, zGtAxisResult); } } + + @Test + public void testLogSoftmax() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + + double[][] array = { + { + -0.37646714, + -0.5311734, + 0.82353556, + -1.0500005, + -1.2197578, + -1.0560939, + 0.7242568, + -0.93800896, + 0.47922453, + 0.96604276 + }, + { + 1.8431032, + 0.63521856, + -0.82236594, + -1.8610067, + 0.890422, + -1.8440033, + -1.5645103, + -0.31505722, + 1.7022362, + 0.5422927 + }, + { + -2.3798232, + 0.56610274, + -0.28281465, + 1.37052, + -0.08637848, + 0.3824045, + -0.7390341, + 0.38309613, + -0.05741333, + 0.41976207 + }, + { + 1.4530851, + -0.8334874, + 0.14740701, + 0.00373064, + -0.86982375, + -0.6652942, + 0.665558, + 1.1553634, + 1.5083209, + 0.04152437 + }, + { + -0.3040565, + -0.86586237, + 1.0949674, + -0.4449086, + -0.48374927, + 0.6941735, + 0.21010222, + -0.20612952, + -0.32806364, + 1.6194562 + } + }; + + double[][] expectedArray = { + { + -2.7961895, + -2.9508958, + -1.5961869, + -3.4697227, + -3.63948, + -3.4758162, + -1.6954656, + -3.3577313, + -1.9404979, + -1.4536797 + }, + { + -1.1292517, + -2.3371363, + -3.794721, + -4.8333616, + -2.081933, + -4.8163586, + -4.536865, + -3.2874122, + -1.2701187, + -2.4300623 + }, + { + -4.97046, + -2.024534, + -2.8734512, + -1.2201167, + -2.6770153, + -2.2082322, + -3.329671, + -2.2075405, + -2.64805, + -2.1708746 + }, + { + -1.464081, + -3.7506535, + -2.7697592, + -2.9134355, + -3.78699, + -3.5824602, + -2.2516081, + -1.7618027, + -1.4088452, + -2.8756418 + }, + { + -3.0270078, + -3.5888138, + -1.6279839, + -3.1678598, + -3.2067006, + -2.0287778, + -2.512849, + -2.929081, + -3.051015, + -1.1034951 + } + }; + + Operand input = tf.constant(array); + + Operand result = fops.nn.logSoftmax(input); + Operand expected = tf.constant(expectedArray); + session.evaluate(expected, result); + } + } + + @Test + public void testLogSoftmaxAxes() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + FrameworkOps fops = FrameworkOps.create(tf); + + double[][] arr = { + {0., 0.09090909, 0.18181818, 0.27272727}, + {0.36363636, 0.45454545, 0.54545455, 0.63636364}, + {0.72727273, 0.81818182, 0.90909091, 1.} + }; + + Operand arrTF = tf.constant(arr); + + Operand xNegAxisResult = fops.nn.logSoftmax(arrTF, -2); + Operand yPosAxisResult = fops.nn.logSoftmax(arrTF, 0); + Operand zGtAxisResult = fops.nn.logSoftmax(arrTF, 0); + session.evaluate(xNegAxisResult, yPosAxisResult); + session.evaluate(yPosAxisResult, zGtAxisResult); + } + } } From 1ebe917056b4dfe77ef54a5b653cac55f172a55e Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 10 May 2021 10:22:48 -0400 Subject: [PATCH 30/31] Reformat code --- .../org/tensorflow/framework/op/nn/NNhelper.java | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java index d4c174a5389..3af3c606b4a 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/nn/NNhelper.java @@ -28,7 +28,6 @@ import org.tensorflow.op.math.Sub; import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TFloating; -import org.tensorflow.types.family.TNumber; import java.util.Arrays; import java.util.function.BiFunction; @@ -43,8 +42,8 @@ public class NNhelper { * * @param scope the TensorFlow Scope * @param input the input - * @param computeOp The function to wrap. Must accept the scope as the first argument, and the input as the second argument. - * (e.g. {@code org.tensorflow.op.nn.Softmax::create} + * @param computeOp The function to wrap. Must accept the scope as the first argument, and the + * input as the second argument. (e.g. {@code org.tensorflow.op.nn.Softmax::create} * @param axis The axisension the operation should operate on. {@code -1} indicates the last * axisension. * @param the data type from the input and the result. @@ -52,7 +51,10 @@ public class NNhelper { * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive */ public static Operand wrap2DFunction( - Scope scope, Operand input, BiFunction, Operand> computeOp, int axis) { + Scope scope, + Operand input, + BiFunction, Operand> computeOp, + int axis) { Shape shape = input.shape(); boolean isLastDim = axis == -1 || axis == shape.numDimensions() - 1; if (isLastDim) { @@ -91,7 +93,11 @@ public static Operand wrap2DFunction( * @return the restored output based on the axisension and shape. */ private static Operand fixOutput( - Scope scope, Operand output, Shape shape, Operand axis, Operand lastIndex) { + Scope scope, + Operand output, + Shape shape, + Operand axis, + Operand lastIndex) { Operand result = swapAxis(scope, output, axis, lastIndex); return Reshape.create(scope, result, Constant.tensorOf(scope, shape)); } From 61673dff90d7df94d746df3d61d79e2583bf6859 Mon Sep 17 00:00:00 2001 From: Jim Clarke Date: Mon, 10 May 2021 10:22:56 -0400 Subject: [PATCH 31/31] Reformat code --- .../org/tensorflow/framework/op/NnOps.java | 23 ++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java index 89f5ee159fe..a1f85544f95 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/op/NnOps.java @@ -195,7 +195,6 @@ public Operand sparseSoftmaxCrossEntro scope, labels, logits); } - /** * Calculates a Softmax operation operation on the last dimension * @@ -204,8 +203,8 @@ public Operand sparseSoftmaxCrossEntro * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive * @param the data type for the input and result */ - public Operand softmax(Operand input) { - return wrap2DFunction( scope, input, org.tensorflow.op.nn.Softmax::create, -1 ); + public Operand softmax(Operand input) { + return wrap2DFunction(scope, input, org.tensorflow.op.nn.Softmax::create, -1); } /** @@ -218,8 +217,8 @@ public Operand softmax(Operand input) { * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive * @param the data type for the input and result */ - public Operand softmax(Operand input, int axis) { - return wrap2DFunction( scope, input, org.tensorflow.op.nn.Softmax::create, axis ); + public Operand softmax(Operand input, int axis) { + return wrap2DFunction(scope, input, org.tensorflow.op.nn.Softmax::create, axis); } /** @@ -230,13 +229,13 @@ public Operand softmax(Operand input, int axis) { * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive * @param the data type for the input and result */ - public Operand logSoftmax(Operand input) { - return wrap2DFunction( scope, input, org.tensorflow.op.nn.LogSoftmax::create, -1 ); + public Operand logSoftmax(Operand input) { + return wrap2DFunction(scope, input, org.tensorflow.op.nn.LogSoftmax::create, -1); } /** - * Calculates a Log Softmax operation. If the axis is not the last dimension, then the input axis is - * moved to the last axis before calling tf.nn.softmax, then restored before returning. + * Calculates a Log Softmax operation. If the axis is not the last dimension, then the input axis + * is moved to the last axis before calling tf.nn.softmax, then restored before returning. * * @param input the input * @param axis the axis @@ -244,9 +243,7 @@ public Operand logSoftmax(Operand input) { * @throws IllegalArgumentException if axis is not in the range [-rank - rank], exclusive * @param the data type for the input and result */ - public Operand logSoftmax(Operand input, int axis) { - return wrap2DFunction( scope, input, org.tensorflow.op.nn.LogSoftmax::create, axis ); + public Operand logSoftmax(Operand input, int axis) { + return wrap2DFunction(scope, input, org.tensorflow.op.nn.LogSoftmax::create, axis); } - - }