Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create rms norm tensor at input.device instead of device 0 #21

Merged
merged 4 commits into from
Aug 15, 2024

Conversation

ByronHsu
Copy link
Collaborator

@ByronHsu ByronHsu commented Aug 15, 2024

Summary

Found a bug (?) in rmsnorm by eyeballing through the code. We always create tensor on "cuda", which means it is always on device 0. It can cause issue for multi-gpu training because gpu 0 will need more memory than others. Not sure why we haven't seen issue for end-to-end training. However, the fix should still be safe to apply.

Along the way, i discover two issues

  1. With triton 3.0.0, geglu kernel breaks cc @yundai424
  2. With torch 2.5.0 dev, mixtral convergence break cc @qingquansong

And did two modifications

  1. Disable verbose mode in pytest because the logs will be too long to paste in PR
  2. Adjust pr template to hide the instruction in comment

Testing Done

  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence
jobuser [ ~/Liger-Kernel ]$ make test
pytest --disable-warnings test/ --ignore=test/convergence
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 111 items                                                                                                                                                               

test/transformers/test_cross_entropy.py ..........................................................                                                                          [ 52%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                                                                 [ 57%]
test/transformers/test_geglu.py ........                                                                                                                                    [ 64%]
test/transformers/test_rms_norm.py ................                                                                                                                         [ 79%]
test/transformers/test_rope.py ............                                                                                                                                 [ 90%]
test/transformers/test_swiglu.py ........                                                                                                                                   [ 97%]
test/transformers/test_transformers_monkey_patch.py .                                                                                                                       [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                  [100%]

========================================================================= 111 passed in 61.43s (0:01:01) ==========================================================================
jobuser [ ~/Liger-Kernel ]$ make test-convergence
HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 8 items                                                                                                                                                                 

test/convergence/test_mini_models.py ......                                                                                                                                 [ 75%]
test/convergence/test_mini_models_no_logits.py ..                                                                                                                           [100%]

========================================================================== 8 passed in 97.62s (0:01:37) ===========================================================================
jobuser [ ~/Liger-Kernel ]$ make checkstyle
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
Skipped 1 files
All done! ✨ 🍰 ✨
42 files left unchanged.

qingquansong
qingquansong previously approved these changes Aug 15, 2024
@lancerts lancerts merged commit fc8e95d into main Aug 15, 2024
@ByronHsu ByronHsu deleted the byhsu/rms-fix branch August 15, 2024 20:38
yundai424 pushed a commit that referenced this pull request Aug 16, 2024
## Summary
create rms norm tensor at input.device instead of device 0

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->
Found a bug (?) in rmsnorm by eyeballing through the code. We always
create tensor on "cuda", which means it is always on device 0. It can
cause issue for multi-gpu training because gpu 0 will need more memory
than others. Not sure why we haven't seen issue for end-to-end training.
However, the fix should still be safe to apply.

Along the way, i discover two issues
1. With triton 3.0.0, geglu kernel breaks cc @yundai424 
2. With torch 2.5.0 dev, mixtral convergence break cc @qingquansong 

And did two modifications
1. Disable verbose mode in pytest because the logs will be too long to
paste in PR
2. Adjust pr template to hide the instruction in comment

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

```bash
jobuser [ ~/Liger-Kernel ]$ make test
pytest --disable-warnings test/ --ignore=test/convergence
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 111 items                                                                                                                                                               

test/transformers/test_cross_entropy.py ..........................................................                                                                          [ 52%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                                                                 [ 57%]
test/transformers/test_geglu.py ........                                                                                                                                    [ 64%]
test/transformers/test_rms_norm.py ................                                                                                                                         [ 79%]
test/transformers/test_rope.py ............                                                                                                                                 [ 90%]
test/transformers/test_swiglu.py ........                                                                                                                                   [ 97%]
test/transformers/test_transformers_monkey_patch.py .                                                                                                                       [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                  [100%]

========================================================================= 111 passed in 61.43s (0:01:01) ==========================================================================
jobuser [ ~/Liger-Kernel ]$ make test-convergence
HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence
=============================================================================== test session starts ===============================================================================
platform linux -- Python 3.10.14, pytest-7.1.2, pluggy-1.0.0
rootdir: /home/jobuser/Liger-Kernel
plugins: lipy-config-base-30.6.1, lipy-fabric-35.2.3, lipy-test-8.0.52, datadir-1.3.1, lipy-mp-34.4.191
collected 8 items                                                                                                                                                                 

test/convergence/test_mini_models.py ......                                                                                                                                 [ 75%]
test/convergence/test_mini_models_no_logits.py ..                                                                                                                           [100%]

========================================================================== 8 passed in 97.62s (0:01:37) ===========================================================================
jobuser [ ~/Liger-Kernel ]$ make checkstyle
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
Skipped 1 files
All done! ✨ 🍰 ✨
42 files left unchanged.
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants