Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[common][pyTorch]Add zero_centered_gamma option to RMSNorm #631

Merged
merged 12 commits into from
Feb 3, 2024

Conversation

ptrendx
Copy link
Member

@ptrendx ptrendx commented Jan 26, 2024

Fixes #577.

Changes:

  • [common] Change the kernel to support zero_centered_gamma option
  • [pyTorch] Removed restrictions around zero_centered_gamma usage
  • [pyTorch] Added tests (also added test for LayerNorm as for some reason there wasn't one in test_numerics.py) and tightened tolerances

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested review from timmoon10 and ksivaman January 26, 2024 00:33
@ptrendx
Copy link
Member Author

ptrendx commented Jan 26, 2024

/te-ci pytorch

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx
Copy link
Member Author

ptrendx commented Jan 26, 2024

/te-ci pytorch

1 similar comment
@ptrendx
Copy link
Member Author

ptrendx commented Jan 26, 2024

/te-ci pytorch

Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM aside from some pedantic suggestions.

ksivaman and others added 4 commits January 30, 2024 10:29
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member

/te-ci pytorch

ptrendx and others added 2 commits February 1, 2024 12:26
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ksivaman
Copy link
Member

ksivaman commented Feb 3, 2024

/te-ci pytorch

@ksivaman ksivaman merged commit d68028c into NVIDIA:main Feb 3, 2024
9 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Does transformer_engine.pytorch.RMSNorm support zero_centered_gamma?
3 participants