Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add training mode support for BatchNormalization op #3597

Merged
merged 2 commits into from
Aug 14, 2024

Conversation

vivekkhandelwal1
Copy link
Collaborator

This commit extends the OnnxToTorch lowering for BatchNormalization op for supporting the case when training=True.

Signed-Off By: Vivek Khandelwal vivekkhandelwal1424@gmail.com

@vivekkhandelwal1
Copy link
Collaborator Author

vivekkhandelwal1 commented Aug 6, 2024

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not pass e2e testing for the following onnx node tests due to a significant numerics mismatch:

"test_batchnorm_epsilon_training_mode"
"test_batchnorm_example_training_mode"

Please try to debug the numerics mismatching when you get the chance. You can either use this branch of torch-mlir in your iree-build and run the iree_tests, or you can use the alt_e2eshark and run python run.py --torchtolinalg -t test_batch.

@zjgarvey
Copy link
Collaborator

zjgarvey commented Aug 6, 2024

From a glance, it looks like the numerics for Y and running_var are off, but running_mean seems to match pretty well.

This commit extends the OnnxToTorch lowering for BatchNormalization op
for supporting the case when training=True.

Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
@vivekkhandelwal1
Copy link
Collaborator Author

This does not pass e2e testing for the following onnx node tests due to a significant numerics mismatch:

"test_batchnorm_epsilon_training_mode" "test_batchnorm_example_training_mode"

Please try to debug the numerics mismatching when you get the chance. You can either use this branch of torch-mlir in your iree-build and run the iree_tests, or you can use the alt_e2eshark and run python run.py --torchtolinalg -t test_batch.

Hi @zjgarvey, the accuracy issue is fixed. After a lot of debugging, it turned out that the unbiased has to be set to False, instead of True.

@vivekkhandelwal1
Copy link
Collaborator Author

@zjgarvey Can you please review this PR, today?

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as this is passing numerics, the implementation is clear and well-commented on. Thanks, Vivek.

@vivekkhandelwal1 vivekkhandelwal1 merged commit 4a0bed0 into llvm:main Aug 14, 2024
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants