-
Notifications
You must be signed in to change notification settings - Fork 202
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
create rms norm tensor at input.device instead of device 0 #21
Merged
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
qingquansong
previously approved these changes
Aug 15, 2024
lancerts
approved these changes
Aug 15, 2024
yundai424
pushed a commit
that referenced
this pull request
Aug 16, 2024
## Summary create rms norm tensor at input.device instead of device 0 <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> Found a bug (?) in rmsnorm by eyeballing through the code. We always create tensor on "cuda", which means it is always on device 0. It can cause issue for multi-gpu training because gpu 0 will need more memory than others. Not sure why we haven't seen issue for end-to-end training. However, the fix should still be safe to apply. Along the way, i discover two issues 1. With triton 3.0.0, geglu kernel breaks cc @yundai424 2. With torch 2.5.0 dev, mixtral convergence break cc @qingquansong And did two modifications 1. Disable verbose mode in pytest because the logs will be too long to paste in PR 2. Adjust pr template to hide the instruction in comment ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ```bash jobuser [ ~/Liger-Kernel ]$ make test pytest --disable-warnings test/ --ignore=test/convergence =============================================================================== test session starts =============================================================================== platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0 rootdir: /home/jobuser/Liger-Kernel plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191 collected 111 items test/transformers/test_cross_entropy.py .......................................................... [ 52%] test/transformers/test_fused_linear_cross_entropy.py ...... [ 57%] test/transformers/test_geglu.py ........ [ 64%] test/transformers/test_rms_norm.py ................ [ 79%] test/transformers/test_rope.py ............ [ 90%] test/transformers/test_swiglu.py ........ [ 97%] test/transformers/test_transformers_monkey_patch.py . [ 98%] test/triton/test_triton_monkey_patch.py .. [100%] ========================================================================= 111 passed in 61.43s (0:01:01) ========================================================================== jobuser [ ~/Liger-Kernel ]$ make test-convergence HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence =============================================================================== test session starts =============================================================================== platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0 rootdir: /home/jobuser/Liger-Kernel plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191 collected 8 items test/convergence/test_mini_models.py ...... [ 75%] test/convergence/test_mini_models_no_logits.py .. [100%] ========================================================================== 8 passed in 97.62s (0:01:37) =========================================================================== jobuser [ ~/Liger-Kernel ]$ make checkstyle flake8 .; flake8_status=$?; \ isort .; isort_status=$?; \ black .; black_status=$?; \ if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \ exit 1; \ fi Skipped 1 files All done! ✨ 🍰 ✨ 42 files left unchanged. ```
3 tasks
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Found a bug (?) in rmsnorm by eyeballing through the code. We always create tensor on "cuda", which means it is always on device 0. It can cause issue for multi-gpu training because gpu 0 will need more memory than others. Not sure why we haven't seen issue for end-to-end training. However, the fix should still be safe to apply.
Along the way, i discover two issues
And did two modifications
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence