Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FusedLinerCrossEntropy support for Phi3 #103

Merged
merged 32 commits into from
Aug 28, 2024

Conversation

tyler-romero
Copy link
Collaborator

@tyler-romero tyler-romero commented Aug 26, 2024

Summary

Add FusedLinearCrossEntropy support for Phi3. #98

Testing Done

  • Hardware Type: 4090
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@tyler-romero tyler-romero changed the title Tyler/fused ce phi3 Add FusedLinerCrossEntropy support for Phi3 Aug 26, 2024
@tyler-romero
Copy link
Collaborator Author

Need to rebase after merging #76

@ByronHsu
Copy link
Collaborator

@tyler-romero are you on discord? please say hi! https://discord.gg/nSeNms8u

@lancerts lancerts requested a review from shimizust August 26, 2024 19:29
@tyler-romero tyler-romero marked this pull request as ready for review August 27, 2024 16:52
@tyler-romero
Copy link
Collaborator Author

> make all
python -m pytest --disable-warnings test/ --ignore=test/convergence
=========================================================================================== test session starts ===========================================================================================
platform linux -- Python 3.10.13, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tromero/workspace/Liger-Kernel
plugins: devtools-0.12.2
collected 141 items                                                                                                                                                                                       

test/transformers/test_auto_model.py .                                                                                                                                                              [  0%]
test/transformers/test_cross_entropy.py ........................................................ss                                                                                                  [ 41%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                                                                                         [ 46%]
test/transformers/test_geglu.py ........                                                                                                                                                            [ 51%]
test/transformers/test_monkey_patch.py .....                                                                                                                                                        [ 55%]
test/transformers/test_rms_norm.py ................................                                                                                                                                 [ 78%]
test/transformers/test_rope.py ............                                                                                                                                                         [ 86%]
test/transformers/test_swiglu.py ................                                                                                                                                                   [ 97%]
test/transformers/test_trainer_integration.py .                                                                                                                                                     [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                                          [100%]

================================================================================ 139 passed, 2 skipped in 73.13s (0:01:13) ================================================================================
HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence
=========================================================================================== test session starts ===========================================================================================
platform linux -- Python 3.10.13, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tromero/workspace/Liger-Kernel
plugins: devtools-0.12.2
collected 28 items                                                                                                                                                                                        

test/convergence/test_mini_models.py ..............                                                                                                                                                 [ 50%]
test/convergence/test_mini_models_no_logits.py ..............                                                                                                                                       [100%]

===================================================================================== 28 passed in 151.42s (0:02:31) ======================================================================================
flake8 .; flake8_status=$?; \
isort .; isort_status=$?; \
black .; black_status=$?; \
if [ $flake8_status -ne 0 ] || [ $isort_status -ne 0 ] || [ $black_status -ne 0 ]; then \
        exit 1; \
fi
Skipped 2 files
All done! ✨ 🍰 ✨
58 files left unchanged.



# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was breaking one of the monkeypatch tests on main

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tyler-romero what is the root cause? Is it still breaking on the current main?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes as of now its still broken on main:

> make test
python -m pytest --disable-warnings test/ --ignore=test/convergence
=========================================================================================== test session starts ===========================================================================================
platform linux -- Python 3.10.13, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tromero/workspace/Liger-Kernel
plugins: devtools-0.12.2
collected 141 items                                                                                                                                                                                       

test/transformers/test_auto_model.py .                                                                                                                                                              [  0%]
test/transformers/test_cross_entropy.py ........................................................ss                                                                                                  [ 41%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                                                                                         [ 46%]
test/transformers/test_geglu.py ........                                                                                                                                                            [ 51%]
test/transformers/test_monkey_patch.py ....F                                                                                                                                                        [ 55%]
test/transformers/test_rms_norm.py ................................                                                                                                                                 [ 78%]
test/transformers/test_rope.py ............                                                                                                                                                         [ 86%]
test/transformers/test_swiglu.py ................                                                                                                                                                   [ 97%]
test/transformers/test_trainer_integration.py .                                                                                                                                                     [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                                          [100%]

================================================================================================ FAILURES =================================================================================================
__________________________________________________________________________________ test_patching_apis_match_auto_mapping __________________________________________________________________________________

    def test_patching_apis_match_auto_mapping():
        # Test that all of the patching APIs present also have a corresponding entry in the auto mapping
        patching_functions = [
            func
            for name, func in inspect.getmembers(monkey_patch, inspect.isfunction)
            if name.startswith("apply_liger_kernel_to_")
        ]
    
>       assert set(patching_functions) == set(MODEL_TYPE_TO_APPLY_LIGER_FN.values())
E       assert {<function ap...86dbbe0>, ...} == {<function ap...70e7586db7f0>}
E         
E         Extra items in the left set:
E         <function apply_liger_kernel_to_gemma2 at 0x70e7586dba30>
E         Use -v to get more diff

test/transformers/test_monkey_patch.py:95: AssertionError
========================================================================================= short test summary info =========================================================================================
FAILED test/transformers/test_monkey_patch.py::test_patching_apis_match_auto_mapping - assert {<function ap...86dbbe0>, ...} == {<function ap...70e7586db7f0>}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is just checking for the presence of this function in the mapping

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there were simultaneous commits merged with this test added and a new model type added. Are you able to fix the test? It's just making sure all the patching APIs are accounted for in the mapping (used with AutoModel class)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the test is also fixed by this PR!

@tyler-romero tyler-romero requested a review from lancerts August 27, 2024 22:19
@lancerts lancerts requested a review from ByronHsu August 27, 2024 22:22
Copy link
Collaborator

@shimizust shimizust left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks for the contribution!

@shimizust shimizust merged commit 54c8bc1 into linkedin:main Aug 28, 2024
1 check passed
DocShotgun added a commit to DocShotgun/axolotl that referenced this pull request Aug 28, 2024
winglian pushed a commit to DocShotgun/axolotl that referenced this pull request Aug 30, 2024
winglian added a commit to axolotl-ai-cloud/axolotl that referenced this pull request Sep 1, 2024
* Update supported models for Liger Kernel

Add Mistral LCE, Gemma LCE, Gemma 2 without LCE (softcapping is not yet implemented for Gemma in Liger Kernel LCE forward), Phi3 without LCE

* move import to their appropriate conditions

* Integrate Phi3 LCE support

linkedin/Liger-Kernel#103

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
djsaunde pushed a commit to axolotl-ai-cloud/axolotl that referenced this pull request Dec 17, 2024
* Update supported models for Liger Kernel

Add Mistral LCE, Gemma LCE, Gemma 2 without LCE (softcapping is not yet implemented for Gemma in Liger Kernel LCE forward), Phi3 without LCE

* move import to their appropriate conditions

* Integrate Phi3 LCE support

linkedin/Liger-Kernel#103

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants