diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java index 77c6ab2bf87..363291fa5cc 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/CategoricalCrossentropy.java @@ -69,7 +69,7 @@ public class CategoricalCrossentropy extends Loss { public static final boolean FROM_LOGITS_DEFAULT = false; public static final float LABEL_SMOOTHING_DEFAULT = 0.0f; - public static final int DEFAULT_AXIS = -1; + public static final int DEFAULT_AXIS = Losses.CHANNELS_LAST; private final boolean fromLogits; private final float labelSmoothing; @@ -203,8 +203,9 @@ public CategoricalCrossentropy( * confidence on label values are relaxed. e.g. labelSmoothing=0.2 means that we will use a * value of 0.1 for label 0 and 0.9 for label 1 * @param reduction Type of Reduction to apply to loss. - * @param axis The channels axis. axis=-1 corresponds to data format `Channels Last' - * and axis=1 corresponds to data format 'Channels First'. + * @param axis The channels axis. axis=-1 corresponds to data format "Channels Last" + * and axis=1 corresponds to data format "Channels First". + * {@link Losses#CHANNELS_LAST} and {@link Losses#CHANNELS_FIRST} * @throws IllegalArgumentException if labelSmoothing is not in the inclusive range of 0. - 1. */ public CategoricalCrossentropy( diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java index 81d9e13c8a9..0d25bd5e7e2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/losses/Losses.java @@ -36,6 +36,9 @@ public class Losses { /** Default Fuzz factor. */ public static final float EPSILON = 1e-7f; + public static final int CHANNELS_LAST = -1; + public static final int CHANNELS_FIRST = 1; + /** * Calculates the mean absolute error between labels and predictions. * @@ -239,7 +242,7 @@ public static Operand categoricalCross tLabels = smoothCategoricalLabels(tf, tLabels, labelSmoothing); } if (fromLogits) { - return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1); + return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis); } /* TODO if (!(predictions instanceof Variable) && (!tf.scope().env().isEager())) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java new file mode 100644 index 00000000000..651a6fac0b0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A Metric that computes the binary cross-entropy loss between true labels and predicted labels. + * + *

This is the crossentropy metric class to be used when there are only two label classes (0 and + * 1). + * + * @param the data type for the predictions. + * @param The data type for the metric result + */ +public class BinaryCrossentropy + extends MeanMetricWrapper implements LossMetric { + + private final boolean fromLogits; + private final float labelSmoothing; + + /** + * Creates a BinaryCrossentropy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0, + * compute the loss between the predicted labels and a smoothed version of the true labels, + * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing + * correspond to heavier smoothing. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public BinaryCrossentropy( + Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + this.fromLogits = fromLogits; + this.labelSmoothing = labelSmoothing; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java new file mode 100644 index 00000000000..c330ea88eaa --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java @@ -0,0 +1,105 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A Metric that computes the categorical cross-entropy loss between true labels and predicted + * labels. + * + *

This is the crossentropy metric class to be used when there are multiple label classes (2 or + * more). The labels should be given as a one_hot representation. eg., When labels values are + * [2, 0, 1], the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]] + * . + * + * @param the data type for the predictions. + * @param The data type for the metric result + */ +public class CategoricalCrossentropy + extends MeanMetricWrapper implements LossMetric { + + private final boolean fromLogits; + private final float labelSmoothing; + private final int axis; + + /** + * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the + * labels and predictions. + * + *

Uses a {@link Losses#CHANNELS_LAST} for the channel axis. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution. + * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, + * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and 0.9 + * for label 1 + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CategoricalCrossentropy( + Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) { + this(tf, name, fromLogits, labelSmoothing, Losses.CHANNELS_LAST, seed, type); + } + + /** + * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the + * labels and predictions. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed, + * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2 + * means that we will use a value of 0.1 for label 0 and 0.9 + * for label 1 + * @param axis Int specifying the channels axis. axis={@link Losses#CHANNELS_LAST} + * corresponds to data format channels_last, and + * axis={@link Losses#CHANNELS_FIRST} corresponds to data format + * channels_first. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CategoricalCrossentropy( + Ops tf, + String name, + boolean fromLogits, + float labelSmoothing, + int axis, + long seed, + Class type) { + super(tf, name, seed, type); + setLoss(this); + this.fromLogits = fromLogits; + this.labelSmoothing = labelSmoothing; + this.axis = axis; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.categoricalCrossentropy( + getTF(), labels, predictions, fromLogits, labelSmoothing, axis); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java new file mode 100644 index 00000000000..2741a36edb6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A Metric that computes the categorical hinge loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result + */ +public class CategoricalHinge extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a CategoricalHinge metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CategoricalHinge(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.categoricalHinge(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java new file mode 100644 index 00000000000..458de092bec --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the cosine similarity metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class CosineSimilarity extends MeanMetricWrapper + implements LossMetric { + public static final int DEFAULT_AXIS = -1; + private final int[] axis; + + /** + * Creates a metric that computes the cosine similarity metric between labels and predictions with + * a default axis, {@link #DEFAULT_AXIS} + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CosineSimilarity(Ops tf, String name, long seed, Class type) { + this(tf, name, DEFAULT_AXIS, seed, type); + } + + /** + * Creates a metric that computes the cosine similarity metric between labels and predictions. + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param axis The dimension along which the cosine similarity is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) { + this(tf, name, new int[] {axis}, seed, type); + } + /** + * Creates a CosineSimilarity metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param axis The dimension along which the cosine similarity is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class type) { + super(tf, name, seed, type); + this.axis = axis; + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity + return Metrics.cosineProximity(getTF(), labels, predictions, axis); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java new file mode 100644 index 00000000000..baf9ad8ab7d --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the hinge loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class Hinge extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a Hinge metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Hinge(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.hinge(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java new file mode 100644 index 00000000000..efcbbcbb7f0 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the Kullback-Leibler divergence loss metric between labels and + * predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class KLDivergence extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a KLDivergence metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public KLDivergence(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java new file mode 100644 index 00000000000..3df8505d54b --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric + * between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class LogCoshError extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a LogCoshError metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public LogCoshError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.logCosh(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java new file mode 100644 index 00000000000..de1f5a5629e --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java @@ -0,0 +1,41 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.framework.metrics.impl.Reduce; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN } + * + * @param The data type for the metric values + * @param The data type for the metric result + */ +public class Mean extends Reduce { + + /** + * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} + * + * @param tf the TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected Mean(Ops tf, String name, long seed, Class type) { + super(tf, name, MetricReduction.WEIGHTED_MEAN, seed, type); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java new file mode 100644 index 00000000000..e27676932ff --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class MeanAbsoluteError extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a Mean Absolute Error metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public MeanAbsoluteError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.meanAbsoluteError(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java new file mode 100644 index 00000000000..84fa9b627b2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class MeanAbsolutePercentageError + extends MeanMetricWrapper implements LossMetric { + + /** + * Creates a Mean Absolute Error metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.meanAbsolutePercentageError(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java new file mode 100644 index 00000000000..c7edd6ebe93 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class MeanSquaredError extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a Mean Absolute Error metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public MeanSquaredError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.meanSquaredError(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java new file mode 100644 index 00000000000..199b6e0e114 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the mean of absolute difference between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class MeanSquaredLogarithmicError + extends MeanMetricWrapper implements LossMetric { + + /** + * Creates a Mean Absolute Error metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java new file mode 100644 index 00000000000..bbb2aa73da2 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java @@ -0,0 +1,193 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import java.util.Collections; +import java.util.List; + +/** + * Base class for Metrics + * + * @param The data type for the metric values + * @param The data type for the metric result + */ +public abstract class Metric { + + /** The TensorFlow Ops */ + private final Ops tf; + + /** The seed for random number generation */ + private final long seed; + + /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */ + private final String name; + + /** + * Creates a Metric with a name of {@link Class#getSimpleName()} + * + * @param tf the TensorFlow Ops + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + protected Metric(Ops tf, long seed) { + this(tf, null, seed); + } + + /** + * Creates a Metric + * + * @param tf the TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + */ + protected Metric(Ops tf, String name, long seed) { + if (!tf.scope().env().isGraph()) { + throw new IllegalArgumentException("Metrics are required to execute in Graph mode."); + } + this.seed = seed; + this.name = name != null ? name : this.getClass().getSimpleName(); + this.tf = tf.withName(this.getClass().getSimpleName()); + } + + /** + * Creates a List of Operations to update the metric state based on input values. + * + *

This is an empty implementation that should be overridden in a subclass, if needed. + * + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to values, may be null. + * @return a List of Operations to update the metric state + * @param the data type for sampleWeights + */ + @SuppressWarnings({"unchecked", "unused"}) + public List updateStateList(Operand values, Operand sampleWeights) { + return Collections.EMPTY_LIST; + } + + /** + * Creates a List of Operations to update the metric state based on labels and predictions. + * + *

This is an empty implementation that should be overridden in a sub class, if needed. + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights sample weights to be applied to values, may be null. + * @param the data type for the labels + * @param the data type for the sampleWeights + * @return a List of Operations to update the metric state + */ + @SuppressWarnings({"unchecked", "unused"}) + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + return Collections.EMPTY_LIST; + } + + /** + * Creates a NoOp Operation with control dependencies to update the metric state + * + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to values, may be null. + * @param the data type for sampleWeights + * @return the Operation to update the metric state + */ + public final Op updateState(Operand values, Operand sampleWeights) { + List controlOps = updateStateList(values, sampleWeights); + return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); + } + + /** + * Creates a NoOp Operation with control dependencies to update the metric state + * + * @param labels the labels + * @param predictions the predictions + * @param sampleWeights sample weights to be applied to values, may be null. + * @param the data type for the labels + * @param the data type for the sampleWeights + * @return the Operation to update the metric state + */ + public final Op updateState( + Operand labels, Operand predictions, Operand sampleWeights) { + List controlOps = updateStateList(labels, predictions, sampleWeights); + return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp(); + } + + /** + * Gets the current result of the metric + * + * @return the result, possibly with control dependencies + */ + public abstract Operand result(); + + /** + * Resets any state variables to their initial values + * + * @return the control operation for doing the reset + */ + public abstract Op resetStates(); + + /** + * Calls update state once, followed by a call to get the result + * + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to values, may be null. + * @return the result, possibly with control dependencies + * @param the data type for the sampleWeights. + */ + public final Operand callOnce( + Operand values, Operand sampleWeights) { + List controlOps = updateStateList(values, sampleWeights); + Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps); + return ltf.identity(result()); + } + + /** + * Gets a formatted name for a variable, in the form {@link #name} + "_" + varName. + * + * @param varName the base name for the variable + * @return the formatted variable name + */ + protected String getVariableName(String varName) { + return String.format("%s_%s", this.name, varName); + } + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the name of this metric. + * + * @return the name of this metric + */ + public String getName() { + return name; + } + + /** The random number generator seed value */ + public long getSeed() { + return seed; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java new file mode 100644 index 00000000000..d837ff626b3 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java @@ -0,0 +1,26 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +/** Defines the different types of metric reductions */ +public enum MetricReduction { + + /** Scalar sum of weighted values. */ + SUM, + /** Scalar sum of weighted values divided by number of elements. */ + SUM_OVER_BATCH_SIZE, + /** Scalar sum of weighted values divided by sum of weights. */ + WEIGHTED_MEAN +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java new file mode 100644 index 00000000000..0169bc6b8bc --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java @@ -0,0 +1,134 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.ReduceSum; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.family.TNumber; + +/** Helper class with built-in metrics functions. */ +public class Metrics { + + public static final float L2_NORM_EPSILON = 1e-12f; + + /** + * Computes how often targets are in the top K predictions. + * + *

Standalone usage: + * + *

+   *     Operand<TInt32> labels = tf.constant(new int[][]
+   *                                    {{0, 0, 1}, {0, 1, 0}});
+   *     Operand<TFloat32> predictions = tf.constant(new float[][]
+   *                                    {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}});
+   *     Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
+   *                                    labels, predictions, 3)
+   *     //m.shape().toString == "[2]"
+   * 
+ * + * @param tf the TensorFlow Ops. + * @param labels the ground truth values. + * @param predictions The prediction values. + * @param k Number of top elements to look at for computing accuracy. + * @param the data type for the predictions and results + * @param the data type ofr the labels. + * @return the Operand for the Top K categorical accuracy value. + */ + public static Operand topKCategoricalAccuracy( + Ops tf, Operand labels, Operand predictions, long k) { + Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class); + return CastHelper.cast( + tf, + tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)), + predictions.type()); + } + + /** + * Computes the cosine similarity between labels and predictions. + * + * @param tf the TensorFlow Ops + * @param labels The ground truth values. + * @param predictions The prediction values. + * @param axes The dimensions along which the cosine similarity is computed. + * @param the data type for the labels + * @param the data type for the predictions and result + * @return Cosine similarity value. + */ + public static Operand cosineProximity( + Ops tf, Operand labels, Operand predictions, int[] axes) { + Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type()); + labelsNorm = l2Normalize(tf, labelsNorm, axes); + + Operand predictionsNorm = l2Normalize(tf, predictions, axes); + Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm); + return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE)); + } + + /** + * Normalizes along dimension axis using an L2 norm with an epsilon of {@link + * #L2_NORM_EPSILON}. + * + *

For a 1-D tensor with axis = 0, computes + * + *

+   *       output = x / sqrt(max(sum(x**2), epsilon))
+   * 
+ * + *

For x with more dimensions, independently normalizes each 1-D slice along + * dimension axis. + * + * @param tf The TensorFlow ops + * @param x The operand to normalize + * @param axes Dimension(s) along which to normalize. + * @param The data type for x. + * @return the normalized values of x. + */ + public static Operand l2Normalize(Ops tf, Operand x, int[] axes) { + return l2Normalize(tf, x, axes, L2_NORM_EPSILON); + } + + /** + * Normalizes along dimension axis using an L2 norm. + * + *

For a 1-D tensor with axis = 0, computes + * + *

+   *       output = x / sqrt(max(sum(x**2), epsilon))
+   * 
+ * + *

For x with more dimensions, independently normalizes each 1-D slice along + * dimension axis. + * + * @param tf The TensorFlow ops + * @param x The operand to normalize + * @param axes Dimension(s) along which to normalize. + * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon) as the + * divisor if norm < sqrt(epsilon). + * @param The data type for the values. + * @return the normalized values of x. + */ + public static Operand l2Normalize( + Ops tf, Operand x, int[] axes, float epsilon) { + Operand squareSum = + tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE)); + Operand y = + tf.math.rsqrt( + tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type()))); + return tf.math.mul(x, y); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java new file mode 100644 index 00000000000..75a2031fbb5 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the poisson loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class Poisson extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a Poisson metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public Poisson(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.poisson(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java new file mode 100644 index 00000000000..2e01f722de6 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the sparse categorical cross-entropy loss between true labels and + * predicted labels. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class SparseCategoricalCrossentropy + extends MeanMetricWrapper implements LossMetric { + + private final boolean fromLogits; + private final int axis; + + /** + * Creates a SparseCategoricalCrossentropy metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution. + * @param axis The dimension along which the entropy is computed. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public SparseCategoricalCrossentropy( + Ops tf, String name, boolean fromLogits, int axis, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + this.fromLogits = fromLogits; + this.axis = axis; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java new file mode 100644 index 00000000000..430dbbcc229 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Losses; +import org.tensorflow.framework.metrics.impl.LossMetric; +import org.tensorflow.framework.metrics.impl.MeanMetricWrapper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A metric that computes the squared hinge loss metric between labels and predictions. + * + * @param the data type for the predictions. + * @param The data type for the metric result. + */ +public class SquaredHinge extends MeanMetricWrapper + implements LossMetric { + + /** + * Creates a SquaredHinge metric + * + * @param tf the TensorFlow Ops + * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + public SquaredHinge(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + setLoss(this); + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand labels, Operand predictions) { + return Losses.squaredHinge(getTF(), labels, predictions); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java new file mode 100644 index 00000000000..66640e72f50 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java @@ -0,0 +1,52 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.exceptions; + +import org.tensorflow.ndarray.Shape; + +/** + * Exception that indicates that static shapes are not able to broadcast among each other during + * arithmetic operations. Static shapes do not have unknown rank or any unknown dimensions {@link + * Shape#hasUnknownDimension()}. The term broadcasting describes how TensorFlow treats arrays with + * different shapes during arithmetic operations. + * + *

Broadcasting is the process of making arrays to have compatible shapes for arithmetic + * operations. Two shapes are compatible if for each dimension pair they are either equal or one of + * them is one. When trying to broadcast a Tensor to a shape, it starts with the trailing + * dimensions, and works its way forward. + * + * @see Numpy Broadcasting + */ +public class NotBroadcastableException extends IllegalArgumentException { + + /** + * Creates a new NotBroadcastableException exception with the specified detail message + * + * @param message the detail message. + */ + public NotBroadcastableException(String message) { + super(message); + } + + /** + * Creates a new NotBroadcastableException exception with the specified detail message + * + * @param message the detail message. + * @param cause the cause + */ + public NotBroadcastableException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java new file mode 100644 index 00000000000..b7b87d313aa --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java @@ -0,0 +1,36 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.types.family.TNumber; + +/** + * Interface for Metrics that wrap Loss functions. + * + * @param The data type of the predictions. + */ +public interface LossMetric { + + /** + * Calculates the weighted loss between labels and predictions + * + * @param labels the truth values or labels + * @param predictions the predictions + * @param The data type of the labels. + * @return the loss + */ + Operand call(Operand labels, Operand predictions); +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java new file mode 100644 index 00000000000..17c209a8fed --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java @@ -0,0 +1,106 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.Mean; +import org.tensorflow.framework.metrics.MetricReduction; +import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +/** + * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of + * {@link MetricReduction#WEIGHTED_MEAN}. + * + *

The loss function calculates the loss between the labels and predictions + * then passes this loss to the {@link Mean} metric to calculate the weighted mean of the + * loss over many iterations or epochs + * + * @param the data type for the predictions. + * @param The data type for the metric result + */ +public class MeanMetricWrapper extends Mean { + + /** The loss function interface */ + protected LossMetric loss; + + /** + * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#WEIGHTED_MEAN} + * + * @param tf the TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param type the type for the variables and result + */ + protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) { + super(tf, name, seed, type); + } + + /** + * Gets the loss function. + * + * @return the loss function. + */ + public LossMetric getLoss() { + return loss; + } + + /** + * Sets the Loss function for this wrapper. + * + * @param loss the loss function. + */ + protected void setLoss(LossMetric loss) { + this.loss = loss; + } + + /** + * Creates Operations that update the state of the mean metric, by calling the loss function and + * passing the loss to the Mean metric to calculate the weighted mean of the loss over many + * iterations. + * + * @param labels the truth values or labels + * @param predictions the predictions + * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is + * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor + * of size [batch_size], then the total loss for each sample of the batch is rescaled by the + * corresponding element in the sampleWeights vector. If the shape of sampleWeights is + * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of + * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss + * functions reduce by 1 dimension, usually axis=-1.) + * @param the datatype of the labels + * @param the data type for sampleWeights + * @return a List of control operations that updates the Mean state variables. + */ + public List updateStateList( + Operand labels, Operand predictions, Operand sampleWeights) { + if (labels == null || predictions == null) { + throw new IllegalArgumentException("missing required inputs for labels and predictions"); + } + + Operand tLabels = CastHelper.cast(getTF(), labels, getResultType()); + Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType()); + + Operand losses = loss.call(tLabels, tPredictions); + + return super.updateStateList( + CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java new file mode 100644 index 00000000000..ad8ff58e417 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java @@ -0,0 +1,348 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.math.Mean; +import org.tensorflow.types.TBool; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TIntegral; +import org.tensorflow.types.family.TNumber; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes; +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * These are helper methods for Metrics and will be module private when Java modularity is applied + * to TensorFlow Java. These methods should not be used outside of the metrics packages. + */ +public class MetricsHelper { + public static final float NEG_INF = -1e10f; + private static final String ASSERT_BROADCAST_ERROR_PREFIX = + "weights can not be broadcast to values."; + + /** + * Asserts that the sampleWeights can be broadcast to the same shape as values + * + * + *

In losses and metrics, limited weight broadcasting is supported. Weights must be either + * scalar, or the same rank as the target values, with each dimension either 1, or the same as the + * corresponding values dimension. + * + * @param tf the TensorFlow Ops + * @param sampleWeights the sample weights. + * @param values the values to which weights are applied. + * @return Operation with control dependencies to ensure sampleWeight + * can be broadcast to values + * @param the type of Operand + * @throws NotBroadcastableException If static checks determine sampleWeights has an + * incorrect shape that prohibit broadcasting to values + */ + @SuppressWarnings("unchecked") + public static Op assertBroadcastable( + Ops tf, Operand sampleWeights, Operand values) { + + // try static check for exact match + + Shape weightsShapeStatic = sampleWeights.shape(); + int weightsRankStatic = weightsShapeStatic.numDimensions(); + + Shape valuesShapeStatic = values.shape(); + int valuesRankStatic = valuesShapeStatic.numDimensions(); + + // if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) { + if (!weightsShapeStatic.isUnknown() + && !valuesShapeStatic.isUnknown() + && !weightsShapeStatic.hasUnknownDimension() + && !valuesShapeStatic.hasUnknownDimension()) { + if (weightsRankStatic == 0) { + return tf.withSubScope("staticScalarCheckSuccess") + .withControlDependencies(Collections.EMPTY_LIST) + .noOp(); + } + if (weightsRankStatic != valuesRankStatic) { + throw new NotBroadcastableException( + String.format( + "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.", + ASSERT_BROADCAST_ERROR_PREFIX, + valuesRankStatic, + weightsRankStatic, + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); + } + + for (int i = 0; i < valuesRankStatic; i++) { + if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i) + && weightsShapeStatic.size(i) != 1) { + throw new NotBroadcastableException( + String.format( + "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.", + ASSERT_BROADCAST_ERROR_PREFIX, + i, + valuesShapeStatic.toString(), + weightsShapeStatic.toString())); + } + } + return tf.withSubScope("staticDimsCheckSuccess") + .withControlDependencies(Collections.EMPTY_LIST) + .noOp(); + } + // Dynamic checks. + Operand weightsShape = tf.shape(sampleWeights); + Operand weightsRank = tf.rank(sampleWeights); + Operand valuesShape = tf.shape(values); + Operand valuesRank = tf.rank(values); + + Operand isScalar = tf.math.equal(weightsRank, tf.constant(0)); + List> data = + Arrays.asList( + tf.constant(ASSERT_BROADCAST_ERROR_PREFIX), + tf.constant("weights.shape="), + weightsShape, + tf.constant("values.shape="), + valuesShape, + tf.constant("isScalar="), + isScalar); + + // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a + // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true. + Operand reshapedWeights = + tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights); + weightsShape = tf.shape(reshapedWeights); + weightsRank = tf.rank(reshapedWeights); + + Operand validNonscalar = + canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape); + + Operand isValidShape = tf.select(isScalar, isScalar, validNonscalar); + + return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data); + } + + /** + * Gets an operand that tests if the shapes have the same rank and valid dimensions. + * + * @param tf the TensorFlow Ops + * @param weightsRank the operand for the rank of the sample weights + * @param weightsShape the operand for the shape of the sample weights + * @param valuesRank the operand for the rank of the sample weights + * @param valuesShape the operand for the shape of the sample weights + * @param the data type for the operands + * @return a boolean operand to determine if the Shape is scalar or not. + */ + private static Operand canBroadcastNonscalarShapes( + Ops tf, + Operand weightsRank, + Operand weightsShape, + Operand valuesRank, + Operand valuesShape) { + tf = tf.withSubScope("canBroadcastNonscalarShapes"); + Operand isSameRank = tf.math.equal(valuesRank, weightsRank); + return tf.select(isSameRank, canBroadcastDims(tf, weightsShape, valuesShape), isSameRank); + } + + /** + * Gets an operand that tests if the shapes have valid dimensions or not. + * + * @param tf the TensorFlow Ops + * @param weightsShape the operand for the shape of the sample weights + * @param valuesShape the operand for the shape of the values + * @param the data type for the operands + * @return a boolean operand to determine if the shapes have valid dimensions or not. + */ + private static Operand canBroadcastDims( + Ops tf, Operand weightsShape, Operand valuesShape) { + tf = tf.withSubScope("canBroadcastDims"); + Operand valuesShape2d = tf.expandDims(valuesShape, tf.constant(-1)); + Operand validDims = + tf.concat(Arrays.asList(valuesShape2d, tf.onesLike(valuesShape2d)), tf.constant(1)); + Operand weightsShape2D = tf.expandDims(weightsShape, tf.constant(-1)); + + Operand diffResult = SetsOps.difference(tf, weightsShape2D, validDims); + Operand numInvalidDims = tf.size(diffResult); + return tf.math.equal(tf.constant(0), numInvalidDims); + } + + /** + * Broadcast weights to the same shape as values. + * + * @param tf the TensorFlow ops + * @param weights Operand whose shape is broadcastable to values. + * @param values Operand of any shape + * @param the type of Operands + * @return weights broadcast to values shape + */ + public static Operand broadcastWeights( + Ops tf, Operand weights, Operand values) { + + Shape weightsShape = weights.shape(); + Shape valuesShape = values.shape(); + + if (!weightsShape.hasUnknownDimension() + && !valuesShape.hasUnknownDimension() + && weightsShape.isCompatibleWith(valuesShape)) { + return weights; + } + + Ops ctf = + tf.withSubScope("broadcastWeights") + .withControlDependencies( + Collections.singletonList(assertBroadcastable(tf, weights, tf.onesLike(values)))); + return ctf.math.mul(weights, tf.onesLike(values)); + } + + // aliases for mean + + /** + * Calculate the mean of the operand, along all axes and keepDims is false + * + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param the type of the Operand. + * @return the mean of the operand + */ + public static Operand mean(Ops tf, Operand x) { + return mean(tf, x, null, false); + } + + /** + * Calculate the mean of the operand, alongside the specified axis with keepDims is + * false + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param axes Axes to compute the mean. + * @param the type of the Operand. + * @param the type of the axes. + * @return the mean of the operand, along the specified axes. + */ + public static Operand mean( + Ops tf, Operand x, Operand axes) { + return mean(tf, x, axes, false); + } + + /** + * Calculates the mean of the operand, along all axes. + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param keepDims Indicates whether to keep the dimensions or not. If keepdims is + * false, the rank of the tensor is reduced by 1 for each entry in axes + * . If keepdims is true, the reduced dimensions are retained + * with length 1. + * @param the type of the operand + * @return the mean of elements of x. + */ + public static Operand mean( + Ops tf, Operand x, boolean keepDims) { + return mean(tf, x, null, keepDims); + } + + + + /** + * Calculates the mean of the operand, alongside the specified axes. + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param axes Axes to compute the mean. + * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the + * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the + * * reduced dimensions are retained with length 1. + * @param the data type of the Operand + * @param the data type of the axes + * @return the mean of elements of x. + */ + + public static Operand mean( + Ops tf, Operand x, Operand axes, boolean keepDims) { + if (axes == null) { + axes = (Operand) allAxes(tf, x); + } + return tf.math.mean(x, axes, Mean.keepDims(keepDims)); + } + + /** + * Calculate the mean of the operand, along all axes and keepDims is false + * + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @return the mean of the operand containing floating point numbers + */ + public static Operand booleanMean(Ops tf, Operand x) { + return booleanMean(tf, x, null, false); + } + + /** + * Calculate the mean of the operand, alongside the specified axis with keepDims is + * false + * + * @param tf the TensorFlow Ops + * @param x the Operand used to calculate the mean + * @param axes Axes to compute the mean. + * @param the type of the axes. + * @return the mean of the operand, along the specified axes, containing floating point numbers + */ + public static Operand booleanMean( + Ops tf, Operand x,Operand axes) { + return booleanMean(tf, x, axes, false); + } + + /** + * Calculates the mean of the boolean operand, alongside all axes. + * + * @param x the boolean Operand used to calculate the mean + * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the + * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the + * * reduced dimensions are retained with length 1. + * @param the data type of the axes + * @return the mean of elements of x containing floating point numbers + */ + public static Operand booleanMean( + Ops tf, Operand x, boolean keepDims) { + return booleanMean(tf, x, null, keepDims); + } + + /** + * Calculates the mean of the boolean operand, alongside the specified axes. + * + * @param x the boolean Operand used to calculate the mean + * @param axes Axes to compute the mean. + * @param keepDims Indicates whether to keep the dimensions or not. If `keepdims` is `false`, the + * * rank of the tensor is reduced by 1 for each entry in `axes`. If `keepdims` is `true`, the + * * reduced dimensions are retained with length 1. + * @param the data type of the axes + * @return the mean of elements of x containing floating point numbers + */ + public static Operand booleanMean( + Ops tf, Operand x, Operand axes, boolean keepDims) { + Operand xf = cast(tf, x, TFloat64.class); + return mean(tf, xf, axes, keepDims); + } + +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java new file mode 100644 index 00000000000..8e48cb4e573 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/Reduce.java @@ -0,0 +1,240 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossTuple; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.framework.metrics.Metric; +import org.tensorflow.framework.metrics.MetricReduction; +import org.tensorflow.framework.utils.CastHelper; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.family.TNumber; + +import java.util.ArrayList; +import java.util.List; + +/** + * Encapsulates metrics that perform a reduce operation on the metric values. + * + * @param The data type for the metric values + * @param The data type for the metric result + */ +public abstract class Reduce extends Metric { + public static final String TOTAL = "total"; + public static final String COUNT = "count"; + protected final MetricReduction reduction; + private final String totalName; + private final String countName; + + private final Class resultType; + /** the variable that holds the total of the metric values */ + protected Variable total; + /** the variable that holds the count of the metric values. + * For {@link MetricReduction#WEIGHTED_MEAN}, this count may be weighted */ + protected Variable count; + + /** + * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM} + * + * @param tf the TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param resultType the type for the variables and result + */ + protected Reduce(Ops tf, String name, long seed, Class resultType) { + this(tf, name, MetricReduction.SUM, seed, resultType); + } + + /** + * @param tf The TensorFlow Ops + * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}. + * @param reduction The type of metric reduction to apply + * @param seed the seed for random number generation. An initializer created with a given seed + * will always produce the same random tensor for a given shape and data type. + * @param resultType the type for the variables and result + */ + protected Reduce(Ops tf, String name, MetricReduction reduction, long seed, Class resultType) { + super(tf, name, seed); + this.reduction = reduction; + this.totalName = this.getVariableName(TOTAL); + this.countName = this.getVariableName(COUNT); + this.resultType = resultType; + setupVars(); + } + /** Initializes the Variables */ + private void setupVars() { + if (total == null) { + total = getTF().withName(totalName).variable(Shape.scalar(), resultType); + } + if (reduction == MetricReduction.SUM_OVER_BATCH_SIZE + || reduction == MetricReduction.WEIGHTED_MEAN) { + if (count == null) { + count = getTF().withName(countName).variable(Shape.scalar(), resultType); + } + } + } + + /** {@inheritDoc} */ + public Op resetStates() { + List controls = new ArrayList<>(); + if (total != null) { + controls.add( + getTF().assign(total, CastHelper.cast(getTF(), getTF().constant(0), total.type()))); + } + if (count != null) { + controls.add( + getTF().assign(count, CastHelper.cast(getTF(), getTF().constant(0), count.type()))); + } + return getTF().withControlDependencies(controls).noOp(); + } + + /** + * Updates the metric variables based on the inputs. At least one input arg required for + * values, an optional additional input for the sampleWeights + * + * @param values the inputs to be passed to update state, this may not be null + * @param sampleWeights sample weights to be applied to values, may be null. + * @return the result with a control dependency on update state Operands + * @throws IllegalArgumentException if values is null + */ + @Override + public List updateStateList(Operand values, Operand sampleWeights) { + + if (values == null) { + throw new IllegalArgumentException("values is required."); + } + List updateOperations = new ArrayList<>(); + // cast everything to match the variables + Operand lSampleWeights = null; + Operand lValues = values; + + if (sampleWeights != null) { + lSampleWeights = CastHelper.cast(getTF(), sampleWeights, lValues.type()); + LossTuple tuple = + LossesHelper.squeezeOrExpandDimensions(getTF(), null, lValues, lSampleWeights); + lValues = tuple.getTarget(); + lSampleWeights = tuple.getSampleWeights(); + try { + lSampleWeights = MetricsHelper.broadcastWeights(getTF(), lSampleWeights, lValues); + } catch (IllegalArgumentException ex) { + // if we get here we have static shapes with either + // different ranks or different dimension sizes. + // first, reduce the values down to the rank of the samples + int valuesRank = lValues.shape().numDimensions(); + int weightsRank = lSampleWeights.shape().numDimensions(); + int numAxes = Math.min(0, valuesRank - weightsRank); + if (numAxes + > 0) { // values rank is greater than weights rank, reduce values to weights rank. + int[] axes = new int[numAxes]; + for (int i = 0; i < numAxes; i++) axes[i] = i + weightsRank; + if (reduction == MetricReduction.SUM) { + lValues = getTF().reduceSum(lValues, getTF().constant(axes)); + } else { + lValues = getTF().math.mean(lValues, getTF().constant(axes)); + } + } + } + lValues = getTF().math.mul(lValues, lSampleWeights); + } + + Operand weightedValueSum = + getTF().reduceSum(lValues, LossesHelper.allAxes(getTF(), lValues)); + Operand totalUpdate = + getTF().assignAdd(total, CastHelper.cast(getTF(), weightedValueSum, total.type())); + updateOperations.add(totalUpdate); + Operand numValues; + if (reduction != MetricReduction.SUM) { + switch (reduction) { + case SUM_OVER_BATCH_SIZE: + numValues = + CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + break; + case WEIGHTED_MEAN: + if (lSampleWeights == null) { + numValues = + CastHelper.cast(getTF(), getTF().constant(lValues.shape().size()), resultType); + } else { + numValues = + CastHelper.cast( + getTF(), + getTF() + .reduceSum(lSampleWeights, LossesHelper.allAxes(getTF(), lSampleWeights)), + resultType); + } + break; + default: + throw new UnsupportedOperationException( + String.format("reduction [%s] not implemented", reduction)); + } + Operand totalCount = getTF().assignAdd(this.count, numValues); + + updateOperations.add(totalCount); + } + + return updateOperations; + } + + /** {@inheritDoc} */ + @Override + public Operand result() { + Operand fResult; + + switch (this.reduction) { + case SUM: + fResult = getTF().identity(total); + break; + case WEIGHTED_MEAN: + case SUM_OVER_BATCH_SIZE: + fResult = getTF().math.divNoNan(total, CastHelper.cast(getTF(), count, resultType)); + break; + default: + throw new UnsupportedOperationException( + String.format("reduction [%s] not implemented", reduction)); + } + return fResult; + } + + /** + * Gets the total variable + * + * @return the total variable + */ + public Variable getTotal() { + return total; + } + + /** + * Gets the count variable + * + * @return the count variable + */ + public Variable getCount() { + return count; + } + + /** + * Gets the type for the variables + * + * @return the type for the variables + */ + public Class getResultType() { + return resultType; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java new file mode 100644 index 00000000000..1841c7ee238 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/SetsOps.java @@ -0,0 +1,146 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.op.SparseOps; +import org.tensorflow.op.sparse.DenseToDenseSetOperation; +import org.tensorflow.types.family.TNumber; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** Implementation of set operations */ +public class SetsOps { + + /** + * Enumeration containing the string operation values to be passed to the TensorFlow Sparse Ops + * function {@link SparseOps#denseToDenseSetOperation} + */ + public enum Operation { + A_MINUS_B("a-b"), + B_MINUS_A("b-a"), + INTERSECTION("intersection"), + UNION("union"); + + private final String setOperation; + + Operation(String setOperation) { + this.setOperation = setOperation; + } + + /** + * Gets the set operation String value used to pass as the stringOperation value to {@link + * SparseOps#denseToDenseSetOperation} + * + * @return the set operation String value + */ + public String getSetOperation() { + return setOperation; + } + } + + /** + * Computes set difference of elements in last dimension of a and b with + * aMinusB set to true. + * + *

All but the last dimension of a and b must match + * + * @param tf the TensorFlow Ops + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand difference(Ops tf, Operand a, Operand b) { + return difference(tf, a, b, true); + } + /** + * Computes set difference of elements in last dimension of a and b. + * + *

All but the last dimension of a and b must match + * + * @param tf the TensorFlow Ops + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param aMinusB whether to subtract b from a, vs vice versa. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand difference( + Ops tf, Operand a, Operand b, boolean aMinusB) { + return setOperation(tf, a, b, aMinusB ? Operation.A_MINUS_B : Operation.B_MINUS_A); + } + + /** + * Computes set union of elements in last dimension of a and b. + * + * @param tf the TensorFlow Ops + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand union(Ops tf, Operand a, Operand b) { + return setOperation(tf, a, b, Operation.UNION); + } + + /** + * Computes set intersection of elements in last dimension of a and b. + * + * @param tf the TensorFlow Ops + * @param a The first operand representing set a + * @param b The other operand representing set b + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the * same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand intersection(Ops tf, Operand a, Operand b) { + return setOperation(tf, a, b, Operation.INTERSECTION); + } + + /** + * Compute set operation of elements in last dimension of a and b. + * + * @param tf the TensorFlow Ops + * @param a The first set operation operand + * @param b The other et operation operand + * @param setOperation The set operation to perform, {@link Operation}. + * @param the data type for the sets + * @return An Operand with the same rank as a and b, and all but the + * last dimension the same. Elements along the last dimension contain the results of the set + * operation. + */ + public static Operand setOperation( + Ops tf, Operand a, Operand b, Operation setOperation) { + + DenseToDenseSetOperation setOperationResult = + tf.sparse.denseToDenseSetOperation( + a, b, setOperation.getSetOperation(), DenseToDenseSetOperation.validateIndices(true)); + + return tf.sparse.sparseToDense( + setOperationResult.resultIndices(), + setOperationResult.resultShape(), + setOperationResult.resultValues(), + cast(tf, tf.constant(0), a.type())); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java new file mode 100644 index 00000000000..7ceedded018 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/BinaryCrossentropyTest.java @@ -0,0 +1,151 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class BinaryCrossentropyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryCrossentropy instance = + new BinaryCrossentropy<>(tf, "BCE_testUnweighted", false, 0, 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[] trueArray = {1, 0, 1, 0}; + float[] predictionArray = {1, 1, 1, 0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 2))); + Op op = instance.updateState(labels, yPrediction, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(7.71247434F, total); + session.evaluate(2, count); + session.evaluate(3.85623717F, result); + } + } + + @Test + public void testUnweightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryCrossentropy instance = + new BinaryCrossentropy<>(tf, "BCE_testUnweightedLogits", true, 0, 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[] trueArray = {1, 0, 1, 0, 1, 1}; + double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, logits, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(66.66667, total); + session.evaluate(2, count); + session.evaluate(33.333332, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryCrossentropy instance = + new BinaryCrossentropy<>(tf, "BCE_testWeighted", false, 0, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 0, 1, 0}; + float[] predictionArray = {1, 1, 1, 0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 2))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 2))); + Operand sampleWeight = tf.constant(new float[] {1.5f, 2.f}); + Op op = instance.updateState(labels, yPrediction, sampleWeight); + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(11.499929f, total); + session.evaluate(3.5f, count); + session.evaluate(3.285694f, result); + } + } + + @Test + public void testWeightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + BinaryCrossentropy instance = + new BinaryCrossentropy<>(tf, "BCE_testWeightedLogits", true, 0, 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[] trueArray = {1, 0, 1, 0, 1, 1}; + double[] logitsArray = {100.0, -100.0, 100.0, 100.0, 100.0, -100.0}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(new double[] {2, 2.5}); + + Op op = instance.updateState(labels, logits, sampleWeight); + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(166.66666, total); + session.evaluate(4.5, count); + session.evaluate(37.037033, result); + } + } + + @Test + public void testLabelSmoothing() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float labelSmoothing = 0.1F; + BinaryCrossentropy instance = + new BinaryCrossentropy<>( + tf, "BCE_testWeightedLabS", true, labelSmoothing, 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[] trueArray = {1, 0, 1}; + double[] logitsArray = {100., -100., -100.}; + Operand labels = tf.constant(trueArray); + Operand logits = tf.constant(logitsArray); + + Op op = instance.updateState(labels, logits, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + + session.evaluate(35, total); + session.evaluate(1, count); + session.evaluate(35, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java new file mode 100644 index 00000000000..2b4a1d75467 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalCrossentropyTest.java @@ -0,0 +1,151 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class CategoricalCrossentropyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>( + tf, "CCE_testUnweighted", false, 0, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {0.05, 0.95, 0, 0.1, 0.8, 0.1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.3538785, total); + session.evaluate(2, count); + session.evaluate(1.1769392, result); + } + } + + @Test + public void testUnweightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>( + tf, "CCE_testUnweightedLogits", true, 0, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(7.0022807, total); + session.evaluate(2, count); + session.evaluate(3.5011404, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>( + tf, "CCE_testWeighted", false, 0, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {0.05f, 0.95f, 0f, 0.1f, 0.8f, 0.1f}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(new double[] {1.5f, 2.}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4.6821095, total); + session.evaluate(3.5, count); + session.evaluate(1.3377455, result); + } + } + + @Test + public void testWeightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>(tf, "CCE_testWeighted", true, 0, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(new double[] {1.5, 2.f}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(14.004333, total); + session.evaluate(3.5, count); + session.evaluate(4.0012328, result); + } + } + + @Test + public void testLabelSmoothing() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + float labelSmoothing = 0.1F; + CategoricalCrossentropy instance = + new CategoricalCrossentropy<>( + tf, "CCE_testWeighted", true, labelSmoothing, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 0, 0, 1}; + double[] predArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(7.3356137, total); + session.evaluate(2, count); + session.evaluate(3.6678069, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java new file mode 100644 index 00000000000..87248d95e48 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CategoricalHingeTest.java @@ -0,0 +1,97 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class CategoricalHingeTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalHinge instance = + new CategoricalHinge<>(tf, "CH_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 5))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2., total); + session.evaluate(4, count); + session.evaluate(0.5, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CategoricalHinge instance = + new CategoricalHinge<>(tf, "CH_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(3.5F, total); + session.evaluate(7, count); + session.evaluate(0.5, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java new file mode 100644 index 00000000000..a9721ef2f8f --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/CosineSimilarityTest.java @@ -0,0 +1,101 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +class CosineSimilarityTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CosineSimilarity instance = + new CosineSimilarity<>(tf, "CS_testUnweighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.3744381F, total); + session.evaluate(2, count); + session.evaluate(0.18721905F, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + CosineSimilarity instance = + new CosineSimilarity<>(tf, "CS_testWeighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new float[] {1.2f, 3.4f}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(-0.3119840621948241F, total); + session.evaluate(4.6, count); + session.evaluate(-0.06782262221626612F, result); + } + } + + @Test + public void test_axis() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + int axis = 1; + CosineSimilarity instance = + new CosineSimilarity<>(tf, "CS_testWeighted", axis, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.3744381F, total); + session.evaluate(2, count); + session.evaluate(0.18721905F, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java new file mode 100644 index 00000000000..6af5fed4889 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/HingeTest.java @@ -0,0 +1,84 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class HingeTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Hinge instance = + new Hinge<>(tf, "Hinge_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {0, 1, 0, 1, 0, 0, 1, 1}; + double[] predArray = {-0.3, 0.2, -0.1, 1.6, -0.25, -1., 0.5, 0.6}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.0125, total); + session.evaluate(2, count); + session.evaluate(.5062500, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Hinge instance = + new Hinge<>(tf, "Hinge_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = { + -1, 1, -1, 1, + -1, -1, 1, 1 + }; + float[] predArray = { + -0.3f, 0.2f, -0.1f, 1.6f, + -0.25f, -1.f, 0.5f, 0.6f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + + Operand sampleWeight = tf.constant(new double[] {1.5, 2.}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.7250f, total); + session.evaluate(3.5, count); + session.evaluate(.49285714f, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java new file mode 100644 index 00000000000..28020c0fa1c --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/KLDivergenceTest.java @@ -0,0 +1,83 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class KLDivergenceTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + KLDivergence instance = + new KLDivergence<>(tf, "KLD_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[][] trueArray = {{.5f, .8f, .12f}, {.7f, .43f, .8f}}; + float[][] predArray = {{.4f, .9f, .12f}, {.36f, .3f, .4f}}; + Operand labels = tf.constant(trueArray); + Operand predictions = tf.constant(predArray); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.1921477, total); + session.evaluate(2, count); + session.evaluate(0.5960738, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + KLDivergence instance = + new KLDivergence<>(tf, "KLD_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[] trueArray = { + .5f, .8f, .12f, + .7f, .43f, .8f + }; + float[] predArray = { + .4f, .9f, .12f, + .36f, .3f, .4f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new double[][] {{1.2}, {3.4}}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4.015142, total); + session.evaluate(4.6, count); + session.evaluate(0.872857, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java new file mode 100644 index 00000000000..31c043e0473 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/LogCoshErrorTest.java @@ -0,0 +1,80 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class LogCoshErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + LogCoshError instance = + new LogCoshError<>(tf, "LogCosh_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + float[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4.829245, result); + session.evaluate(9.65849, total); + session.evaluate(2, count); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + LogCoshError instance = + new LogCoshError<>(tf, "LogCosh_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 9, 2, -5, -2, 6}; + float[] predArray = {4, 8, 12, 8, 1, 3}; + double[][] sampleArray = {{1.2}, {3.4}}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Operand sampleWeight = tf.constant(sampleArray); + + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(5.2178759, result); + session.evaluate(24.002228, total); + session.evaluate(4.6, count); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java new file mode 100644 index 00000000000..73241ecbe9f --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsoluteErrorTest.java @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class MeanAbsoluteErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanAbsoluteError instance = + new MeanAbsoluteError<>(tf, "MAE_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0f, instance.getTotal()); + session.evaluate(0f, instance.getCount()); + session.evaluate(0.f, instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + Op op = instance.updateState(yTrue, yPrediction, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.0, total); + session.evaluate(4, count); + session.evaluate(0.5, result); + + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanAbsoluteError instance = + new MeanAbsoluteError<>(tf, "MAE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); + Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(3.8, total); + session.evaluate(7, count); + session.evaluate(0.54285, result); + + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java new file mode 100644 index 00000000000..4c92844b217 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageErrorTest.java @@ -0,0 +1,115 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +class MeanAbsolutePercentageErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + session.setEpsilon(1E-6f); + Ops tf = session.getTF(); + MeanAbsolutePercentageError instance = + new MeanAbsolutePercentageError<>(tf, "MAPE_testUnweighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.evaluate(0.0f, instance.getTotal()); + session.evaluate(0f, instance.getCount()); + session.evaluate(0.f, instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Op op = instance.updateState(yTrue, yPrediction, null); + + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.4E9f, total); + session.evaluate(4f, count); + session.evaluate(35e7f, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + session.setEpsilon(1E-6f); + Ops tf = session.getTF(); + MeanAbsolutePercentageError instance = + new MeanAbsolutePercentageError<>(tf, "MAPE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0, instance.getCount()); + + long[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1.f, 1.5f, 2.f, 2.5f}); + Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + + session.run(op); + + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.800000067278928E9, total); + session.evaluate(7, count); + session.evaluate(4.000000096112754E8, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java new file mode 100644 index 00000000000..0b760213015 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredErrorTest.java @@ -0,0 +1,107 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; + +class MeanSquaredErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanSquaredError instance = + new MeanSquaredError<>(tf, "MSE_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + Op op = instance.updateState(yTrue, yPrediction, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.0, total); + session.evaluate(4, count); + session.evaluate(0.5, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanSquaredError instance = + new MeanSquaredError<>(tf, "MSE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + + long[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); + Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(3.8, total); + session.evaluate(7, count); + session.evaluate(0.542857, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java new file mode 100644 index 00000000000..098a5cb9725 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicErrorTest.java @@ -0,0 +1,106 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class MeanSquaredLogarithmicErrorTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanSquaredLogarithmicError instance = + new MeanSquaredLogarithmicError<>(tf, "MSLE_testUnweighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + session.evaluate(0.0f, instance.getTotal()); + session.evaluate(0f, instance.getCount()); + session.evaluate(0.f, instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + float[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + Op op = instance.updateState(yTrue, yPrediction, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.96090573f, total); + session.evaluate(4f, count); + session.evaluate(0.24022f, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + MeanSquaredLogarithmicError instance = + new MeanSquaredLogarithmicError<>(tf, "MSLE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + session.evaluate(0.0, instance.getTotal()); + session.evaluate(0, instance.getCount()); + session.evaluate(0., instance.getCount()); + + int[] trueArray = { + 0, 1, 0, 1, 0, + 0, 0, 1, 1, 1, + 1, 1, 1, 1, 0, + 0, 0, 0, 0, 1 + }; + double[] predictionArray = { + 0, 0, 1, 1, 0, + 1, 1, 1, 1, 1, + 0, 1, 0, 1, 0, + 1, 1, 1, 1, 1 + }; + Operand yTrue = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(4, 5))); + Operand yPrediction = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(4, 5))); + + Operand sampleWeight = tf.constant(new double[] {1., 1.5, 2., 2.5}); + Op op = instance.updateState(yTrue, yPrediction, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.8257208, total); + session.evaluate(7, count); + session.evaluate(0.26082, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java new file mode 100644 index 00000000000..cf3c3e44719 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/PoissonTest.java @@ -0,0 +1,79 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class PoissonTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Poisson instance = + new Poisson<>(tf, "Poisson_testUnweighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {4, 8, 12, 8, 1, 3}; + float[] predArray = {1, 9, 2, 5, 2, 6}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(-6.6131644, total); + session.evaluate(2, count); + session.evaluate(-3.3065822, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Poisson instance = + new Poisson<>(tf, "Poisson_testWeighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {4, 8, 12, 8, 1, 3}; + float[] predArray = {1, 9, 2, 5, 2, 6}; + + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 3))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new float[] {1.2f, 3.4f}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(-12.29468f, total); + session.evaluate(4.6f, count); + session.evaluate(-2.6727562f, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java new file mode 100644 index 00000000000..87af1bd8448 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropyTest.java @@ -0,0 +1,129 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class SparseCategoricalCrossentropyTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy<>( + tf, "SCE_testUnweighted", false, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2}; + double[] predictionArray = {0.05, 0.95, 0, 0.1, 0.8, 0.1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); + Operand predictions = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(2.3538785, total); + session.evaluate(2, count); + session.evaluate(1.1769392, result); + } + } + + @Test + public void testUnweightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy<>( + tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2}; + double[] logitsArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); + Operand logits = tf.reshape(tf.constant(logitsArray), tf.constant(Shape.of(2, 3))); + Op op = instance.updateState(labels, logits, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(7.002277, total); + session.evaluate(2, count); + session.evaluate(3.501135, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy<>( + tf, "SCE_testWeighted", false, -1, 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2}; + double[] predictionArray = {0.05, 0.95, 0, 0.1, 0.8, 0.1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); + Operand predictions = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new float[] {1.5F, 2.F}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(4.6821103f, total); + session.evaluate(3.5f, count); + session.evaluate(1.3377458f, result); + } + } + + @Test + public void testWeightedLogits() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SparseCategoricalCrossentropy instance = + new SparseCategoricalCrossentropy<>( + tf, "SCE_testWeighted", true, -1, 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = {1, 2}; + double[] predictionArray = {1, 9, 0, 1, 8, 1}; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2))); + Operand predictions = + tf.reshape(tf.constant(predictionArray), tf.constant(Shape.of(2, 3))); + + Operand sampleWeight = tf.constant(new double[] {1.5, 2}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(14.004333, total); + session.evaluate(3.5, count); + session.evaluate(4.001232, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java new file mode 100644 index 00000000000..e3376c224f3 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/SquaredHingeTest.java @@ -0,0 +1,90 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.op.core.Variable; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; + +class SquaredHingeTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + @Test + public void testUnweighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SquaredHinge instance = + new SquaredHinge<>(tf, "SCE_testUnweighted", 1001L, TFloat32.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 1, 0, 1, + 0, 0, 1, 1 + }; + float[] predArray = { + -0.3f, 0.2f, -0.1f, 1.6f, + -0.25f, -1.f, 0.5f, 0.6f + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + Op op = instance.updateState(labels, predictions, null); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(0.72812f, total); + session.evaluate(2f, count); + session.evaluate(0.3640625f, result); + } + } + + @Test + public void testWeighted() { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + SquaredHinge instance = + new SquaredHinge<>(tf, "SCE_testWeighted", 1001L, TFloat64.class); + session.run(instance.resetStates()); + int[] trueArray = { + 0, 1, 0, 1, + 0, 0, 1, 1 + }; + double[] predArray = { + -0.3, 0.2, -0.1, 1.6, + -0.25, -1., 0.5, 0.6 + }; + Operand labels = tf.reshape(tf.constant(trueArray), tf.constant(Shape.of(2, 4))); + Operand predictions = + tf.reshape(tf.constant(predArray), tf.constant(Shape.of(2, 4))); + + Operand sampleWeight = tf.constant(new double[] {1.5f, 2.f}); + Op op = instance.updateState(labels, predictions, sampleWeight); + session.run(op); + Variable total = instance.getTotal(); + Variable count = instance.getCount(); + Operand result = instance.result(); + session.evaluate(1.2137499, total); + session.evaluate(3.5, count); + session.evaluate(0.3467857, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java new file mode 100644 index 00000000000..63d666f8640 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/AssertBroadcastableTest.java @@ -0,0 +1,292 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Op; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class AssertBroadcastableTest { + + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + int[][][] valueArrayI = + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + long[][][] valueArrayL = + new long[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + float[][][] valueArrayF = + new float[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + double[][][] valueArrayD = + new double[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + + private void testValid( + TestSession testSession, Ops tf, Operand weights, Operand values, Class type) { + + Op staticOp = MetricsHelper.assertBroadcastable(tf, weights, values); + + // dynamic test + Operand weightsPlaceholder = tf.placeholder(type); + Operand valuesPlaceholder = tf.placeholder(type); + + List tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); + try (Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1)) { + Op dynamicOp = MetricsHelper.assertBroadcastable(tf, weightsPlaceholder, valuesPlaceholder); + + testSession + .getGraphSession() + .runner() + .feed(weightsPlaceholder, weightsTensor) + .feed(valuesPlaceholder, valuesTensor) + .addTarget(dynamicOp) + .run(); + } + } + + @Test + public void testValidScalar() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + Operand values = tf.constant(valueArrayF); + Operand weights = tf.constant(5f); + testValid(testSession, tf, weights, values, TFloat32.class); + } + } + + @Test + public void test1x1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + Operand values = tf.constant(valueArrayD); + Operand weights = tf.constant(new double[][][] {{{5}}}); + testValid(testSession, tf, weights, values, TFloat64.class); + } + } + + @Test + public void test1x1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayL); + Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); + testValid(testSession, tf, weights, values, TInt64.class); + } + } + + @Test + public void test1xNx1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void test1xNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testNx1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testNx1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + Operand values = tf.constant(valueArrayI); + Operand weights = + tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + @Test + public void testNxNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + + Operand values = tf.constant(valueArrayI); + Operand weights = + tf.constant( + new int[][][] { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + } + + // Note: For invalid tests, either NotBroadcastableException is thrown for static shapes or + // TFInvalidInvalidException is thrown for dynamic shapes. Both of these extend + // IllegalArgumentException, + // To simply the assertThrows, only IllegalArgumentException is expected. + // The private method, testValid, tests for both static and dynamic shapes. + @Test + public void testInvalid1x1() { + + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][] {{5}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidOnesExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][][] {{{{5}}}}); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + + Operand weights = + tf.constant( + new int[][][][] { + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, + {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = + tf.constant( + new int[][][][] { + { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + } + }); + testValid(testSession, tf, weights, values, TInt32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java new file mode 100644 index 00000000000..3322a81fe5b --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/BroadcastWeightsTest.java @@ -0,0 +1,380 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.metrics.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.Tensor; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; + +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class BroadcastWeightsTest { + private final TestSession.Mode tfMode = TestSession.Mode.GRAPH; + + int[][][] valueArrayI = + new int[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + long[][][] valueArrayL = + new long[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + float[][][] valueArrayF = + new float[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + double[][][] valueArrayD = + new double[][][] { + {{1, 2, 3, 4}, {5, 6, 7, 8}}, + {{9, 10, 11, 12}, {13, 14, 15, 16}}, + {{17, 18, 19, 20}, {21, 22, 23, 24}} + }; + + private void testValid( + TestSession testSession, + Ops tf, + Operand weights, + Operand values, + Number[] expected, // flattened array + Class type) { + + Operand staticOp = MetricsHelper.broadcastWeights(tf, weights, values); + if (expected != null) { + testSession.evaluate(expected, staticOp); + } else { + testSession.run(staticOp); + } + + // dynamic test + Operand weightsPlaceholder = tf.placeholder(type); + Operand valuesPlaceholder = tf.placeholder(type); + + List tensors = + testSession.getGraphSession().runner().fetch(weights).fetch(values).run(); + try (Tensor weightsTensor = tensors.get(0); + Tensor valuesTensor = tensors.get(1)) { + + Operand dynamicOp = + MetricsHelper.broadcastWeights(tf, weightsPlaceholder, valuesPlaceholder); + + List result = + testSession + .getGraphSession() + .runner() + .feed(weightsPlaceholder, weightsTensor) + .feed(valuesPlaceholder, valuesTensor) + .fetch(dynamicOp) + .run(); + + if (expected != null) { + if (type.equals(TInt32.class)) { + TInt32 intT = (TInt32) result.get(0); + AtomicInteger i = new AtomicInteger(); + intT.scalars() + .forEachIndexed( + (idx, f) -> assertEquals(expected[i.getAndIncrement()].intValue(), f.getInt())); + } else if (type.equals(TInt64.class)) { + TInt64 floatT = (TInt64) result.get(0); + AtomicInteger i = new AtomicInteger(); + floatT + .scalars() + .forEachIndexed( + (idx, f) -> assertEquals(expected[i.getAndIncrement()].longValue(), f.getLong())); + } else if (type.equals(TFloat32.class)) { + TFloat32 floatT = (TFloat32) result.get(0); + AtomicInteger i = new AtomicInteger(); + floatT + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].floatValue(), f.getFloat(), 1e-5F)); + } else if (type.equals(TFloat64.class)) { + TFloat64 doubleT = (TFloat64) result.get(0); + AtomicInteger i = new AtomicInteger(); + doubleT + .scalars() + .forEachIndexed( + (idx, f) -> + assertEquals( + expected[i.getAndIncrement()].doubleValue(), f.getDouble(), 1e-5F)); + } + } + } + } + + @Test + public void testValidScalar() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayF); + Operand weights = tf.constant(5f); + Float[] expected = { + 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, 5f, + 5f + }; + testValid(testSession, tf, weights, values, expected, TFloat32.class); + } + } + + @Test + public void test1x1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayD); + Operand weights = tf.constant(new double[][][] {{{5}}}); + Double[] expected = { + 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., 5., + 5. + }; + + testValid(testSession, tf, weights, values, expected, TFloat64.class); + } + } + + @Test + public void test1x1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayL); + Operand weights = tf.constant(new long[][][] {{{5, 7, 11, 3}}}); + Long[] expected = { + 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, 11L, 3L, 5L, 7L, + 11L, 3L, + }; + testValid(testSession, tf, weights, values, expected, TInt64.class); + } + } + + @Test + public void test1xNx1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][] {{{5}, {11}}}); + Integer[] expected = { + 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11, 5, 5, 5, 5, 11, 11, 11, 11 + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + @Test + public void test1xNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][] {{{5, 7, 11, 3}, {2, 13, 7, 5}}}); + Integer[] expected = { + 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, 5, 7, 11, 3, 2, 13, 7, 5, + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + @Test + public void testNx1x1() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][] {{{5}}, {{7}}, {{11}}}); + Integer[] expected = { + 5, 5, 5, 5, 5, 5, 5, 5, 7, 7, 7, 7, 7, 7, 7, 7, 11, 11, 11, 11, 11, 11, 11, 11 + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + @Test + public void testNx1xN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = + tf.constant(new int[][][] {{{5, 7, 11, 3}}, {{2, 12, 7, 5}}, {{2, 17, 11, 3}}}); + Integer[] expected = { + 5, 7, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3 + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + @Test + public void testNxNxN() { + // no exception should be thrown + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + + Operand weights = + tf.constant( + new int[][][] { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + }); + Integer[] expected = { + 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5 + }; + testValid(testSession, tf, weights, values, expected, TInt32.class); + } + } + + // Note: For invalid tests, either NotBroadcastableException is thrown for static shapes or + // TFInvalidInvalidException is thrown for dynamic shapes. Both of these extend + // IllegalArgumentException, + // To simply the assertThrows, only IllegalArgumentException is expected. + // The private method, testValid, tests for both static and dynamic shapes. + @Test + public void testInvalid1() { + + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[] {5}); + + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalid1x1() { + + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][] {{5}}); + + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][] {{5, 7}, {11, 3}, {2, 12}}); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatch() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][] {{5, 7, 11, 3}, {2, 12, 7, 5}}); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidOnesExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = tf.constant(new int[][][][] {{{{5}}}}); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidPrefixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + + Operand weights = + tf.constant( + new int[][][][] { + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}}, + {{{2}, {17}, {11}, {3}}, {{2}, {17}, {11}, {3}}}, + {{{5}, {7}, {11}, {3}}, {{2}, {12}, {7}, {5}}} + }); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } + + @Test + public void testInvalidSuffixMatchExtraDim() { + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession testSession = TestSession.createTestSession(tfMode)) { + Ops tf = testSession.getTF(); + Operand values = tf.constant(valueArrayI); + Operand weights = + tf.constant( + new int[][][][] { + { + {{5, 7, 11, 3}, {2, 12, 7, 5}}, + {{2, 17, 11, 3}, {2, 17, 11, 3}}, + {{5, 7, 11, 3}, {2, 12, 7, 5}} + } + }); + testValid(testSession, tf, weights, values, null, TInt32.class); + } + }); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java new file mode 100644 index 00000000000..eceff2797f8 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/metrics/impl/SetsOpsTest.java @@ -0,0 +1,120 @@ +package org.tensorflow.framework.metrics.impl; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TUint8; +import org.tensorflow.types.family.TType; + +import java.util.Arrays; +import java.util.List; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +class SetsOpsTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + List> types = Arrays.asList(TInt32.class, TInt64.class, TUint8.class); + + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) + public void testSetIntersectionMultirow2() { + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); + Operand b = tf.constant(new int[][] {{1, 9}, {1, 5}}); + int[][] expected = new int[][] {{1, 9}, {0, 0}}; + Shape expectedShape = Shape.of(2, 2); + for (Class type : types) { + Operand aa = cast(tf, a, type); + Operand bb = cast(tf, b, type); + Operand intersection = SetsOps.intersection(tf, aa, bb); + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + } + } + } + + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) + public void testSetIntersectionDuplicates2d() { + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand a = tf.constant(new int[][] {{1, 1, 3}}); + Operand b = tf.constant(new int[][] {{1, 1}}); + int[][] expected = {{1}}; + Shape expectedShape = Shape.of(1, 1); + for (Class type : types) { + Operand aa = cast(tf, a, type); + Operand bb = cast(tf, b, type); + Operand intersection = SetsOps.intersection(tf, aa, bb); + + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + } + } + } + + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) + public void testDenseSetDifferenceMultirow2d() { + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand a = tf.constant(new int[][] {{1, 5, 9}, {4, 5, 3}}); + Operand b = tf.constant(new int[][] {{1, 2, 6}, {1, 2, 2}}); + + for (Class type : types) { + Operand aa = cast(tf, a, type); + Operand bb = cast(tf, b, type); + int[][] expected = {{5, 9, 0}, {3, 4, 5}}; + // a- b + Shape expectedShape = Shape.of(2, 3); + Operand intersection = SetsOps.difference(tf, aa, bb); + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + + // b - a + expected = new int[][] {{2, 6}, {1, 2}}; + expectedShape = Shape.of(2, 2); + intersection = SetsOps.difference(tf, aa, bb, false); + + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + } + } + } + + @Test + @SuppressWarnings({"unchecked", "rawtypes"}) + public void testDenseUnionMultirow2d() { + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Operand a = tf.constant(new int[][] {{9, 1, 5}, {2, 4, 3}}); + Operand b = tf.constant(new int[][] {{1, 9}, {1, 2}}); + int[][] expected = new int[][] {{5, 0}, {3, 4}}; + for (Class type : types) { + Operand aa = cast(tf, a, type); + Operand bb = cast(tf, b, type); + Shape expectedShape = Shape.of(2, 2); + // a- b + Operand intersection = SetsOps.difference(tf, aa, bb); + session.evaluate(cast(tf, tf.constant(expected), type), intersection); + session.evaluate(tf.constant(expectedShape), tf.shape(intersection, TInt64.class)); + } + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java index 33c4e064e69..43c0642939e 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/GraphTestSession.java @@ -213,10 +213,13 @@ public void evaluate(double expected, Operand input) { @Override public void evaluate(Number[] expected, Output input) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + if (size != Shape.UNKNOWN_SIZE) { + assertEquals( + expected.length, + size, + () -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); + } Class inputType = input.type(); if (inputType == TFloat32.class) { AtomicInteger index = new AtomicInteger(); @@ -425,10 +428,13 @@ public void evaluate(FloatNdArray expected, Output input) { @Override public void evaluate(String[] expected, Output input) { int size = input.shape().size() == 0 ? 1 : (int) input.shape().size(); - assertEquals( - expected.length, - size, - () -> String.format("expected length (%d) != to input length (%d)", expected.length, size)); + if (size != Shape.UNKNOWN_SIZE) { + assertEquals( + expected.length, + size, + () -> + String.format("expected length (%d) != to input length (%d)", expected.length, size)); + } AtomicInteger index = new AtomicInteger(); if (debug) { try (TString result = @@ -1025,7 +1031,7 @@ public void print(PrintWriter writer, Output input) { (TFloat64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %f\n", index.getAndIncrement(), ((Output) input).asTensor().getDouble()); + "%d). %f\n", index.getAndIncrement(), result.getDouble()); } else { result .scalars() @@ -1040,7 +1046,7 @@ public void print(PrintWriter writer, Output input) { (TInt32)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output) input).asTensor().getInt()); + "%d). %d\n", index.getAndIncrement(),result.getInt()); } else { result .scalars() @@ -1055,7 +1061,7 @@ public void print(PrintWriter writer, Output input) { (TInt64)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %d\n", index.getAndIncrement(), ((Output) input).asTensor().getLong()); + "%d). %d\n", index.getAndIncrement(), result.getLong()); } else { result .scalars() @@ -1070,7 +1076,7 @@ public void print(PrintWriter writer, Output input) { (TUint8)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %x\n", index.getAndIncrement(), ((Output) input).asTensor().getByte()); + "%d). %x\n", index.getAndIncrement(), result.getByte()); } else { result .scalars() @@ -1085,7 +1091,7 @@ public void print(PrintWriter writer, Output input) { (TBool)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %b\n", index.getAndIncrement(), ((Output) input).asTensor().getBoolean()); + "%d). %b\n", index.getAndIncrement(), result.getBoolean()); } else { result .scalars() @@ -1100,7 +1106,7 @@ public void print(PrintWriter writer, Output input) { (TString)this.getGraphSession().runner().fetch(input).run().get(0)) { if (isScalar) { writer.printf( - "%d). %s\n", index.getAndIncrement(), ((Output) input).asTensor().getObject()); + "%d). %s\n", index.getAndIncrement(), result.getObject()); } else { result .scalars() diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java index 3fccd0f0506..db39a330522 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/TestSession.java @@ -492,6 +492,16 @@ public void evaluate(FloatNdArray input, Predicate predicate) { input.scalars().forEach(f -> assertTrue(predicate.test(f.getFloat()))); } + /** + * Print the input to standard out + * + + * @param input the operand to print + * @param the data type of the input + */ + public void print(Operand input) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.asOutput()); + } /** * Print the input * @@ -503,6 +513,15 @@ public void print(OutputStream out, Operand input) { print(new PrintWriter(new OutputStreamWriter(out)), input.asOutput()); } + /** + * Print the input to standard out + * + * @param input the op to print + */ + public void print(Op input) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input.op().output(0)); + } + /** * Print the input * @@ -513,6 +532,16 @@ public void print(OutputStream out, Op input) { print(new PrintWriter(new OutputStreamWriter(out)), input.op().output(0)); } + /** + * Print the input to standard out + * + * @param input the op to print + * @param the data type of the input + */ + public void print(Output input) { + print(new PrintWriter(new OutputStreamWriter(System.out)), input); + } + /** * Print the input *