-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Add training mode support for BatchNormalization op #3597
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not pass e2e testing for the following onnx node tests due to a significant numerics mismatch:
"test_batchnorm_epsilon_training_mode"
"test_batchnorm_example_training_mode"
Please try to debug the numerics mismatching when you get the chance. You can either use this branch of torch-mlir in your iree-build and run the iree_tests, or you can use the alt_e2eshark and run python run.py --torchtolinalg -t test_batch
.
From a glance, it looks like the numerics for |
This commit extends the OnnxToTorch lowering for BatchNormalization op for supporting the case when training=True. Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
3d1a486
to
c030a0b
Compare
Hi @zjgarvey, the accuracy issue is fixed. After a lot of debugging, it turned out that the |
c030a0b
to
6ed8a18
Compare
@zjgarvey Can you please review this PR, today? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as this is passing numerics, the implementation is clear and well-commented on. Thanks, Vivek.
This commit extends the OnnxToTorch lowering for BatchNormalization op for supporting the case when training=True.
Signed-Off By: Vivek Khandelwal vivekkhandelwal1424@gmail.com