-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[Relax] Batch norm correctness on eval mode #17752
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc: @MasterJH5574 this is ready for review |
|
@tvm-bot rerun |
MasterJH5574
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Thank you @hugolatendresse for the enhancement!
| ########## Neural Network ########## | ||
|
|
||
| def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var: | ||
| def _batch_norm(self, node: fx.Node, training) -> relax.Var: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good to add a type annotation in any of followup PRs.
| def _batch_norm(self, node: fx.Node, training) -> relax.Var: | |
| def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, will do, thanks
Batch_norm is a different operator in training and eval. The previous interface defaulted to the training mode and required changing an ingested pytorch program itself to use the eval mode. This is sub-ideal, especially since torch.export explicitely communicates whether batch_norm should be in training or eval in a given torch program. This PR automates the selection of training/eval mode in the exported program translator, and achieves correctness for eval mode. Future TODO: there is something wrong with batch_norm on training mode. It does not pass a correctness test when taken straight from the main branch (there's an issue with tensor dimensions). I added a note to address later as training mode is probably not high priority.
Batch_norm is a different operator in training and eval. The previous interface defaulted to the training mode and required changing an ingested pytorch program itself to use the eval mode. This is sub-ideal, especially since torch.export explicitely communicates whether batch_norm should be in training or eval in a given torch program.
This PR automates the selection of training/eval mode in the exported program translator, and achieves correctness for eval mode.
Future TODO: there is something wrong with batch_norm on training mode. It does not pass a correctness test when taken straight from the main branch (there's an issue with tensor dimensions). I added a note to address later as training mode is probably not high priority.