From 5dd4b30a037369c6e73e8f2ef039c277d9a29ebd Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 30 Apr 2021 14:58:24 -0400 Subject: [PATCH 1/3] Setting all the optimizers to have useLocking = True, like Keras. Adding a determinism test that's currently failing. --- .../framework/optimizers/AdaDelta.java | 4 +- .../framework/optimizers/AdaGrad.java | 6 +- .../framework/optimizers/AdaGradDA.java | 4 +- .../tensorflow/framework/optimizers/Adam.java | 4 +- .../framework/optimizers/Adamax.java | 3 +- .../framework/optimizers/GradientDescent.java | 6 +- .../framework/optimizers/Momentum.java | 3 +- .../framework/optimizers/RMSProp.java | 8 +- .../optimizers/GradientDescentTest.java | 137 ++++++++++++++++++ 9 files changed, 166 insertions(+), 9 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java index aadbfeea54b..e5bab9228b4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaDelta.java @@ -20,6 +20,7 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.op.train.ApplyAdadelta; import org.tensorflow.types.family.TType; import java.util.List; @@ -160,7 +161,8 @@ protected Op applyDense(Output gradient, Output variable tf.dtypes.cast(tf.constant(learningRate), gradient.type()), tf.dtypes.cast(tf.constant(rho), gradient.type()), tf.dtypes.cast(tf.constant(epsilon), gradient.type()), - gradient); + gradient, + ApplyAdadelta.useLocking(true)); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index 2dd05ef31b3..a4e42e5cf72 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -19,6 +19,7 @@ import org.tensorflow.Operand; import org.tensorflow.Output; import org.tensorflow.op.Op; +import org.tensorflow.op.train.ApplyAdagrad; import org.tensorflow.op.core.Variable; import org.tensorflow.types.family.TType; @@ -42,6 +43,9 @@ public class AdaGrad extends Optimizer { public static final float LEARNING_RATE_DEFAULT = 0.001f; public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f; + private static final ApplyAdagrad.Options[] opts = new ApplyAdagrad.Options[]{ + ApplyAdagrad.updateSlots(true),ApplyAdagrad.useLocking(true)}; + private final float learningRate; private final float initialAccumulatorValue; @@ -140,7 +144,7 @@ private void createAdaGradSlot(Output v) { protected Op applyDense(Output gradient, Output variable) { Variable slot = getSlot(variable, ACCUMULATOR).get(); return tf.train.applyAdagrad( - variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient); + variable, slot, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient, opts); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java index 7114c33339f..64473b00f69 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGradDA.java @@ -22,6 +22,7 @@ import org.tensorflow.op.Op; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Variable; +import org.tensorflow.op.train.ApplyAdagradDa; import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; @@ -219,7 +220,8 @@ protected Op applyDense(Output gradient, Output variable tf.dtypes.cast(tf.constant(learningRate), gradient.type()), tf.dtypes.cast(tf.constant(l1Strength), gradient.type()), tf.dtypes.cast(tf.constant(l2Strength), gradient.type()), - globalStep); + globalStep, + ApplyAdagradDa.useLocking(true)); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java index 72598d12543..ce581e41397 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adam.java @@ -26,6 +26,7 @@ import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; import org.tensorflow.op.core.Variable; +import org.tensorflow.op.train.ApplyAdam; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -237,7 +238,8 @@ protected Op applyDense(Output gradient, Output variable tf.dtypes.cast(betaOneConst, gradient.type()), tf.dtypes.cast(betaTwoConst, gradient.type()), tf.dtypes.cast(epsilonConst, gradient.type()), - gradient); + gradient, + ApplyAdam.useLocking(true)); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java index 0ecc1ac1451..70b1497c2d8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Adamax.java @@ -170,7 +170,8 @@ protected Op applyDense(Output gradient, Output variable tf.dtypes.cast(betaOneConst, gradient.type()), tf.dtypes.cast(betaTwoConst, gradient.type()), tf.dtypes.cast(epsilonConst, gradient.type()), - gradient); + gradient, + ApplyAdaMax.useLocking(true)); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java index a373b2e5b55..7e2ec9593ed 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/GradientDescent.java @@ -18,6 +18,7 @@ import org.tensorflow.Graph; import org.tensorflow.Output; import org.tensorflow.op.Op; +import org.tensorflow.op.train.ApplyGradientDescent; import org.tensorflow.types.family.TType; /** @@ -66,7 +67,10 @@ public GradientDescent(Graph graph, String name, float learningRate) { @Override protected Op applyDense(Output gradient, Output variable) { return tf.train.applyGradientDescent( - variable, tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient); + variable, + tf.dtypes.cast(tf.constant(learningRate), gradient.type()), + gradient, + ApplyGradientDescent.useLocking(true)); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java index f6640409d60..ca53bd0c7e8 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/Momentum.java @@ -139,7 +139,8 @@ protected Op applyDense(Output gradient, Output variable tf.dtypes.cast(tf.constant(learningRate), gradient.type()), gradient, tf.dtypes.cast(tf.constant(momentum), gradient.type()), - ApplyMomentum.useNesterov(useNesterov)); + ApplyMomentum.useNesterov(useNesterov), + ApplyMomentum.useLocking(true)); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java index e86e64971a4..79ced52dc08 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/RMSProp.java @@ -20,6 +20,8 @@ import org.tensorflow.Output; import org.tensorflow.op.Op; import org.tensorflow.op.core.Variable; +import org.tensorflow.op.train.ApplyCenteredRmsProp; +import org.tensorflow.op.train.ApplyRmsProp; import org.tensorflow.types.family.TType; import java.util.List; @@ -202,7 +204,8 @@ protected Op applyDense(Output gradient, Output variable tf.dtypes.cast(tf.constant(decay), gradient.type()), tf.dtypes.cast(tf.constant(momentum), gradient.type()), tf.dtypes.cast(tf.constant(epsilon), gradient.type()), - gradient); + gradient, + ApplyCenteredRmsProp.useLocking(true)); } return tf.train.applyRmsProp( variable, @@ -212,7 +215,8 @@ protected Op applyDense(Output gradient, Output variable tf.dtypes.cast(tf.constant(decay), gradient.type()), tf.dtypes.cast(tf.constant(momentum), gradient.type()), tf.dtypes.cast(tf.constant(epsilon), gradient.type()), - gradient); + gradient, + ApplyRmsProp.useLocking(true)); } /** {@inheritDoc} */ diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index aefcc537979..8c6bb5ac668 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -2,13 +2,25 @@ import org.junit.jupiter.api.*; import org.tensorflow.Graph; +import org.tensorflow.Session; +import org.tensorflow.Tensor; +import org.tensorflow.framework.initializers.Glorot; +import org.tensorflow.framework.initializers.VarianceScaling; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; +import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Op; import org.tensorflow.op.Ops; import org.tensorflow.op.core.Assign; import org.tensorflow.op.core.Constant; +import org.tensorflow.op.core.Init; +import org.tensorflow.op.core.Placeholder; import org.tensorflow.op.core.Variable; +import org.tensorflow.op.math.Add; +import org.tensorflow.op.math.Mean; +import org.tensorflow.op.nn.Relu; +import org.tensorflow.proto.framework.ConfigProto; +import org.tensorflow.proto.framework.GraphDef; import org.tensorflow.types.TFloat32; import org.tensorflow.types.family.TType; @@ -97,4 +109,129 @@ public void testBasic() { session.evaluate(expectedVar1, var1); } } + + // This test fails due to initialization and gradient issues. It should not, but it seems to be a + // problem + // in TF-core. + @Disabled + @Test + public void testDeterminism() { + ConfigProto config = + ConfigProto.newBuilder() + .setIntraOpParallelismThreads(1) + .setInterOpParallelismThreads(1) + .build(); + + GraphDef def; + String initName; + String trainName; + + String fcWeightName, fcBiasName, outputWeightName, outputBiasName; + + try (Graph g = new Graph()) { + Ops tf = Ops.create(g); + + Glorot initializer = + new Glorot<>(tf, VarianceScaling.Distribution.TRUNCATED_NORMAL, 1L); + // Inputs + Placeholder input = + tf.withName("input").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 20))); + + // Fully connected layer + Variable fcWeights = + tf.variable(initializer.call(tf.array(20L, 200L), TFloat32.class)); + fcWeightName = fcWeights.op().name(); + Variable fcBiases = tf.variable(tf.fill(tf.array(200), tf.constant(0.1f))); + fcBiasName = fcBiases.op().name(); + Relu relu = tf.nn.relu(tf.math.add(tf.linalg.matMul(input, fcWeights), fcBiases)); + + // Output layer + Variable outputWeights = + tf.variable(initializer.call(tf.array(200L, 2L), TFloat32.class)); + outputWeightName = outputWeights.op().name(); + Variable outputBiases = tf.variable(tf.fill(tf.array(2L), tf.constant(0.1f))); + outputBiasName = outputBiases.op().name(); + Add output = tf.math.add(tf.linalg.matMul(relu, outputWeights), outputBiases); + + // Loss + Placeholder placeholder = + tf.withName("output").placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 2))); + Mean loss = + tf.math.mean( + tf.nn.raw.softmaxCrossEntropyWithLogits(output, placeholder).loss(), tf.constant(0)); + + GradientDescent gd = new GradientDescent(g, 0.1f); + Op trainingOp = gd.minimize(loss); + trainName = trainingOp.op().name(); + + // Create the init op + Init init = tf.init(); + initName = init.op().name(); + + def = g.toGraphDef(); + } + + float[] data = + new float[] { + 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, -8.0f, -9.0f, 10.0f, 11.0f, 12.0f, 13.0f, + -14.0f, -15.0f, 0.16f, 0.17f, 0.18f, 1.9f, 0.2f + }; + TFloat32 dataTensor = TFloat32.tensorOf(Shape.of(1, 20), DataBuffers.of(data)); + float[] target = new float[] {0.0f, 1.0f}; + TFloat32 targetTensor = TFloat32.tensorOf(Shape.of(1, 2), DataBuffers.of(target)); + + int numRuns = 10; + List> initialized = new ArrayList<>(numRuns); + List> trained = new ArrayList<>(numRuns); + + for (int i = 0; i < numRuns; i++) { + try (Graph g = new Graph(); + Session s = new Session(g, config)) { + g.importGraphDef(def); + s.run(initName); + + initialized.add( + s.runner() + .fetch(fcWeightName) + .fetch(fcBiasName) + .fetch(outputWeightName) + .fetch(outputBiasName) + .run()); + + s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .run(); + + trained.add( + s.runner() + .fetch(fcWeightName) + .fetch(fcBiasName) + .fetch(outputWeightName) + .fetch(outputBiasName) + .run()); + } + } + + for (int i = 1; i < numRuns; i++) { + assertEquals( + initialized.get(0), + initialized.get(i), + "Variables not initialized identically (0," + i + ")"); + assertEquals( + trained.get(0), trained.get(i), "Variables not trained identically (0," + i + ")"); + } + + for (List curInit : initialized) { + for (Tensor t : curInit) { + t.close(); + } + } + for (List curTrained : trained) { + for (Tensor t : curTrained) { + t.close(); + } + } + } } From c0fc3510837a07f55adf924c82ce4c4b2c50de6d Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Fri, 30 Apr 2021 15:32:54 -0400 Subject: [PATCH 2/3] More work on the GradientDescentTest. --- .../optimizers/GradientDescentTest.java | 42 ++++++++++++++++--- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 8c6bb5ac668..86308b267c7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -7,6 +7,7 @@ import org.tensorflow.framework.initializers.Glorot; import org.tensorflow.framework.initializers.VarianceScaling; import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.ndarray.FloatNdArray; import org.tensorflow.ndarray.Shape; import org.tensorflow.ndarray.buffer.DataBuffers; import org.tensorflow.op.Op; @@ -25,8 +26,10 @@ import org.tensorflow.types.family.TType; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; /** Test cases for GradientDescent Optimizer */ @@ -125,6 +128,7 @@ public void testDeterminism() { GraphDef def; String initName; String trainName; + String lossName; String fcWeightName, fcBiasName, outputWeightName, outputBiasName; @@ -159,8 +163,9 @@ public void testDeterminism() { Mean loss = tf.math.mean( tf.nn.raw.softmaxCrossEntropyWithLogits(output, placeholder).loss(), tf.constant(0)); + lossName = loss.op().name(); - GradientDescent gd = new GradientDescent(g, 0.1f); + GradientDescent gd = new GradientDescent(g, 10.0f); Op trainingOp = gd.minimize(loss); trainName = trainingOp.op().name(); @@ -177,12 +182,14 @@ public void testDeterminism() { -14.0f, -15.0f, 0.16f, 0.17f, 0.18f, 1.9f, 0.2f }; TFloat32 dataTensor = TFloat32.tensorOf(Shape.of(1, 20), DataBuffers.of(data)); - float[] target = new float[] {0.0f, 1.0f}; + float[] target = new float[] {0.2f, 0.8f}; TFloat32 targetTensor = TFloat32.tensorOf(Shape.of(1, 2), DataBuffers.of(target)); - int numRuns = 10; + int numRuns = 20; List> initialized = new ArrayList<>(numRuns); List> trained = new ArrayList<>(numRuns); + float[] initialLoss = new float[numRuns]; + float[] postTrainingLoss = new float[numRuns]; for (int i = 0; i < numRuns; i++) { try (Graph g = new Graph(); @@ -197,12 +204,16 @@ public void testDeterminism() { .fetch(outputWeightName) .fetch(outputBiasName) .run()); + System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3))); - s.runner() + TFloat32 lossVal = (TFloat32) s.runner() .addTarget(trainName) .feed("input", dataTensor) .feed("output", targetTensor) - .run(); + .fetch(lossName) + .run().get(0); + initialLoss[i] = lossVal.getFloat(); + lossVal.close(); trained.add( s.runner() @@ -211,10 +222,25 @@ public void testDeterminism() { .fetch(outputWeightName) .fetch(outputBiasName) .run()); + System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3))); + System.out.println("Trained - " + ndArrToString((TFloat32)trained.get(i).get(3))); + + lossVal = (TFloat32) s.runner() + .addTarget(trainName) + .feed("input", dataTensor) + .feed("output", targetTensor) + .fetch(lossName) + .run().get(0); + postTrainingLoss[i] = lossVal.getFloat(); + lossVal.close(); } } for (int i = 1; i < numRuns; i++) { + assertEquals(initialLoss[0],initialLoss[i]); + assertEquals(postTrainingLoss[0],postTrainingLoss[i]); + // Because the weights are references not copies. + assertEquals(initialized.get(i),trained.get(i)); assertEquals( initialized.get(0), initialized.get(i), @@ -234,4 +260,10 @@ public void testDeterminism() { } } } + + private static String ndArrToString(FloatNdArray ndarray) { + StringBuffer sb = new StringBuffer(); + ndarray.scalars().forEachIndexed((idx,array) -> sb.append(Arrays.toString(idx)).append(" = ").append(array.getFloat()).append("\n")); + return sb.toString(); + } } From c11c911cb1180e3e225bd513ac8137588409f33c Mon Sep 17 00:00:00 2001 From: Adam Pocock Date: Mon, 3 May 2021 21:41:36 -0400 Subject: [PATCH 3/3] Tidying up the test. --- .../framework/optimizers/AdaGrad.java | 2 +- .../optimizers/GradientDescentTest.java | 21 ++++++------------- 2 files changed, 7 insertions(+), 16 deletions(-) diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java index a4e42e5cf72..66a170efcc2 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java @@ -44,7 +44,7 @@ public class AdaGrad extends Optimizer { public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f; private static final ApplyAdagrad.Options[] opts = new ApplyAdagrad.Options[]{ - ApplyAdagrad.updateSlots(true),ApplyAdagrad.useLocking(true)}; + ApplyAdagrad.updateSlots(true), ApplyAdagrad.useLocking(true)}; private final float learningRate; diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java index 86308b267c7..d6786b71972 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java @@ -113,9 +113,9 @@ public void testBasic() { } } - // This test fails due to initialization and gradient issues. It should not, but it seems to be a - // problem - // in TF-core. + // This test fails due to incorrect gradients being generated some of the time, when + // using an identical graph on identical data. It should not, but it seems to be a + // problem in TF-core. @Disabled @Test public void testDeterminism() { @@ -204,7 +204,6 @@ public void testDeterminism() { .fetch(outputWeightName) .fetch(outputBiasName) .run()); - System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3))); TFloat32 lossVal = (TFloat32) s.runner() .addTarget(trainName) @@ -222,8 +221,6 @@ public void testDeterminism() { .fetch(outputWeightName) .fetch(outputBiasName) .run()); - System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3))); - System.out.println("Trained - " + ndArrToString((TFloat32)trained.get(i).get(3))); lossVal = (TFloat32) s.runner() .addTarget(trainName) @@ -237,10 +234,10 @@ public void testDeterminism() { } for (int i = 1; i < numRuns; i++) { - assertEquals(initialLoss[0],initialLoss[i]); - assertEquals(postTrainingLoss[0],postTrainingLoss[i]); + assertEquals(initialLoss[0], initialLoss[i]); + assertEquals(postTrainingLoss[0], postTrainingLoss[i]); // Because the weights are references not copies. - assertEquals(initialized.get(i),trained.get(i)); + assertEquals(initialized.get(i), trained.get(i)); assertEquals( initialized.get(0), initialized.get(i), @@ -260,10 +257,4 @@ public void testDeterminism() { } } } - - private static String ndArrToString(FloatNdArray ndarray) { - StringBuffer sb = new StringBuffer(); - ndarray.scalars().forEachIndexed((idx,array) -> sb.append(Arrays.toString(idx)).append(" = ").append(array.getFloat()).append("\n")); - return sb.toString(); - } }