Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow BatchNorm on CUDA with track_stats=False #2427

Merged
merged 1 commit into from
Apr 29, 2024

Conversation

paulnovo
Copy link
Contributor

Enables BatchNorm without track_stats in training and test modes. Also, unit tests are added to ensure the CUDA implementation matches the CPU implementation.

This address #1606 but requires changes in NNlib PR 576.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@paulnovo
Copy link
Contributor Author

The "buildkite/flux-dot-jl" tests appear to be failing because NNlib 0.9.13 was installed. 0.9.14 includes my related fixes to NNlib for track_stats=False. Is there a preferred way forward here? Should I increase the compat version of NNlib from 0.9.1 to 0.9.14?

@ToucheSir
Copy link
Member

Yes, please do and then this should be ready if tests pass.

Enables BatchNorm without track_stats in training and test modes. Also,
unit tests are added to ensure the CUDA implementation matches the CPU
implementation.

Also, update required NNlib to 0.9.14 which includes fixes to batcnorm
when track_stats=False.
Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@ToucheSir ToucheSir merged commit 27bdab3 into FluxML:master Apr 29, 2024
5 of 8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants