Skip to content

Conversation

nfeybesse
Copy link
Contributor

No description provided.

@Craigacp
Copy link
Collaborator

Craigacp commented Mar 4, 2024

Can you add a test which triggers this bug to the cross entropy tests - https://github.com/tensorflow/java/blob/master/tensorflow-framework/src/test/java/org/tensorflow/framework/losses/CategoricalCrossentropyTest.java. I think this used to work so I worry it's due to a TF upgrade and we didn't catch it with tests.

@nfeybesse
Copy link
Contributor Author

No, the problem is older, and it is probably the dynamic batch size which triggers the problems. I will try to do a test case

@nfeybesse
Copy link
Contributor Author

@test
public void testCategoricalCrossEntopyWithDynamicBatchSize() {
try (Graph graph = new Graph()) {
Ops tf = Ops.create(graph);
Operand yPred = tf.placeholder(TFloat32.class, Placeholder.shape(Shape.of(-1, 3)));
Operand yTrue = tf.reshape(tf.constant(new float[] { 1f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 1f }), tf.array(3, 3));
CategoricalCrossentropy instance = new CategoricalCrossentropy(true);
Operand loss = instance.call(tf, yTrue, yPred);// Throw TFInvalidArgument Exception without fix
try (Session session = new Session(graph); TFloat32 result = (TFloat32) session.runner().feed(yPred, TFloat32.tensorOf(Shape.of(3, 3), DataBuffers.of(new float[] { 1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f }))).fetch(loss).run().get(0)) {
if (Math.abs(0.5514477f - result.getFloat()) > 0.01)
throw new IllegalStateException("Invalid result :" + result.getFloat());
}
}
}

@nfeybesse
Copy link
Contributor Author

I am confused because I have the impression that there has not yet been a test carried out by feeding a model with batches of dynamic size. I know from experience that it is largely possible, but that you have to track down a few small bugs. How would you integrate my test so that it would be suitable?

@Craigacp
Copy link
Collaborator

Craigacp commented Mar 6, 2024

Add it next to the other tests for that loss. If there are more issues then let's fix them.

More of the framework was in flight a couple of years ago, but we didn't get all of it merged, so I assume that some of those things were tested in the original codebase before it was broken up into smaller PRs.

@Craigacp Craigacp merged commit 2b6d83f into tensorflow:master Mar 8, 2024
@Craigacp
Copy link
Collaborator

Craigacp commented Mar 8, 2024

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants