reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input));
+ regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum));
+ }
+
+ return regularization;
+ }
+
+ /**
+ * Gets the L1 regularization factor
+ *
+ * @return the L1 regularization factor
+ */
+ public float getL1() {
+ return l1;
+ }
+
+ /**
+ * Gets the L2 regularization factor
+ *
+ * @return the L2 regularization factor
+ */
+ public float getL2() {
+ return l2;
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java
new file mode 100644
index 00000000000..7b8f5b28a70
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java
@@ -0,0 +1,46 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.regularizers;
+
+import org.tensorflow.op.Ops;
+
+/**
+ * A regularizer that applies a L2 (Ridge Regression) regularization penalty.
+ *
+ * The L2 regularization penalty is computed as: loss = l2 * reduceSum(square(x))
+ */
+public class L2 extends L1L2 {
+
+ /**
+ * Create a regularizer that applies an L2 regularization penalty of {@link
+ * #DEFAULT_REGULARIZATION_PENALTY}
+ *
+ * @param tf the TensorFlow Ops
+ */
+ public L2(Ops tf) {
+ this(tf, DEFAULT_REGULARIZATION_PENALTY);
+ }
+
+ /**
+ * Create a regularizer that applies an L1 regularization penalty
+ *
+ * @param tf the TensorFlow Ops
+ * @param l2 the L2 regularization penalty
+ * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite.
+ */
+ public L2(Ops tf, float l2) {
+ super(tf, 0f, l2);
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java
new file mode 100644
index 00000000000..5d9ff0e3e10
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java
@@ -0,0 +1,91 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.regularizers;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Loss;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * Base class for Regularizers
+ *
+ *
Regularizers allow you to apply penalties on layer parameters or layer activity during
+ * optimization. These penalties are summed into the loss function that the network optimizes.
+ */
+public abstract class Regularizer {
+
+ public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f;
+
+ private final Ops tf;
+ private final String name;
+
+ /**
+ * Creates a Regularizer, using {@link Class#getSimpleName()} for the name
+ *
+ * @param tf the TensorFlow ops.
+ */
+ protected Regularizer(Ops tf) {
+ this(tf, null);
+ }
+ /**
+ * Creates a Regularizer
+ *
+ * @param tf the TensorFlow ops.
+ * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the
+ * name.
+ */
+ protected Regularizer(Ops tf, String name) {
+ this.tf = tf;
+ this.name = name == null ? this.getClass().getSimpleName() : name;
+ }
+
+ /**
+ * Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only
+ * sampleWeights are applied to the regularizer.
+ *
+ * @return this Regularizer as a Loss
+ */
+ public Loss asLoss() {
+ return new RegularizerLoss(this.tf, this);
+ }
+
+ /**
+ * Computes a regularization penalty from an input.
+ *
+ * @param input the weighted input
+ * @return the result of computing the regularization penalty
+ * @param the data type of the input and result
+ */
+ public abstract Operand call(Operand input);
+
+ /**
+ * Gets the TensorFlow Ops
+ *
+ * @return the TensorFlow Ops
+ */
+ public Ops getTF() {
+ return tf;
+ }
+
+ /**
+ * Gets the name for this regularizer
+ *
+ * @return the name for this regularizer
+ */
+ public String getName() {
+ return name;
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java
new file mode 100644
index 00000000000..582cd038f8f
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java
@@ -0,0 +1,64 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.regularizers;
+
+import org.tensorflow.Operand;
+import org.tensorflow.framework.losses.Loss;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TNumber;
+
+/**
+ * A Regularizer call wrapped as a Loss instance
+ *
+ * This class facilitates using a regularizer as a loss, only sampleWeights
are
+ * regularized.
+ */
+class RegularizerLoss extends Loss {
+
+ private final Regularizer regularizer;
+
+ /**
+ * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link
+ * Loss#REDUCTION_DEFAULT}
+ *
+ * @param tf the TensorFlow Ops
+ * @param regularizer the regularizer used to calculate the loss
+ */
+ public RegularizerLoss(Ops tf, Regularizer regularizer) {
+ this(tf, null, regularizer);
+ }
+
+ /**
+ * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT}
+ *
+ * @param tf the TensorFlow Ops
+ * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}.
+ * @param regularizer the regularizer used to calculate the loss
+ */
+ public RegularizerLoss(Ops tf, String name, Regularizer regularizer) {
+ super(tf, name);
+ this.regularizer = regularizer;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(
+ Operand extends TNumber> labels, Operand predictions, Operand sampleWeights) {
+ if (sampleWeights == null) {
+ throw new IllegalArgumentException("sampleWeights cannot be null");
+ }
+ return regularizer.call(sampleWeights);
+ }
+}
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java
new file mode 100644
index 00000000000..63ecc155fd1
--- /dev/null
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java
@@ -0,0 +1,63 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.regularizers;
+
+import org.tensorflow.framework.utils.ND;
+import org.tensorflow.ndarray.DoubleNdArray;
+import org.tensorflow.ndarray.FloatNdArray;
+import org.tensorflow.ndarray.StdArrays;
+
+public class CommonTest {
+
+ protected float regularizeL1L2(float[][] w, float l1, float l2) {
+ return regularizeL1(w, l1) + regularizeL2(w, l2);
+ }
+
+ protected float regularizeL1(float[][] w, float l1) {
+ FloatNdArray fa = StdArrays.ndCopyOf(w);
+ fa = ND.abs(fa);
+ FloatNdArray sum = ND.sum(fa);
+ FloatNdArray mul = ND.mul(sum, l1);
+ return mul.getFloat();
+ }
+
+ protected float regularizeL2(float[][] w, float l2) {
+ FloatNdArray fa = StdArrays.ndCopyOf(w);
+ fa = ND.square(fa);
+ FloatNdArray sum = ND.sum(fa);
+ FloatNdArray mul = ND.mul(sum, l2);
+ return mul.getFloat();
+ }
+
+ protected double regularizeL1L2(double[][] w, float l1, float l2) {
+ return regularizeL1(w, l1) + regularizeL2(w, l2);
+ }
+
+ protected double regularizeL1(double[][] w, float l1) {
+ DoubleNdArray fa = StdArrays.ndCopyOf(w);
+ fa = ND.abs(fa);
+ DoubleNdArray sum = ND.sum(fa);
+ DoubleNdArray mul = ND.mul(sum, l1);
+ return mul.getDouble();
+ }
+
+ protected double regularizeL2(double[][] w, float l2) {
+ DoubleNdArray fa = StdArrays.ndCopyOf(w);
+ fa = ND.square(fa);
+ DoubleNdArray sum = ND.sum(fa);
+ DoubleNdArray mul = ND.mul(sum, l2);
+ return mul.getDouble();
+ }
+}
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java
new file mode 100644
index 00000000000..181ae367f07
--- /dev/null
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java
@@ -0,0 +1,124 @@
+package org.tensorflow.framework.regularizers;
+
+import org.junit.jupiter.api.Test;
+import org.tensorflow.Operand;
+import org.tensorflow.framework.utils.TestSession;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.TFloat64;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+class L1L2Test extends CommonTest {
+ private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH};
+
+ @Test
+ public void testCreate() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1L2 instance = new L1L2(tf, 0.2f, 0.3f);
+ assertEquals(0.2f, instance.getL1());
+ assertEquals(0.3f, instance.getL2());
+
+ instance = new L1L2(tf, 0, 0);
+ assertEquals(0.f, instance.getL1());
+ assertEquals(0.f, instance.getL2());
+
+ instance = new L1L2(tf, 0.5f, 0);
+ assertEquals(0.5f, instance.getL1());
+ assertEquals(0.f, instance.getL2());
+
+ instance = new L1L2(tf, 0, 0.5f);
+ assertEquals(0.f, instance.getL1());
+ assertEquals(0.5f, instance.getL2());
+
+ instance = new L1L2(tf);
+ assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1());
+ assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2());
+ }
+ }
+
+ @Test
+ public void testCallDefaultsConstant() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1L2 instance = new L1L2(tf);
+ Operand result = instance.call(tf.constant(555f));
+ session.evaluate(3085.8f, result);
+ }
+ }
+
+ @Test
+ public void testCallL1L2_0() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1L2 instance = new L1L2(tf, 0, 0);
+ Operand weights =
+ tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}});
+ Operand result = instance.call(weights);
+ session.evaluate(0, result);
+ }
+ }
+
+ @Test
+ public void testCallL1L2TFloat32() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1L2 instance = new L1L2(tf, 0.01f, 0.02f);
+ float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ float expected = regularizeL1L2(w, 0.01f, 0.02f);
+ session.setEpsilon(.09f);
+ session.evaluate(expected, result);
+ }
+ }
+
+ @Test
+ public void testCallL1L2TFloat64() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1L2 instance = new L1L2(tf, 0.01f, 0.02f);
+ double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ double expected = regularizeL1L2(w, 0.01f, 0.02f);
+ session.setEpsilon(.09f);
+ session.evaluate(expected, result);
+ }
+ }
+
+ @Test
+ public void testCallL2_0() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1L2 instance = new L1L2(tf, 0.01f, 0);
+ float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ float expected = regularizeL1(w, 0.01f);
+ session.evaluate(expected, result);
+ }
+ }
+
+ @Test
+ public void testCallL1_0() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1L2 instance = new L1L2(tf, 0, 0.02f);
+ double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ double expected = regularizeL2(w, 0.02f);
+ session.setEpsilon(.01f);
+ session.evaluate(expected, result);
+ }
+ }
+}
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java
new file mode 100644
index 00000000000..0e42a257816
--- /dev/null
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java
@@ -0,0 +1,74 @@
+package org.tensorflow.framework.regularizers;
+
+import org.junit.jupiter.api.Test;
+import org.tensorflow.Operand;
+import org.tensorflow.framework.utils.TestSession;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.TFloat64;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+class L1Test extends CommonTest {
+ private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH};
+
+ @Test
+ public void testCreate() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1 instance = new L1(tf, 0.2f);
+ assertEquals(0.2f, instance.getL1());
+ assertEquals(0.f, instance.getL2());
+
+ instance = new L1(tf, 0f);
+ assertEquals(0.f, instance.getL1());
+ assertEquals(0.f, instance.getL2());
+
+ instance = new L1(tf);
+ assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1());
+ assertEquals(0.f, instance.getL2());
+ }
+ }
+
+ @Test
+ public void testCallL10() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1 instance = new L1(tf, 0.0f);
+ float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ session.evaluate(0f, result);
+ }
+ }
+
+ @Test
+ public void testCallL1TFloat32() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1 instance = new L1(tf);
+ float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY);
+ session.evaluate(expected, result);
+ }
+ }
+
+ @Test
+ public void testCallL1TFloat64() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1 instance = new L1(tf, 0.02f);
+ double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ double expected = regularizeL1(w, 0.02f);
+ session.evaluate(expected, result);
+ }
+ }
+}
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java
new file mode 100644
index 00000000000..aba036ee306
--- /dev/null
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java
@@ -0,0 +1,76 @@
+package org.tensorflow.framework.regularizers;
+
+import org.junit.jupiter.api.Test;
+import org.tensorflow.Operand;
+import org.tensorflow.framework.utils.TestSession;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.TFloat32;
+import org.tensorflow.types.TFloat64;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+class L2Test extends CommonTest {
+ private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH};
+
+ @Test
+ public void testCreate() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L2 instance = new L2(tf, 0.2f);
+ assertEquals(0.2f, instance.getL2());
+ assertEquals(0.f, instance.getL1());
+
+ instance = new L2(tf, 0f);
+ assertEquals(0.f, instance.getL2());
+ assertEquals(0.f, instance.getL1());
+
+ L2 instance64 = new L2(tf);
+ assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2());
+ assertEquals(0.f, instance64.getL1());
+ }
+ }
+
+ @Test
+ public void testCallL20() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L2 instance = new L2(tf, 0.0f);
+ float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ session.evaluate(0, result);
+ }
+ }
+
+ @Test
+ public void testCallL2TFloat32() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L2 instance = new L2(tf);
+ float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ float expected = regularizeL2(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY);
+ session.setEpsilon(.01f);
+ session.evaluate(expected, result);
+ }
+ }
+
+ @Test
+ public void testCallL2TFloat64() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L2 instance = new L2(tf, 0.02f);
+ double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}};
+ Operand weights = tf.constant(w);
+ Operand result = instance.call(weights);
+ double expected = regularizeL2(w, 0.02f);
+ session.setEpsilon(.01f);
+ session.evaluate(expected, result);
+ }
+ }
+}
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java
new file mode 100644
index 00000000000..fe2624cec3d
--- /dev/null
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java
@@ -0,0 +1,27 @@
+package org.tensorflow.framework.regularizers;
+
+import org.junit.jupiter.api.Test;
+import org.tensorflow.Operand;
+import org.tensorflow.framework.utils.TestSession;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.TFloat32;
+
+class RegularizerLossTest {
+ private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH};
+
+ @Test
+ public void testCreate() {
+ for (TestSession.Mode tfMode : tfModes)
+ try (TestSession session = TestSession.createTestSession(tfMode)) {
+ Ops tf = session.getTF();
+ L1L2 regularizer = new L1L2(tf, 0.01f, 0f);
+ float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}};
+ Operand weights = tf.constant(w);
+ Operand regularizerResult = regularizer.call(weights);
+ RegularizerLoss lossInstance = new RegularizerLoss(tf, regularizer);
+
+ Operand loss = lossInstance.call(null, null, weights);
+ session.evaluate(regularizerResult, loss);
+ }
+ }
+}
diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java
index ef8bb71d724..c0c0f12fbf9 100644
--- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java
+++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java
@@ -100,23 +100,6 @@ public static FloatNdArray sqrt(FloatNdArray a) {
return result;
}
- /**
- * Gets the square root of an array.
- *
- * @param a the array
- * @return the square root of the array.
- */
- public static DoubleNdArray sqrt(DoubleNdArray a) {
- DoubleNdArray result = NdArrays.ofDoubles(a.shape());
- int nDims = a.shape().numDimensions();
- a.elements(nDims - 1)
- .forEachIndexed(
- (idx, v) -> {
- result.setDouble(Math.sqrt(v.getDouble()), idx);
- });
- return result;
- }
-
/**
* Gets the square of an array.
*
@@ -315,6 +298,64 @@ public static FloatNdArray mul(float scalar, FloatNdArray a) {
return mul(a, scalar);
}
+ /**
+ * Multiply 2 arrays
+ *
+ * @param a the first array
+ * @param b the second array
+ * @return the resulting array from the muliply operation
+ */
+ public static DoubleNdArray mul(DoubleNdArray a, DoubleNdArray b) {
+ if (!a.shape().equals(b.shape()))
+ throw new IllegalArgumentException(
+ String.format(
+ "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape()));
+ boolean sameSize = a.shape().size() == b.shape().size();
+ DoubleNdArray result = NdArrays.ofDoubles(a.shape());
+ int nDims = a.shape().numDimensions();
+
+ a.elements(nDims - 1)
+ .forEachIndexed(
+ (idx, v) -> {
+ if (sameSize) {
+ result.setDouble(v.getDouble() * b.getDouble(idx), idx);
+ } else {
+ double value = v.getDouble() * b.getDouble(idx[0], 0L);
+ result.setDouble(value, idx);
+ }
+ });
+ return result;
+ }
+
+ /**
+ * Multiply an array with a scalar value
+ *
+ * @param a the array
+ * @param scalar the scalar value
+ * @return the resulting array from the Multiply operation
+ */
+ public static DoubleNdArray mul(DoubleNdArray a, float scalar) {
+ DoubleNdArray result = NdArrays.ofDoubles(a.shape());
+ if (a.shape().isScalar()) {
+ a.scalars().forEach(f -> result.setDouble(f.getDouble() * scalar));
+ } else {
+ a.scalars().forEachIndexed((idx, f) -> result.setDouble(f.getDouble() * scalar, idx));
+ }
+
+ return result;
+ }
+
+ /**
+ * Multiply a scalar value with an array
+ *
+ * @param scalar the scalar value
+ * @param a the array
+ * @return the resulting array from the Multiply operation
+ */
+ public static DoubleNdArray mul(float scalar, DoubleNdArray a) {
+ return mul(a, scalar);
+ }
+
/**
* Divide two arrays
*
@@ -487,7 +528,7 @@ public static FloatNdArray max(FloatNdArray a, FloatNdArray b) {
a.elements(nDims - 1)
.forEachIndexed(
(idx, v) -> {
- result.setFloat((float) Math.max(v.getFloat(), b.getFloat(idx)), idx);
+ result.setFloat(Math.max(v.getFloat(), b.getFloat(idx)), idx);
});
return result;
}
@@ -506,7 +547,7 @@ public static FloatNdArray max(FloatNdArray a, float scalar) {
a.elements(nDims - 1)
.forEachIndexed(
(idx, v) -> {
- result.setFloat((float) Math.max(v.getFloat(), scalar), idx);
+ result.setFloat(Math.max(v.getFloat(), scalar), idx);
});
return result;
}
@@ -539,7 +580,7 @@ public static FloatNdArray min(FloatNdArray a, FloatNdArray b) {
a.elements(nDims - 1)
.forEachIndexed(
(idx, v) -> {
- result.setFloat((float) Math.min(v.getFloat(), b.getFloat(idx)), idx);
+ result.setFloat(Math.min(v.getFloat(), b.getFloat(idx)), idx);
});
return result;
}
@@ -558,7 +599,7 @@ public static FloatNdArray min(FloatNdArray a, float scalar) {
a.elements(nDims - 1)
.forEachIndexed(
(idx, v) -> {
- result.setFloat((float) Math.min(v.getFloat(), scalar), idx);
+ result.setFloat(Math.min(v.getFloat(), scalar), idx);
});
return result;
}
@@ -583,20 +624,20 @@ public static FloatNdArray min(float scalar, FloatNdArray a) {
*/
public static FloatNdArray abs(FloatNdArray a) {
FloatNdArray result = NdArrays.ofFloats(a.shape());
- a.scalars().forEachIndexed((idx, f) -> result.setFloat((float) Math.abs(f.getFloat()), idx));
+ a.scalars().forEachIndexed((idx, f) -> result.setFloat(Math.abs(f.getFloat()), idx));
return result;
}
/**
- * Sum all elements of an array
+ * Get the absolute value of each member of the array
*
* @param a the array
- * @return an a array with one element containing the sum.
+ * @return the array with the absolute value of each item.
*/
- public static FloatNdArray sum(FloatNdArray a) {
- AtomicReference sum = new AtomicReference<>(0.f);
- a.scalars().forEach(f -> sum.set(sum.get() + f.getFloat()));
- return NdArrays.scalarOf(sum.get());
+ public static DoubleNdArray abs(DoubleNdArray a) {
+ DoubleNdArray result = NdArrays.ofDoubles(a.shape());
+ a.scalars().forEachIndexed((idx, f) -> result.setDouble(Math.abs(f.getDouble()), idx));
+ return result;
}
/**
@@ -605,9 +646,9 @@ public static FloatNdArray sum(FloatNdArray a) {
* @param a the array
* @return an a array with one element containing the sum.
*/
- public static DoubleNdArray sum(DoubleNdArray a) {
- AtomicReference sum = new AtomicReference(0D);
- a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble()));
+ public static FloatNdArray sum(FloatNdArray a) {
+ AtomicReference sum = new AtomicReference<>(0.f);
+ a.scalars().forEach(f -> sum.set(sum.get() + f.getFloat()));
return NdArrays.scalarOf(sum.get());
}
@@ -622,17 +663,6 @@ public static FloatNdArray sum(FloatNdArray a, int axis) {
return sum(a, axis, false);
}
- /**
- * Sum all elements of an array based on the specified axis
- *
- * @param a the array
- * @param axis the axis to sum
- * @return an a array the sum over the axis less the diemsnion
- */
- public static DoubleNdArray sum(DoubleNdArray a, int axis) {
- return sum(a, axis, false);
- }
-
/**
* Sum all elements of an array based on the specified axis
*
@@ -676,6 +706,58 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) {
* Sum all elements of an array based on the specified axis
*
* @param a the array
+ * @param axes the axis to sum
+ * @param keepDims indicates whether the dimensions over the sum should be kept or not.
+ * @return an a array the sum over the axis
+ */
+ public static FloatNdArray sum(FloatNdArray a, Integer[] axes, boolean keepDims) {
+ Shape shape = a.shape();
+ if (axes == null) {
+ FloatNdArray result = sum(a);
+ if (keepDims) {
+ float scalar = result.getFloat(0);
+ long[] dims = {1, 1};
+ Shape bShape = Shape.of(dims);
+ FloatNdArray resultK = NdArrays.ofFloats(bShape);
+ resultK.setFloat(scalar, 0, 0);
+ return resultK;
+ }
+ return result;
+ } else if (axes.length == 1) {
+ return sum(a, axes[0], keepDims);
+ } else {
+ // TODO
+ throw new UnsupportedOperationException("Multi Axis Not implemented Yet");
+ }
+ }
+
+ /**
+ * Sum all elements of an array
+ *
+ * @param a the array
+ * @return an a array with one element containing the sum.
+ */
+ public static DoubleNdArray sum(DoubleNdArray a) {
+ AtomicReference sum = new AtomicReference<>(0.);
+ a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble()));
+ return NdArrays.scalarOf(sum.get());
+ }
+
+ /**
+ * Sum all elements of an array based on the specified axis
+ *
+ * @param a the array
+ * @param axis the axis to sum
+ * @return an a array the sum over the axis less the diemsnion
+ */
+ public static DoubleNdArray sum(DoubleNdArray a, int axis) {
+ return sum(a, axis, false);
+ }
+
+ /**
+ * Sum all elements of an array over on the specified axis
+ *
+ * @param a the array
* @param axis the axis to sum
* @param keepDims indicates whether the dimensions over the sum should be kept or not.
* @return an a array the sum over the axis
@@ -719,16 +801,16 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) {
* @param keepDims indicates whether the dimensions over the sum should be kept or not.
* @return an a array the sum over the axis
*/
- public static FloatNdArray sum(FloatNdArray a, Integer[] axes, boolean keepDims) {
+ public static DoubleNdArray sum(DoubleNdArray a, Integer[] axes, boolean keepDims) {
Shape shape = a.shape();
if (axes == null) {
- FloatNdArray result = sum(a);
+ DoubleNdArray result = sum(a);
if (keepDims) {
- float scalar = result.getFloat(0);
+ double scalar = result.getDouble(0);
long[] dims = {1, 1};
Shape bShape = Shape.of(dims);
- FloatNdArray resultK = NdArrays.ofFloats(bShape);
- resultK.setFloat(scalar, 0, 0);
+ DoubleNdArray resultK = NdArrays.ofDoubles(bShape);
+ resultK.setDouble(scalar, 0, 0);
return resultK;
}
return result;