categoricalCross
tLabels = smoothCategoricalLabels(tf, tLabels, labelSmoothing);
}
if (fromLogits) {
- return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, -1);
+ return tf.nn.softmaxCrossEntropyWithLogits(tLabels, predictions, axis);
}
/* TODO
if (!(predictions instanceof Variable) && (!tf.scope().env().isEager())) {
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java
new file mode 100644
index 00000000000..651a6fac0b0
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/BinaryCrossentropy.java
@@ -0,0 +1,66 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A Metric that computes the binary cross-entropy loss between true labels and predicted labels.
+ *
+ * This is the crossentropy metric class to be used when there are only two label classes (0 and
+ * 1).
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result
+ */
+public class BinaryCrossentropy
+ extends MeanMetricWrapper implements LossMetric {
+
+ private final boolean fromLogits;
+ private final float labelSmoothing;
+
+ /**
+ * Creates a BinaryCrossentropy metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution.
+ * @param labelSmoothing value used to smooth labels, When 0, no smoothing occurs. When > 0,
+ * compute the loss between the predicted labels and a smoothed version of the true labels,
+ * where the smoothing squeezes the labels towards 0.5. Larger values of label_smoothing
+ * correspond to heavier smoothing.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public BinaryCrossentropy(
+ Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ this.fromLogits = fromLogits;
+ this.labelSmoothing = labelSmoothing;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.binaryCrossentropy(getTF(), labels, predictions, fromLogits, labelSmoothing);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java
new file mode 100644
index 00000000000..c330ea88eaa
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalCrossentropy.java
@@ -0,0 +1,105 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A Metric that computes the categorical cross-entropy loss between true labels and predicted
+ * labels.
+ *
+ * This is the crossentropy metric class to be used when there are multiple label classes (2 or
+ * more). The labels should be given as a one_hot representation. eg., When labels values are
+ * [2, 0, 1]
, the labels Operand contains = [[0, 0, 1], [1, 0, 0], [0, 1, 0]]
+ *
.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result
+ */
+public class CategoricalCrossentropy
+ extends MeanMetricWrapper implements LossMetric {
+
+ private final boolean fromLogits;
+ private final float labelSmoothing;
+ private final int axis;
+
+ /**
+ * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the
+ * labels and predictions.
+ *
+ * Uses a {@link Losses#CHANNELS_LAST} for the channel axis.
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param fromLogits Whether to interpret predictions as a tensor of logit values oras opposed to a probability distribution.
+ * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed,
+ * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2
+ * means that we will use a value of 0.1
for label 0
and 0.9
+ *
for label 1
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public CategoricalCrossentropy(
+ Ops tf, String name, boolean fromLogits, float labelSmoothing, long seed, Class type) {
+ this(tf, name, fromLogits, labelSmoothing, Losses.CHANNELS_LAST, seed, type);
+ }
+
+ /**
+ * Creates a CategoricalCrossentropy metric that computes the crossentropy metric between the
+ * labels and predictions.
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution.
+ * @param labelSmoothing value used to smooth labels, When > 0, label values are smoothed,
+ * meaning the confidence on label values are relaxed. e.g. labelSmoothing=0.2
+ * means that we will use a value of 0.1
for label 0
and 0.9
+ *
for label 1
+ * @param axis Int specifying the channels axis. axis={@link Losses#CHANNELS_LAST}
+ * corresponds to data format channels_last
, and
+ * axis={@link Losses#CHANNELS_FIRST}
corresponds to data format
+ * channels_first
.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public CategoricalCrossentropy(
+ Ops tf,
+ String name,
+ boolean fromLogits,
+ float labelSmoothing,
+ int axis,
+ long seed,
+ Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ this.fromLogits = fromLogits;
+ this.labelSmoothing = labelSmoothing;
+ this.axis = axis;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.categoricalCrossentropy(
+ getTF(), labels, predictions, fromLogits, labelSmoothing, axis);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java
new file mode 100644
index 00000000000..2741a36edb6
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CategoricalHinge.java
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A Metric that computes the categorical hinge loss metric between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result
+ */
+public class CategoricalHinge extends MeanMetricWrapper
+ implements LossMetric {
+
+ /**
+ * Creates a CategoricalHinge metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public CategoricalHinge(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.categoricalHinge(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java
new file mode 100644
index 00000000000..458de092bec
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/CosineSimilarity.java
@@ -0,0 +1,83 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the cosine similarity metric between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class CosineSimilarity extends MeanMetricWrapper
+ implements LossMetric {
+ public static final int DEFAULT_AXIS = -1;
+ private final int[] axis;
+
+ /**
+ * Creates a metric that computes the cosine similarity metric between labels and predictions with
+ * a default axis, {@link #DEFAULT_AXIS}
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public CosineSimilarity(Ops tf, String name, long seed, Class type) {
+ this(tf, name, DEFAULT_AXIS, seed, type);
+ }
+
+ /**
+ * Creates a metric that computes the cosine similarity metric between labels and predictions.
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param axis The dimension along which the cosine similarity is computed.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public CosineSimilarity(Ops tf, String name, int axis, long seed, Class type) {
+ this(tf, name, new int[] {axis}, seed, type);
+ }
+ /**
+ * Creates a CosineSimilarity metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param axis The dimension along which the cosine similarity is computed.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public CosineSimilarity(Ops tf, String name, int[] axis, long seed, Class type) {
+ super(tf, name, seed, type);
+ this.axis = axis;
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ // NOTE: cosineProximity is a different algorithm than Losses.cosineSimilarity
+ return Metrics.cosineProximity(getTF(), labels, predictions, axis);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java
new file mode 100644
index 00000000000..baf9ad8ab7d
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Hinge.java
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the hinge loss metric between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class Hinge extends MeanMetricWrapper
+ implements LossMetric {
+
+ /**
+ * Creates a Hinge metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public Hinge(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.hinge(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java
new file mode 100644
index 00000000000..efcbbcbb7f0
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/KLDivergence.java
@@ -0,0 +1,53 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the Kullback-Leibler divergence loss metric between labels and
+ * predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class KLDivergence extends MeanMetricWrapper
+ implements LossMetric {
+
+ /**
+ * Creates a KLDivergence metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public KLDivergence(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.kullbackLeiblerDivergence(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java
new file mode 100644
index 00000000000..3df8505d54b
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/LogCoshError.java
@@ -0,0 +1,53 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the logarithm of the hyperbolic cosine of the prediction error metric
+ * between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class LogCoshError extends MeanMetricWrapper
+ implements LossMetric {
+
+ /**
+ * Creates a LogCoshError metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public LogCoshError(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.logCosh(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java
new file mode 100644
index 00000000000..de1f5a5629e
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Mean.java
@@ -0,0 +1,41 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.framework.metrics.impl.Reduce;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that that implements a weighted mean {@link MetricReduction#WEIGHTED_MEAN }
+ *
+ * @param The data type for the metric values
+ * @param The data type for the metric result
+ */
+public class Mean extends Reduce {
+
+ /**
+ * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#SUM}
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ protected Mean(Ops tf, String name, long seed, Class type) {
+ super(tf, name, MetricReduction.WEIGHTED_MEAN, seed, type);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java
new file mode 100644
index 00000000000..e27676932ff
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsoluteError.java
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the mean of absolute difference between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class MeanAbsoluteError extends MeanMetricWrapper
+ implements LossMetric {
+
+ /**
+ * Creates a Mean Absolute Error metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public MeanAbsoluteError(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.meanAbsoluteError(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java
new file mode 100644
index 00000000000..84fa9b627b2
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanAbsolutePercentageError.java
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the mean of absolute difference between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class MeanAbsolutePercentageError
+ extends MeanMetricWrapper implements LossMetric {
+
+ /**
+ * Creates a Mean Absolute Error metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public MeanAbsolutePercentageError(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.meanAbsolutePercentageError(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java
new file mode 100644
index 00000000000..c7edd6ebe93
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredError.java
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the mean of absolute difference between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class MeanSquaredError extends MeanMetricWrapper
+ implements LossMetric {
+
+ /**
+ * Creates a Mean Absolute Error metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public MeanSquaredError(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.meanSquaredError(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java
new file mode 100644
index 00000000000..199b6e0e114
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MeanSquaredLogarithmicError.java
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the mean of absolute difference between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class MeanSquaredLogarithmicError
+ extends MeanMetricWrapper implements LossMetric {
+
+ /**
+ * Creates a Mean Absolute Error metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public MeanSquaredLogarithmicError(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.meanSquaredLogarithmicError(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java
new file mode 100644
index 00000000000..bbb2aa73da2
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metric.java
@@ -0,0 +1,193 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * Base class for Metrics
+ *
+ * @param The data type for the metric values
+ * @param The data type for the metric result
+ */
+public abstract class Metric {
+
+ /** The TensorFlow Ops */
+ private final Ops tf;
+
+ /** The seed for random number generation */
+ private final long seed;
+
+ /** The name for this metric. Defaults to {@link Class#getSimpleName()}. */
+ private final String name;
+
+ /**
+ * Creates a Metric with a name of {@link Class#getSimpleName()}
+ *
+ * @param tf the TensorFlow Ops
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ */
+ protected Metric(Ops tf, long seed) {
+ this(tf, null, seed);
+ }
+
+ /**
+ * Creates a Metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ */
+ protected Metric(Ops tf, String name, long seed) {
+ if (!tf.scope().env().isGraph()) {
+ throw new IllegalArgumentException("Metrics are required to execute in Graph mode.");
+ }
+ this.seed = seed;
+ this.name = name != null ? name : this.getClass().getSimpleName();
+ this.tf = tf.withName(this.getClass().getSimpleName());
+ }
+
+ /**
+ * Creates a List of Operations to update the metric state based on input values.
+ *
+ * This is an empty implementation that should be overridden in a subclass, if needed.
+ *
+ * @param values the inputs to be passed to update state, this may not be null
+ * @param sampleWeights sample weights to be applied to values, may be null.
+ * @return a List of Operations to update the metric state
+ * @param the data type for sampleWeights
+ */
+ @SuppressWarnings({"unchecked", "unused"})
+ public List updateStateList(Operand values, Operand sampleWeights) {
+ return Collections.EMPTY_LIST;
+ }
+
+ /**
+ * Creates a List of Operations to update the metric state based on labels and predictions.
+ *
+ * This is an empty implementation that should be overridden in a sub class, if needed.
+ *
+ * @param labels the labels
+ * @param predictions the predictions
+ * @param sampleWeights sample weights to be applied to values, may be null.
+ * @param the data type for the labels
+ * @param the data type for the sampleWeights
+ * @return a List of Operations to update the metric state
+ */
+ @SuppressWarnings({"unchecked", "unused"})
+ public List updateStateList(
+ Operand labels, Operand predictions, Operand sampleWeights) {
+ return Collections.EMPTY_LIST;
+ }
+
+ /**
+ * Creates a NoOp Operation with control dependencies to update the metric state
+ *
+ * @param values the inputs to be passed to update state, this may not be null
+ * @param sampleWeights sample weights to be applied to values, may be null.
+ * @param the data type for sampleWeights
+ * @return the Operation to update the metric state
+ */
+ public final Op updateState(Operand values, Operand sampleWeights) {
+ List controlOps = updateStateList(values, sampleWeights);
+ return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp();
+ }
+
+ /**
+ * Creates a NoOp Operation with control dependencies to update the metric state
+ *
+ * @param labels the labels
+ * @param predictions the predictions
+ * @param sampleWeights sample weights to be applied to values, may be null.
+ * @param the data type for the labels
+ * @param the data type for the sampleWeights
+ * @return the Operation to update the metric state
+ */
+ public final Op updateState(
+ Operand labels, Operand predictions, Operand sampleWeights) {
+ List controlOps = updateStateList(labels, predictions, sampleWeights);
+ return tf.withSubScope("updateState").withControlDependencies(controlOps).noOp();
+ }
+
+ /**
+ * Gets the current result of the metric
+ *
+ * @return the result, possibly with control dependencies
+ */
+ public abstract Operand result();
+
+ /**
+ * Resets any state variables to their initial values
+ *
+ * @return the control operation for doing the reset
+ */
+ public abstract Op resetStates();
+
+ /**
+ * Calls update state once, followed by a call to get the result
+ *
+ * @param values the inputs to be passed to update state, this may not be null
+ * @param sampleWeights sample weights to be applied to values, may be null.
+ * @return the result, possibly with control dependencies
+ * @param the data type for the sampleWeights.
+ */
+ public final Operand callOnce(
+ Operand values, Operand sampleWeights) {
+ List controlOps = updateStateList(values, sampleWeights);
+ Ops ltf = tf.withSubScope("callOnce").withControlDependencies(controlOps);
+ return ltf.identity(result());
+ }
+
+ /**
+ * Gets a formatted name for a variable, in the form {@link #name} + "_" + varName.
+ *
+ * @param varName the base name for the variable
+ * @return the formatted variable name
+ */
+ protected String getVariableName(String varName) {
+ return String.format("%s_%s", this.name, varName);
+ }
+
+ /**
+ * Gets the TensorFlow Ops
+ *
+ * @return the TensorFlow Ops
+ */
+ public Ops getTF() {
+ return tf;
+ }
+
+ /**
+ * Gets the name of this metric.
+ *
+ * @return the name of this metric
+ */
+ public String getName() {
+ return name;
+ }
+
+ /** The random number generator seed value */
+ public long getSeed() {
+ return seed;
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java
new file mode 100644
index 00000000000..d837ff626b3
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/MetricReduction.java
@@ -0,0 +1,26 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+/** Defines the different types of metric reductions */
+public enum MetricReduction {
+
+ /** Scalar sum of weighted values. */
+ SUM,
+ /** Scalar sum of weighted values divided by number of elements. */
+ SUM_OVER_BATCH_SIZE,
+ /** Scalar sum of weighted values divided by sum of weights. */
+ WEIGHTED_MEAN
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java
new file mode 100644
index 00000000000..0169bc6b8bc
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Metrics.java
@@ -0,0 +1,134 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.utils.CastHelper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.core.ReduceSum;
+import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.family.TNumber;
+
+/** Helper class with built-in metrics functions. */
+public class Metrics {
+
+ public static final float L2_NORM_EPSILON = 1e-12f;
+
+ /**
+ * Computes how often targets are in the top K predictions.
+ *
+ * Standalone usage:
+ *
+ *
+ * Operand<TInt32> labels = tf.constant(new int[][]
+ * {{0, 0, 1}, {0, 1, 0}});
+ * Operand<TFloat32> predictions = tf.constant(new float[][]
+ * {{0.1f, 0.9f, 0.8f}, {0.05f, 0.95f, 0f}});
+ * Operand<TFloat32> m = Metrics.topKCategoricalAccuracy(
+ * labels, predictions, 3)
+ * //m.shape().toString == "[2]"
+ *
+ *
+ * @param tf the TensorFlow Ops.
+ * @param labels the ground truth values.
+ * @param predictions The prediction values.
+ * @param k Number of top elements to look at for computing accuracy.
+ * @param the data type for the predictions and results
+ * @param the data type ofr the labels.
+ * @return the Operand for the Top K categorical accuracy value.
+ */
+ public static Operand topKCategoricalAccuracy(
+ Ops tf, Operand labels, Operand predictions, long k) {
+ Operand fPredictions = CastHelper.cast(tf, predictions, TFloat32.class);
+ return CastHelper.cast(
+ tf,
+ tf.nn.inTopK(fPredictions, tf.math.argMax(labels, tf.constant(-1)), tf.constant(k)),
+ predictions.type());
+ }
+
+ /**
+ * Computes the cosine similarity between labels and predictions.
+ *
+ * @param tf the TensorFlow Ops
+ * @param labels The ground truth values.
+ * @param predictions The prediction values.
+ * @param axes The dimensions along which the cosine similarity is computed.
+ * @param the data type for the labels
+ * @param the data type for the predictions and result
+ * @return Cosine similarity value.
+ */
+ public static Operand cosineProximity(
+ Ops tf, Operand labels, Operand predictions, int[] axes) {
+ Operand labelsNorm = CastHelper.cast(tf, labels, predictions.type());
+ labelsNorm = l2Normalize(tf, labelsNorm, axes);
+
+ Operand predictionsNorm = l2Normalize(tf, predictions, axes);
+ Operand mathMul = tf.math.mul(labelsNorm, predictionsNorm);
+ return tf.reduceSum(mathMul, tf.constant(axes), ReduceSum.keepDims(Boolean.FALSE));
+ }
+
+ /**
+ * Normalizes along dimension axis
using an L2 norm with an epsilon of {@link
+ * #L2_NORM_EPSILON}.
+ *
+ * For a 1-D tensor with axis = 0
, computes
+ *
+ *
+ * output = x / sqrt(max(sum(x**2), epsilon))
+ *
+ *
+ * For x
with more dimensions, independently normalizes each 1-D slice along
+ * dimension axis
.
+ *
+ * @param tf The TensorFlow ops
+ * @param x The operand to normalize
+ * @param axes Dimension(s) along which to normalize.
+ * @param The data type for x.
+ * @return the normalized values of x.
+ */
+ public static Operand l2Normalize(Ops tf, Operand x, int[] axes) {
+ return l2Normalize(tf, x, axes, L2_NORM_EPSILON);
+ }
+
+ /**
+ * Normalizes along dimension axis
using an L2 norm.
+ *
+ *
For a 1-D tensor with axis = 0
, computes
+ *
+ *
+ * output = x / sqrt(max(sum(x**2), epsilon))
+ *
+ *
+ * For x
with more dimensions, independently normalizes each 1-D slice along
+ * dimension axis
.
+ *
+ * @param tf The TensorFlow ops
+ * @param x The operand to normalize
+ * @param axes Dimension(s) along which to normalize.
+ * @param epsilon A lower bound value for the norm. Will use sqrt(epsilon)
as the
+ * divisor if norm < sqrt(epsilon)
.
+ * @param The data type for the values.
+ * @return the normalized values of x.
+ */
+ public static Operand l2Normalize(
+ Ops tf, Operand x, int[] axes, float epsilon) {
+ Operand squareSum =
+ tf.reduceSum(tf.math.square(x), tf.constant(axes), ReduceSum.keepDims(Boolean.TRUE));
+ Operand y =
+ tf.math.rsqrt(
+ tf.math.maximum(squareSum, CastHelper.cast(tf, tf.constant(epsilon), x.type())));
+ return tf.math.mul(x, y);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java
new file mode 100644
index 00000000000..75a2031fbb5
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/Poisson.java
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the poisson loss metric between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class Poisson extends MeanMetricWrapper
+ implements LossMetric {
+
+ /**
+ * Creates a Poisson metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public Poisson(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.poisson(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java
new file mode 100644
index 00000000000..2e01f722de6
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SparseCategoricalCrossentropy.java
@@ -0,0 +1,61 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the sparse categorical cross-entropy loss between true labels and
+ * predicted labels.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class SparseCategoricalCrossentropy
+ extends MeanMetricWrapper implements LossMetric {
+
+ private final boolean fromLogits;
+ private final int axis;
+
+ /**
+ * Creates a SparseCategoricalCrossentropy metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param fromLogits Whether to interpret predictions as a tensor of logit values as opposed to a probability distribution.
+ * @param axis The dimension along which the entropy is computed.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public SparseCategoricalCrossentropy(
+ Ops tf, String name, boolean fromLogits, int axis, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ this.fromLogits = fromLogits;
+ this.axis = axis;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.sparseCategoricalCrossentropy(getTF(), labels, predictions, fromLogits, axis);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java
new file mode 100644
index 00000000000..430dbbcc229
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/SquaredHinge.java
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Losses;
+import org.tensorflow.framework.metrics.impl.LossMetric;
+import org.tensorflow.framework.metrics.impl.MeanMetricWrapper;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A metric that computes the squared hinge loss metric between labels and predictions.
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result.
+ */
+public class SquaredHinge extends MeanMetricWrapper
+ implements LossMetric {
+
+ /**
+ * Creates a SquaredHinge metric
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this metric, if null then metric name is {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ public SquaredHinge(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ setLoss(this);
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand labels, Operand predictions) {
+ return Losses.squaredHinge(getTF(), labels, predictions);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java
new file mode 100644
index 00000000000..66640e72f50
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/exceptions/NotBroadcastableException.java
@@ -0,0 +1,52 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics.exceptions;
+
+import org.tensorflow.ndarray.Shape;
+
+/**
+ * Exception that indicates that static shapes are not able to broadcast among each other during
+ * arithmetic operations. Static shapes do not have unknown rank or any unknown dimensions {@link
+ * Shape#hasUnknownDimension()}. The term broadcasting describes how TensorFlow treats arrays with
+ * different shapes during arithmetic operations.
+ *
+ * Broadcasting is the process of making arrays to have compatible shapes for arithmetic
+ * operations. Two shapes are compatible if for each dimension pair they are either equal or one of
+ * them is one. When trying to broadcast a Tensor to a shape, it starts with the trailing
+ * dimensions, and works its way forward.
+ *
+ * @see Numpy Broadcasting
+ */
+public class NotBroadcastableException extends IllegalArgumentException {
+
+ /**
+ * Creates a new NotBroadcastableException exception with the specified detail message
+ *
+ * @param message the detail message.
+ */
+ public NotBroadcastableException(String message) {
+ super(message);
+ }
+
+ /**
+ * Creates a new NotBroadcastableException exception with the specified detail message
+ *
+ * @param message the detail message.
+ * @param cause the cause
+ */
+ public NotBroadcastableException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java
new file mode 100644
index 00000000000..b7b87d313aa
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/LossMetric.java
@@ -0,0 +1,36 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics.impl;
+
+import org.tensorflow.Operand;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * Interface for Metrics that wrap Loss functions.
+ *
+ * @param The data type of the predictions.
+ */
+public interface LossMetric {
+
+ /**
+ * Calculates the weighted loss between labels
and predictions
+ *
+ * @param labels the truth values or labels
+ * @param predictions the predictions
+ * @param The data type of the labels.
+ * @return the loss
+ */
+ Operand call(Operand labels, Operand predictions);
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java
new file mode 100644
index 00000000000..17c209a8fed
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MeanMetricWrapper.java
@@ -0,0 +1,106 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics.impl;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.metrics.Mean;
+import org.tensorflow.framework.metrics.MetricReduction;
+import org.tensorflow.framework.utils.CastHelper;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+import java.util.List;
+
+/**
+ * A class that bridges a stateless loss function with the {@link Mean} metric using a reduction of
+ * {@link MetricReduction#WEIGHTED_MEAN}.
+ *
+ * The loss function calculates the loss between the labels
and predictions
+ *
then passes this loss to the {@link Mean} metric to calculate the weighted mean of the
+ * loss over many iterations or epochs
+ *
+ * @param the data type for the predictions.
+ * @param The data type for the metric result
+ */
+public class MeanMetricWrapper extends Mean {
+
+ /** The loss function interface */
+ protected LossMetric loss;
+
+ /**
+ * Creates a Reducible Metric with a metric reductions of {@link MetricReduction#WEIGHTED_MEAN}
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name for this metric. If null, name defaults to {@link Class#getSimpleName()}.
+ * @param seed the seed for random number generation. An initializer created with a given seed
+ * will always produce the same random tensor for a given shape and data type.
+ * @param type the type for the variables and result
+ */
+ protected MeanMetricWrapper(Ops tf, String name, long seed, Class type) {
+ super(tf, name, seed, type);
+ }
+
+ /**
+ * Gets the loss function.
+ *
+ * @return the loss function.
+ */
+ public LossMetric getLoss() {
+ return loss;
+ }
+
+ /**
+ * Sets the Loss function for this wrapper.
+ *
+ * @param loss the loss function.
+ */
+ protected void setLoss(LossMetric loss) {
+ this.loss = loss;
+ }
+
+ /**
+ * Creates Operations that update the state of the mean metric, by calling the loss function and
+ * passing the loss to the Mean metric to calculate the weighted mean of the loss over many
+ * iterations.
+ *
+ * @param labels the truth values or labels
+ * @param predictions the predictions
+ * @param sampleWeights Optional sampleWeights acts as a coefficient for the loss. If a scalar is
+ * provided, then the loss is simply scaled by the given value. If sampleWeights is a tensor
+ * of size [batch_size], then the total loss for each sample of the batch is rescaled by the
+ * corresponding element in the sampleWeights vector. If the shape of sampleWeights is
+ * [batch_size, d0, .. dN-1] (or can be broadcasted to this shape), then each loss element of
+ * predictions is scaled by the corresponding value of sampleWeights. (Note on dN-1: all loss
+ * functions reduce by 1 dimension, usually axis=-1.)
+ * @param the datatype of the labels
+ * @param the data type for sampleWeights
+ * @return a List of control operations that updates the Mean state variables.
+ */
+ public List updateStateList(
+ Operand labels, Operand predictions, Operand sampleWeights) {
+ if (labels == null || predictions == null) {
+ throw new IllegalArgumentException("missing required inputs for labels and predictions");
+ }
+
+ Operand tLabels = CastHelper.cast(getTF(), labels, getResultType());
+ Operand tPredictions = CastHelper.cast(getTF(), predictions, getResultType());
+
+ Operand losses = loss.call(tLabels, tPredictions);
+
+ return super.updateStateList(
+ CastHelper.cast(getTF(), losses, predictions.type()), sampleWeights);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
new file mode 100644
index 00000000000..ad8ff58e417
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/metrics/impl/MetricsHelper.java
@@ -0,0 +1,348 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.metrics.impl;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.metrics.exceptions.NotBroadcastableException;
+import org.tensorflow.ndarray.Shape;
+import org.tensorflow.op.Op;
+import org.tensorflow.op.Ops;
+import org.tensorflow.op.math.Mean;
+import org.tensorflow.types.TBool;
+import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.TFloat64;
+import org.tensorflow.types.TInt32;
+import org.tensorflow.types.family.TIntegral;
+import org.tensorflow.types.family.TNumber;
+import org.tensorflow.types.family.TType;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.tensorflow.framework.losses.impl.LossesHelper.allAxes;
+import static org.tensorflow.framework.utils.CastHelper.cast;
+
+/**
+ * These are helper methods for Metrics and will be module private when Java modularity is applied
+ * to TensorFlow Java. These methods should not be used outside of the metrics packages.
+ */
+public class MetricsHelper {
+ public static final float NEG_INF = -1e10f;
+ private static final String ASSERT_BROADCAST_ERROR_PREFIX =
+ "weights can not be broadcast to values.";
+
+ /**
+ * Asserts that the sampleWeights
can be broadcast to the same shape as values
+ *
+ *
+ * In losses and metrics, limited weight broadcasting is supported. Weights must be either
+ * scalar, or the same rank as the target values, with each dimension either 1, or the same as the
+ * corresponding values dimension.
+ *
+ * @param tf the TensorFlow Ops
+ * @param sampleWeights the sample weights.
+ * @param values the values to which weights are applied.
+ * @return Operation
with control dependencies to ensure sampleWeight
+ * can be broadcast to values
+ * @param the type of Operand
+ * @throws NotBroadcastableException If static checks determine sampleWeights
has an
+ * incorrect shape that prohibit broadcasting to values
+ */
+ @SuppressWarnings("unchecked")
+ public static Op assertBroadcastable(
+ Ops tf, Operand sampleWeights, Operand values) {
+
+ // try static check for exact match
+
+ Shape weightsShapeStatic = sampleWeights.shape();
+ int weightsRankStatic = weightsShapeStatic.numDimensions();
+
+ Shape valuesShapeStatic = values.shape();
+ int valuesRankStatic = valuesShapeStatic.numDimensions();
+
+ // if (weightsRankStatic != Shape.UNKNOWN_SIZE && valuesRankStatic != Shape.UNKNOWN_SIZE) {
+ if (!weightsShapeStatic.isUnknown()
+ && !valuesShapeStatic.isUnknown()
+ && !weightsShapeStatic.hasUnknownDimension()
+ && !valuesShapeStatic.hasUnknownDimension()) {
+ if (weightsRankStatic == 0) {
+ return tf.withSubScope("staticScalarCheckSuccess")
+ .withControlDependencies(Collections.EMPTY_LIST)
+ .noOp();
+ }
+ if (weightsRankStatic != valuesRankStatic) {
+ throw new NotBroadcastableException(
+ String.format(
+ "%s values.rank=%d. weights.rank=%d. values.shape=%s. weights.shape=%s.",
+ ASSERT_BROADCAST_ERROR_PREFIX,
+ valuesRankStatic,
+ weightsRankStatic,
+ valuesShapeStatic.toString(),
+ weightsShapeStatic.toString()));
+ }
+
+ for (int i = 0; i < valuesRankStatic; i++) {
+ if (valuesShapeStatic.size(i) != weightsShapeStatic.size(i)
+ && weightsShapeStatic.size(i) != 1) {
+ throw new NotBroadcastableException(
+ String.format(
+ "%s Mismatch at dim %d. values.shape=%s weights.shape=%s.",
+ ASSERT_BROADCAST_ERROR_PREFIX,
+ i,
+ valuesShapeStatic.toString(),
+ weightsShapeStatic.toString()));
+ }
+ }
+ return tf.withSubScope("staticDimsCheckSuccess")
+ .withControlDependencies(Collections.EMPTY_LIST)
+ .noOp();
+ }
+ // Dynamic checks.
+ Operand weightsShape = tf.shape(sampleWeights);
+ Operand weightsRank = tf.rank(sampleWeights);
+ Operand valuesShape = tf.shape(values);
+ Operand valuesRank = tf.rank(values);
+
+ Operand isScalar = tf.math.equal(weightsRank, tf.constant(0));
+ List> data =
+ Arrays.asList(
+ tf.constant(ASSERT_BROADCAST_ERROR_PREFIX),
+ tf.constant("weights.shape="),
+ weightsShape,
+ tf.constant("values.shape="),
+ valuesShape,
+ tf.constant("isScalar="),
+ isScalar);
+
+ // hack to work around the non-lazy select for isValidShape, otherwise validNonscalar fails on a
+ // scalar weight. If select was lazy, that branch wouldn't get executed when iScalar is true.
+ Operand reshapedWeights =
+ tf.select(isScalar, tf.math.mul(sampleWeights, tf.onesLike(values)), sampleWeights);
+ weightsShape = tf.shape(reshapedWeights);
+ weightsRank = tf.rank(reshapedWeights);
+
+ Operand validNonscalar =
+ canBroadcastNonscalarShapes(tf, weightsRank, weightsShape, valuesRank, valuesShape);
+
+ Operand isValidShape = tf.select(isScalar, isScalar, validNonscalar);
+
+ return tf.withSubScope("broadcastWeights-dynamic").assertThat(isValidShape, data);
+ }
+
+ /**
+ * Gets an operand that tests if the shapes have the same rank and valid dimensions.
+ *
+ * @param tf the TensorFlow Ops
+ * @param weightsRank the operand for the rank of the sample weights
+ * @param weightsShape the operand for the shape of the sample weights
+ * @param valuesRank the operand for the rank of the sample weights
+ * @param valuesShape the operand for the shape of the sample weights
+ * @param the data type for the operands
+ * @return a boolean operand to determine if the Shape is scalar or not.
+ */
+ private static Operand canBroadcastNonscalarShapes(
+ Ops tf,
+ Operand weightsRank,
+ Operand