diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java new file mode 100644 index 00000000000..7c8c2a1360a --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1.java @@ -0,0 +1,47 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; + +/** + * A regularizer that applies an L1 or Lasso(least absolute shrinkage and selection operator) + * Regression, regularization penalty. + * + *

The L1 regularization penalty is computed as: loss = l1 * reduceSum(abs(x)) + */ +public class L1 extends L1L2 { + + /** + * Create a regularizer that applies an L1 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L1(Ops tf) { + this(tf, DEFAULT_REGULARIZATION_PENALTY); + } + + /** + * Create a regularizer that applies an L1 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l1 the L1 regularization penalty + * @throws IllegalArgumentException if the l1 regularization factor is NaN or is infinite. + */ + public L1(Ops tf, float l1) { + super(tf, l1, 0f); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java new file mode 100644 index 00000000000..29e411f9897 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L1L2.java @@ -0,0 +1,120 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.impl.LossesHelper; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A regularizer that applies both L1 and L2 regularization penalties. + * + *

The L1 regularization penalty is computed as: + * + *

loss = l1 * reduceSum(abs(x))
+ * + *

The L2 regularization penalty is computed as + * + *

loss = l2 * reduceSum(square(x))
+ * + */ +public class L1L2 extends Regularizer { + + private final float l1; + private final float l2; + + /** + * Creates an L1L2 regularizer with no l1 or l2 penalty with zero penalty + * + * @param tf the TensorFlow Ops + */ + public L1L2(Ops tf) { + this(tf, DEFAULT_REGULARIZATION_PENALTY, DEFAULT_REGULARIZATION_PENALTY); + } + + /** + * Creates an L1L2 regularizer + * + * @param tf the TensorFlow Ops + * @param l1 L1 regularization factor, if null it is set to 0. + * @param l2 L2 regularization factor, if null it is set to 0. + * @throws IllegalArgumentException if the l1 or l2 regularization factor is {@link Float#isNaN} + * of {@link Float#isInfinite} + */ + public L1L2(Ops tf, float l1, float l2) { + super(tf); + if (Float.isNaN(l1) || Float.isInfinite(l1)) { + throw new IllegalArgumentException( + String.format( + "L1 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l1)); + } + this.l1 = l1; + + if (Float.isNaN(l2) || Float.isInfinite(l2)) { + throw new IllegalArgumentException( + String.format( + "L2 Value: %f is not a valid regularization penalty number, a positive/negative infinity or NaN is not a property value", + l2)); + } + this.l2 = l2; + } + + + /** {@inheritDoc} */ + @Override + public Operand call(Operand input) { + Ops tf = getTF(); + if (this.getL1() == 0f && this.getL2() == 0f) { + return tf.dtypes.cast(tf.constant(0), input.type()); + } + Operand regularization = tf.dtypes.cast(tf.constant(0), input.type()); + + if (this.getL1() != 0.f) { + Operand l1Op = tf.dtypes.cast(tf.constant(this.getL1()), input.type()); + Operand abs = tf.math.abs(input); + Operand reduceSum = tf.reduceSum(abs, LossesHelper.allAxes(tf, input)); + regularization = tf.math.add(regularization, tf.math.mul(l1Op, reduceSum)); + } + + if (this.getL2() != 0.f) { + Operand l2Op = tf.dtypes.cast(tf.constant(this.getL2()), input.type()); + Operand sqr = tf.math.square(input); + Operand reduceSum = tf.reduceSum(sqr, LossesHelper.allAxes(tf, input)); + regularization = tf.math.add(regularization, tf.math.mul(l2Op, reduceSum)); + } + + return regularization; + } + + /** + * Gets the L1 regularization factor + * + * @return the L1 regularization factor + */ + public float getL1() { + return l1; + } + + /** + * Gets the L2 regularization factor + * + * @return the L2 regularization factor + */ + public float getL2() { + return l2; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java new file mode 100644 index 00000000000..7b8f5b28a70 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/L2.java @@ -0,0 +1,46 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.op.Ops; + +/** + * A regularizer that applies a L2 (Ridge Regression) regularization penalty. + * + *

The L2 regularization penalty is computed as: loss = l2 * reduceSum(square(x)) + */ +public class L2 extends L1L2 { + + /** + * Create a regularizer that applies an L2 regularization penalty of {@link + * #DEFAULT_REGULARIZATION_PENALTY} + * + * @param tf the TensorFlow Ops + */ + public L2(Ops tf) { + this(tf, DEFAULT_REGULARIZATION_PENALTY); + } + + /** + * Create a regularizer that applies an L1 regularization penalty + * + * @param tf the TensorFlow Ops + * @param l2 the L2 regularization penalty + * @throws IllegalArgumentException if the l2 regularization factor is NaN or is infinite. + */ + public L2(Ops tf, float l2) { + super(tf, 0f, l2); + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java new file mode 100644 index 00000000000..5d9ff0e3e10 --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/Regularizer.java @@ -0,0 +1,91 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * Base class for Regularizers + * + *

Regularizers allow you to apply penalties on layer parameters or layer activity during + * optimization. These penalties are summed into the loss function that the network optimizes. + */ +public abstract class Regularizer { + + public static final float DEFAULT_REGULARIZATION_PENALTY = 0.01f; + + private final Ops tf; + private final String name; + + /** + * Creates a Regularizer, using {@link Class#getSimpleName()} for the name + * + * @param tf the TensorFlow ops. + */ + protected Regularizer(Ops tf) { + this(tf, null); + } + /** + * Creates a Regularizer + * + * @param tf the TensorFlow ops. + * @param name the name of this regularizer, if null use {@link Class#getSimpleName()} for the + * name. + */ + protected Regularizer(Ops tf, String name) { + this.tf = tf; + this.name = name == null ? this.getClass().getSimpleName() : name; + } + + /** + * Returns this Regularizer as a Loss This is a convenience to use regularize a loss. Only + * sampleWeights are applied to the regularizer. + * + * @return this Regularizer as a Loss + */ + public Loss asLoss() { + return new RegularizerLoss(this.tf, this); + } + + /** + * Computes a regularization penalty from an input. + * + * @param input the weighted input + * @return the result of computing the regularization penalty + * @param the data type of the input and result + */ + public abstract Operand call(Operand input); + + /** + * Gets the TensorFlow Ops + * + * @return the TensorFlow Ops + */ + public Ops getTF() { + return tf; + } + + /** + * Gets the name for this regularizer + * + * @return the name for this regularizer + */ + public String getName() { + return name; + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java new file mode 100644 index 00000000000..582cd038f8f --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/regularizers/RegularizerLoss.java @@ -0,0 +1,64 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.Operand; +import org.tensorflow.framework.losses.Loss; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TNumber; + +/** + * A Regularizer call wrapped as a Loss instance + * + *

This class facilitates using a regularizer as a loss, only sampleWeights are + * regularized. + */ +class RegularizerLoss extends Loss { + + private final Regularizer regularizer; + + /** + * Creates a Loss using {@link Class#getSimpleName()} as the name and a Loss Reduction of {@link + * Loss#REDUCTION_DEFAULT} + * + * @param tf the TensorFlow Ops + * @param regularizer the regularizer used to calculate the loss + */ + public RegularizerLoss(Ops tf, Regularizer regularizer) { + this(tf, null, regularizer); + } + + /** + * Creates a Loss using a Loss Reduction of {@link Loss#REDUCTION_DEFAULT} + * + * @param tf the TensorFlow Ops + * @param name the name of this Loss, if null the name will be {@link Class#getSimpleName()}. + * @param regularizer the regularizer used to calculate the loss + */ + public RegularizerLoss(Ops tf, String name, Regularizer regularizer) { + super(tf, name); + this.regularizer = regularizer; + } + + /** {@inheritDoc} */ + @Override + public Operand call( + Operand labels, Operand predictions, Operand sampleWeights) { + if (sampleWeights == null) { + throw new IllegalArgumentException("sampleWeights cannot be null"); + } + return regularizer.call(sampleWeights); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java new file mode 100644 index 00000000000..63ecc155fd1 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/CommonTest.java @@ -0,0 +1,63 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.regularizers; + +import org.tensorflow.framework.utils.ND; +import org.tensorflow.ndarray.DoubleNdArray; +import org.tensorflow.ndarray.FloatNdArray; +import org.tensorflow.ndarray.StdArrays; + +public class CommonTest { + + protected float regularizeL1L2(float[][] w, float l1, float l2) { + return regularizeL1(w, l1) + regularizeL2(w, l2); + } + + protected float regularizeL1(float[][] w, float l1) { + FloatNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.abs(fa); + FloatNdArray sum = ND.sum(fa); + FloatNdArray mul = ND.mul(sum, l1); + return mul.getFloat(); + } + + protected float regularizeL2(float[][] w, float l2) { + FloatNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.square(fa); + FloatNdArray sum = ND.sum(fa); + FloatNdArray mul = ND.mul(sum, l2); + return mul.getFloat(); + } + + protected double regularizeL1L2(double[][] w, float l1, float l2) { + return regularizeL1(w, l1) + regularizeL2(w, l2); + } + + protected double regularizeL1(double[][] w, float l1) { + DoubleNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.abs(fa); + DoubleNdArray sum = ND.sum(fa); + DoubleNdArray mul = ND.mul(sum, l1); + return mul.getDouble(); + } + + protected double regularizeL2(double[][] w, float l2) { + DoubleNdArray fa = StdArrays.ndCopyOf(w); + fa = ND.square(fa); + DoubleNdArray sum = ND.sum(fa); + DoubleNdArray mul = ND.mul(sum, l2); + return mul.getDouble(); + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java new file mode 100644 index 00000000000..181ae367f07 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1L2Test.java @@ -0,0 +1,124 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2(tf, 0.2f, 0.3f); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.3f, instance.getL2()); + + instance = new L1L2(tf, 0, 0); + assertEquals(0.f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1L2(tf, 0.5f, 0); + assertEquals(0.5f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1L2(tf, 0, 0.5f); + assertEquals(0.f, instance.getL1()); + assertEquals(0.5f, instance.getL2()); + + instance = new L1L2(tf); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL2()); + } + } + + @Test + public void testCallDefaultsConstant() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2(tf); + Operand result = instance.call(tf.constant(555f)); + session.evaluate(3085.8f, result); + } + } + + @Test + public void testCallL1L2_0() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2(tf, 0, 0); + Operand weights = + tf.constant(new float[][] {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + @Test + public void testCallL1L2TFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2(tf, 0.01f, 0.02f); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.09f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL1L2TFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2(tf, 0.01f, 0.02f); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL1L2(w, 0.01f, 0.02f); + session.setEpsilon(.09f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL2_0() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2(tf, 0.01f, 0); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, 0.01f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL1_0() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 instance = new L1L2(tf, 0, 0.02f); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java new file mode 100644 index 00000000000..0e42a257816 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L1Test.java @@ -0,0 +1,74 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L1Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1(tf, 0.2f); + assertEquals(0.2f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1(tf, 0f); + assertEquals(0.f, instance.getL1()); + assertEquals(0.f, instance.getL2()); + + instance = new L1(tf); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance.getL1()); + assertEquals(0.f, instance.getL2()); + } + } + + @Test + public void testCallL10() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1(tf, 0.0f); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + session.evaluate(0f, result); + } + } + + @Test + public void testCallL1TFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1(tf); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL1(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL1TFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1 instance = new L1(tf, 0.02f); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL1(w, 0.02f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java new file mode 100644 index 00000000000..aba036ee306 --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/L2Test.java @@ -0,0 +1,76 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class L2Test extends CommonTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2(tf, 0.2f); + assertEquals(0.2f, instance.getL2()); + assertEquals(0.f, instance.getL1()); + + instance = new L2(tf, 0f); + assertEquals(0.f, instance.getL2()); + assertEquals(0.f, instance.getL1()); + + L2 instance64 = new L2(tf); + assertEquals(Regularizer.DEFAULT_REGULARIZATION_PENALTY, instance64.getL2()); + assertEquals(0.f, instance64.getL1()); + } + } + + @Test + public void testCallL20() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2(tf, 0.0f); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + session.evaluate(0, result); + } + } + + @Test + public void testCallL2TFloat32() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2(tf); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + float expected = regularizeL2(w, Regularizer.DEFAULT_REGULARIZATION_PENALTY); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } + + @Test + public void testCallL2TFloat64() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L2 instance = new L2(tf, 0.02f); + double[][] w = {{1.0, 0.9, 0.8}, {1.2, 0.7, 1.1}}; + Operand weights = tf.constant(w); + Operand result = instance.call(weights); + double expected = regularizeL2(w, 0.02f); + session.setEpsilon(.01f); + session.evaluate(expected, result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java new file mode 100644 index 00000000000..fe2624cec3d --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/regularizers/RegularizerLossTest.java @@ -0,0 +1,27 @@ +package org.tensorflow.framework.regularizers; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; + +class RegularizerLossTest { + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + @Test + public void testCreate() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + L1L2 regularizer = new L1L2(tf, 0.01f, 0f); + float[][] w = {{1.0f, 0.9f, 0.8f}, {1.2f, 0.7f, 1.1f}}; + Operand weights = tf.constant(w); + Operand regularizerResult = regularizer.call(weights); + RegularizerLoss lossInstance = new RegularizerLoss(tf, regularizer); + + Operand loss = lossInstance.call(null, null, weights); + session.evaluate(regularizerResult, loss); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java index ef8bb71d724..c0c0f12fbf9 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/utils/ND.java @@ -100,23 +100,6 @@ public static FloatNdArray sqrt(FloatNdArray a) { return result; } - /** - * Gets the square root of an array. - * - * @param a the array - * @return the square root of the array. - */ - public static DoubleNdArray sqrt(DoubleNdArray a) { - DoubleNdArray result = NdArrays.ofDoubles(a.shape()); - int nDims = a.shape().numDimensions(); - a.elements(nDims - 1) - .forEachIndexed( - (idx, v) -> { - result.setDouble(Math.sqrt(v.getDouble()), idx); - }); - return result; - } - /** * Gets the square of an array. * @@ -315,6 +298,64 @@ public static FloatNdArray mul(float scalar, FloatNdArray a) { return mul(a, scalar); } + /** + * Multiply 2 arrays + * + * @param a the first array + * @param b the second array + * @return the resulting array from the muliply operation + */ + public static DoubleNdArray mul(DoubleNdArray a, DoubleNdArray b) { + if (!a.shape().equals(b.shape())) + throw new IllegalArgumentException( + String.format( + "ValueError: operands do not have same shapes %s %s ", a.shape(), b.shape())); + boolean sameSize = a.shape().size() == b.shape().size(); + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + int nDims = a.shape().numDimensions(); + + a.elements(nDims - 1) + .forEachIndexed( + (idx, v) -> { + if (sameSize) { + result.setDouble(v.getDouble() * b.getDouble(idx), idx); + } else { + double value = v.getDouble() * b.getDouble(idx[0], 0L); + result.setDouble(value, idx); + } + }); + return result; + } + + /** + * Multiply an array with a scalar value + * + * @param a the array + * @param scalar the scalar value + * @return the resulting array from the Multiply operation + */ + public static DoubleNdArray mul(DoubleNdArray a, float scalar) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + if (a.shape().isScalar()) { + a.scalars().forEach(f -> result.setDouble(f.getDouble() * scalar)); + } else { + a.scalars().forEachIndexed((idx, f) -> result.setDouble(f.getDouble() * scalar, idx)); + } + + return result; + } + + /** + * Multiply a scalar value with an array + * + * @param scalar the scalar value + * @param a the array + * @return the resulting array from the Multiply operation + */ + public static DoubleNdArray mul(float scalar, DoubleNdArray a) { + return mul(a, scalar); + } + /** * Divide two arrays * @@ -487,7 +528,7 @@ public static FloatNdArray max(FloatNdArray a, FloatNdArray b) { a.elements(nDims - 1) .forEachIndexed( (idx, v) -> { - result.setFloat((float) Math.max(v.getFloat(), b.getFloat(idx)), idx); + result.setFloat(Math.max(v.getFloat(), b.getFloat(idx)), idx); }); return result; } @@ -506,7 +547,7 @@ public static FloatNdArray max(FloatNdArray a, float scalar) { a.elements(nDims - 1) .forEachIndexed( (idx, v) -> { - result.setFloat((float) Math.max(v.getFloat(), scalar), idx); + result.setFloat(Math.max(v.getFloat(), scalar), idx); }); return result; } @@ -539,7 +580,7 @@ public static FloatNdArray min(FloatNdArray a, FloatNdArray b) { a.elements(nDims - 1) .forEachIndexed( (idx, v) -> { - result.setFloat((float) Math.min(v.getFloat(), b.getFloat(idx)), idx); + result.setFloat(Math.min(v.getFloat(), b.getFloat(idx)), idx); }); return result; } @@ -558,7 +599,7 @@ public static FloatNdArray min(FloatNdArray a, float scalar) { a.elements(nDims - 1) .forEachIndexed( (idx, v) -> { - result.setFloat((float) Math.min(v.getFloat(), scalar), idx); + result.setFloat(Math.min(v.getFloat(), scalar), idx); }); return result; } @@ -583,20 +624,20 @@ public static FloatNdArray min(float scalar, FloatNdArray a) { */ public static FloatNdArray abs(FloatNdArray a) { FloatNdArray result = NdArrays.ofFloats(a.shape()); - a.scalars().forEachIndexed((idx, f) -> result.setFloat((float) Math.abs(f.getFloat()), idx)); + a.scalars().forEachIndexed((idx, f) -> result.setFloat(Math.abs(f.getFloat()), idx)); return result; } /** - * Sum all elements of an array + * Get the absolute value of each member of the array * * @param a the array - * @return an a array with one element containing the sum. + * @return the array with the absolute value of each item. */ - public static FloatNdArray sum(FloatNdArray a) { - AtomicReference sum = new AtomicReference<>(0.f); - a.scalars().forEach(f -> sum.set(sum.get() + f.getFloat())); - return NdArrays.scalarOf(sum.get()); + public static DoubleNdArray abs(DoubleNdArray a) { + DoubleNdArray result = NdArrays.ofDoubles(a.shape()); + a.scalars().forEachIndexed((idx, f) -> result.setDouble(Math.abs(f.getDouble()), idx)); + return result; } /** @@ -605,9 +646,9 @@ public static FloatNdArray sum(FloatNdArray a) { * @param a the array * @return an a array with one element containing the sum. */ - public static DoubleNdArray sum(DoubleNdArray a) { - AtomicReference sum = new AtomicReference(0D); - a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble())); + public static FloatNdArray sum(FloatNdArray a) { + AtomicReference sum = new AtomicReference<>(0.f); + a.scalars().forEach(f -> sum.set(sum.get() + f.getFloat())); return NdArrays.scalarOf(sum.get()); } @@ -622,17 +663,6 @@ public static FloatNdArray sum(FloatNdArray a, int axis) { return sum(a, axis, false); } - /** - * Sum all elements of an array based on the specified axis - * - * @param a the array - * @param axis the axis to sum - * @return an a array the sum over the axis less the diemsnion - */ - public static DoubleNdArray sum(DoubleNdArray a, int axis) { - return sum(a, axis, false); - } - /** * Sum all elements of an array based on the specified axis * @@ -676,6 +706,58 @@ public static FloatNdArray sum(FloatNdArray a, int axis, boolean keepDims) { * Sum all elements of an array based on the specified axis * * @param a the array + * @param axes the axis to sum + * @param keepDims indicates whether the dimensions over the sum should be kept or not. + * @return an a array the sum over the axis + */ + public static FloatNdArray sum(FloatNdArray a, Integer[] axes, boolean keepDims) { + Shape shape = a.shape(); + if (axes == null) { + FloatNdArray result = sum(a); + if (keepDims) { + float scalar = result.getFloat(0); + long[] dims = {1, 1}; + Shape bShape = Shape.of(dims); + FloatNdArray resultK = NdArrays.ofFloats(bShape); + resultK.setFloat(scalar, 0, 0); + return resultK; + } + return result; + } else if (axes.length == 1) { + return sum(a, axes[0], keepDims); + } else { + // TODO + throw new UnsupportedOperationException("Multi Axis Not implemented Yet"); + } + } + + /** + * Sum all elements of an array + * + * @param a the array + * @return an a array with one element containing the sum. + */ + public static DoubleNdArray sum(DoubleNdArray a) { + AtomicReference sum = new AtomicReference<>(0.); + a.scalars().forEach(f -> sum.set(sum.get() + f.getDouble())); + return NdArrays.scalarOf(sum.get()); + } + + /** + * Sum all elements of an array based on the specified axis + * + * @param a the array + * @param axis the axis to sum + * @return an a array the sum over the axis less the diemsnion + */ + public static DoubleNdArray sum(DoubleNdArray a, int axis) { + return sum(a, axis, false); + } + + /** + * Sum all elements of an array over on the specified axis + * + * @param a the array * @param axis the axis to sum * @param keepDims indicates whether the dimensions over the sum should be kept or not. * @return an a array the sum over the axis @@ -719,16 +801,16 @@ public static DoubleNdArray sum(DoubleNdArray a, int axis, boolean keepDims) { * @param keepDims indicates whether the dimensions over the sum should be kept or not. * @return an a array the sum over the axis */ - public static FloatNdArray sum(FloatNdArray a, Integer[] axes, boolean keepDims) { + public static DoubleNdArray sum(DoubleNdArray a, Integer[] axes, boolean keepDims) { Shape shape = a.shape(); if (axes == null) { - FloatNdArray result = sum(a); + DoubleNdArray result = sum(a); if (keepDims) { - float scalar = result.getFloat(0); + double scalar = result.getDouble(0); long[] dims = {1, 1}; Shape bShape = Shape.of(dims); - FloatNdArray resultK = NdArrays.ofFloats(bShape); - resultK.setFloat(scalar, 0, 0); + DoubleNdArray resultK = NdArrays.ofDoubles(bShape); + resultK.setDouble(scalar, 0, 0); return resultK; } return result;