@@ -113,9 +113,9 @@ public void testBasic() {
113
113
}
114
114
}
115
115
116
- // This test fails due to initialization and gradient issues. It should not, but it seems to be a
117
- // problem
118
- // in TF-core.
116
+ // This test fails due to incorrect gradients being generated some of the time, when
117
+ // using an identical graph on identical data. It should not, but it seems to be a
118
+ // problem in TF-core.
119
119
@ Disabled
120
120
@ Test
121
121
public void testDeterminism () {
@@ -204,7 +204,6 @@ public void testDeterminism() {
204
204
.fetch (outputWeightName )
205
205
.fetch (outputBiasName )
206
206
.run ());
207
- System .out .println ("Initialized - " + ndArrToString ((TFloat32 )initialized .get (i ).get (3 )));
208
207
209
208
TFloat32 lossVal = (TFloat32 ) s .runner ()
210
209
.addTarget (trainName )
@@ -222,8 +221,6 @@ public void testDeterminism() {
222
221
.fetch (outputWeightName )
223
222
.fetch (outputBiasName )
224
223
.run ());
225
- System .out .println ("Initialized - " + ndArrToString ((TFloat32 )initialized .get (i ).get (3 )));
226
- System .out .println ("Trained - " + ndArrToString ((TFloat32 )trained .get (i ).get (3 )));
227
224
228
225
lossVal = (TFloat32 ) s .runner ()
229
226
.addTarget (trainName )
@@ -237,10 +234,10 @@ public void testDeterminism() {
237
234
}
238
235
239
236
for (int i = 1 ; i < numRuns ; i ++) {
240
- assertEquals (initialLoss [0 ],initialLoss [i ]);
241
- assertEquals (postTrainingLoss [0 ],postTrainingLoss [i ]);
237
+ assertEquals (initialLoss [0 ], initialLoss [i ]);
238
+ assertEquals (postTrainingLoss [0 ], postTrainingLoss [i ]);
242
239
// Because the weights are references not copies.
243
- assertEquals (initialized .get (i ),trained .get (i ));
240
+ assertEquals (initialized .get (i ), trained .get (i ));
244
241
assertEquals (
245
242
initialized .get (0 ),
246
243
initialized .get (i ),
@@ -260,10 +257,4 @@ public void testDeterminism() {
260
257
}
261
258
}
262
259
}
263
-
264
- private static String ndArrToString (FloatNdArray ndarray ) {
265
- StringBuffer sb = new StringBuffer ();
266
- ndarray .scalars ().forEachIndexed ((idx ,array ) -> sb .append (Arrays .toString (idx )).append (" = " ).append (array .getFloat ()).append ("\n " ));
267
- return sb .toString ();
268
- }
269
260
}
0 commit comments