Skip to content

Commit c11c911

Browse files
committed
Tidying up the test.
1 parent c0fc351 commit c11c911

File tree

2 files changed

+7
-16
lines changed

2 files changed

+7
-16
lines changed

tensorflow-framework/src/main/java/org/tensorflow/framework/optimizers/AdaGrad.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public class AdaGrad extends Optimizer {
4444
public static final float INITIAL_ACCUMULATOR_DEFAULT = 0.01f;
4545

4646
private static final ApplyAdagrad.Options[] opts = new ApplyAdagrad.Options[]{
47-
ApplyAdagrad.updateSlots(true),ApplyAdagrad.useLocking(true)};
47+
ApplyAdagrad.updateSlots(true), ApplyAdagrad.useLocking(true)};
4848

4949
private final float learningRate;
5050

tensorflow-framework/src/test/java/org/tensorflow/framework/optimizers/GradientDescentTest.java

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,9 @@ public void testBasic() {
113113
}
114114
}
115115

116-
// This test fails due to initialization and gradient issues. It should not, but it seems to be a
117-
// problem
118-
// in TF-core.
116+
// This test fails due to incorrect gradients being generated some of the time, when
117+
// using an identical graph on identical data. It should not, but it seems to be a
118+
// problem in TF-core.
119119
@Disabled
120120
@Test
121121
public void testDeterminism() {
@@ -204,7 +204,6 @@ public void testDeterminism() {
204204
.fetch(outputWeightName)
205205
.fetch(outputBiasName)
206206
.run());
207-
System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3)));
208207

209208
TFloat32 lossVal = (TFloat32) s.runner()
210209
.addTarget(trainName)
@@ -222,8 +221,6 @@ public void testDeterminism() {
222221
.fetch(outputWeightName)
223222
.fetch(outputBiasName)
224223
.run());
225-
System.out.println("Initialized - " + ndArrToString((TFloat32)initialized.get(i).get(3)));
226-
System.out.println("Trained - " + ndArrToString((TFloat32)trained.get(i).get(3)));
227224

228225
lossVal = (TFloat32) s.runner()
229226
.addTarget(trainName)
@@ -237,10 +234,10 @@ public void testDeterminism() {
237234
}
238235

239236
for (int i = 1; i < numRuns; i++) {
240-
assertEquals(initialLoss[0],initialLoss[i]);
241-
assertEquals(postTrainingLoss[0],postTrainingLoss[i]);
237+
assertEquals(initialLoss[0], initialLoss[i]);
238+
assertEquals(postTrainingLoss[0], postTrainingLoss[i]);
242239
// Because the weights are references not copies.
243-
assertEquals(initialized.get(i),trained.get(i));
240+
assertEquals(initialized.get(i), trained.get(i));
244241
assertEquals(
245242
initialized.get(0),
246243
initialized.get(i),
@@ -260,10 +257,4 @@ public void testDeterminism() {
260257
}
261258
}
262259
}
263-
264-
private static String ndArrToString(FloatNdArray ndarray) {
265-
StringBuffer sb = new StringBuffer();
266-
ndarray.scalars().forEachIndexed((idx,array) -> sb.append(Arrays.toString(idx)).append(" = ").append(array.getFloat()).append("\n"));
267-
return sb.toString();
268-
}
269260
}

0 commit comments

Comments
 (0)