Skip to content

Conversation

@hugolatendresse
Copy link
Contributor

@hugolatendresse hugolatendresse commented Mar 16, 2025

Batch_norm is a different operator in training and eval. The previous interface defaulted to the training mode and required changing an ingested pytorch program itself to use the eval mode. This is sub-ideal, especially since torch.export explicitely communicates whether batch_norm should be in training or eval in a given torch program.

This PR automates the selection of training/eval mode in the exported program translator, and achieves correctness for eval mode.

Future TODO: there is something wrong with batch_norm on training mode. It does not pass a correctness test when taken straight from the main branch (there's an issue with tensor dimensions). I added a note to address later as training mode is probably not high priority.

@hugolatendresse hugolatendresse changed the title [Relax] Fix batch norm ingestion [Relax] Batch norm correctness on eval mode Mar 16, 2025
@hugolatendresse hugolatendresse marked this pull request as ready for review March 24, 2025 04:29
@hugolatendresse
Copy link
Contributor Author

cc: @MasterJH5574 this is ready for review

@hugolatendresse
Copy link
Contributor Author

@tvm-bot rerun

Copy link
Contributor

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Thank you @hugolatendresse for the enhancement!

########## Neural Network ##########

def _batch_norm_legit_no_training(self, node: fx.Node) -> relax.Var:
def _batch_norm(self, node: fx.Node, training) -> relax.Var:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good to add a type annotation in any of followup PRs.

Suggested change
def _batch_norm(self, node: fx.Node, training) -> relax.Var:
def _batch_norm(self, node: fx.Node, training: bool) -> relax.Var:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, will do, thanks

@MasterJH5574 MasterJH5574 merged commit 51d4b6b into apache:main Mar 26, 2025
10 checks passed
ShiboXing pushed a commit to ShiboXing/tvm that referenced this pull request Aug 10, 2025
Batch_norm is a different operator in training and eval.
The previous interface defaulted to the training mode and required
changing an ingested pytorch program itself to use the eval mode.
This is sub-ideal, especially since torch.export explicitely communicates
whether batch_norm should be in training or eval in a given torch program. 

This PR automates the selection of training/eval mode in the exported
program translator, and achieves correctness for eval mode. 

Future TODO: there is something wrong with batch_norm on training
mode. It does not pass a correctness test when taken straight from
the main branch (there's an issue with tensor dimensions).
I added a note to address later as training mode is probably not high priority.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants