-
Notifications
You must be signed in to change notification settings - Fork 331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Common] Moved framework agnostic THD kernels to common. #1339
base: main
Are you sure you want to change the base?
[Common] Moved framework agnostic THD kernels to common. #1339
Conversation
94a75ac
to
81718ac
Compare
685ad1b
to
3ee113e
Compare
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks.
Please make sure PyTorch CP tests (L1 tests) passed. Thanks.
17eca08
to
6308a03
Compare
/te-ci pytorch |
/te-ci pytorch L1 |
FYI, it seems like pipeline |
Probably needs to add |
71bd3cb
to
940fd65
Compare
/te-ci pytorch L1 |
Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
03b0400
to
89fcca9
Compare
Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
/te-ci pytorch L1 |
@@ -70,6 +70,7 @@ jobs: | |||
run: pip install . -v | |||
env: | |||
NVTE_FRAMEWORK: jax | |||
MAX_JOBS: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @timmoon10
Signed-off-by: Michael Goldfarb <mgoldfarb@nvidia.com>
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks.
Description
Moves several kernels that are used for context parallelism + THD fused attention to shared location so they can be called by Jax and other frameworks.
Type of change
Changes
Moves common code out of pytorch filder.
Checklist: