-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add FusedLinerCrossEntropy support for Phi3 #103
Conversation
Need to rebase after merging #76 |
@tyler-romero are you on discord? please say hi! https://discord.gg/nSeNms8u |
|
|
||
|
||
# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py | ||
MODEL_TYPE_TO_APPLY_LIGER_FN = { | ||
"gemma": apply_liger_kernel_to_gemma, | ||
"gemma2": apply_liger_kernel_to_gemma2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was breaking one of the monkeypatch tests on main
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tyler-romero what is the root cause? Is it still breaking on the current main?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes as of now its still broken on main:
> make test
python -m pytest --disable-warnings test/ --ignore=test/convergence
=========================================================================================== test session starts ===========================================================================================
platform linux -- Python 3.10.13, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tromero/workspace/Liger-Kernel
plugins: devtools-0.12.2
collected 141 items
test/transformers/test_auto_model.py . [ 0%]
test/transformers/test_cross_entropy.py ........................................................ss [ 41%]
test/transformers/test_fused_linear_cross_entropy.py ...... [ 46%]
test/transformers/test_geglu.py ........ [ 51%]
test/transformers/test_monkey_patch.py ....F [ 55%]
test/transformers/test_rms_norm.py ................................ [ 78%]
test/transformers/test_rope.py ............ [ 86%]
test/transformers/test_swiglu.py ................ [ 97%]
test/transformers/test_trainer_integration.py . [ 98%]
test/triton/test_triton_monkey_patch.py .. [100%]
================================================================================================ FAILURES =================================================================================================
__________________________________________________________________________________ test_patching_apis_match_auto_mapping __________________________________________________________________________________
def test_patching_apis_match_auto_mapping():
# Test that all of the patching APIs present also have a corresponding entry in the auto mapping
patching_functions = [
func
for name, func in inspect.getmembers(monkey_patch, inspect.isfunction)
if name.startswith("apply_liger_kernel_to_")
]
> assert set(patching_functions) == set(MODEL_TYPE_TO_APPLY_LIGER_FN.values())
E assert {<function ap...86dbbe0>, ...} == {<function ap...70e7586db7f0>}
E
E Extra items in the left set:
E <function apply_liger_kernel_to_gemma2 at 0x70e7586dba30>
E Use -v to get more diff
test/transformers/test_monkey_patch.py:95: AssertionError
========================================================================================= short test summary info =========================================================================================
FAILED test/transformers/test_monkey_patch.py::test_patching_apis_match_auto_mapping - assert {<function ap...86dbbe0>, ...} == {<function ap...70e7586db7f0>}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test is just checking for the presence of this function in the mapping
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there were simultaneous commits merged with this test added and a new model type added. Are you able to fix the test? It's just making sure all the patching APIs are accounted for in the mapping (used with AutoModel class)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes the test is also fixed by this PR!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks for the contribution!
* Update supported models for Liger Kernel Add Mistral LCE, Gemma LCE, Gemma 2 without LCE (softcapping is not yet implemented for Gemma in Liger Kernel LCE forward), Phi3 without LCE * move import to their appropriate conditions * Integrate Phi3 LCE support linkedin/Liger-Kernel#103 --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
* Update supported models for Liger Kernel Add Mistral LCE, Gemma LCE, Gemma 2 without LCE (softcapping is not yet implemented for Gemma in Liger Kernel LCE forward), Phi3 without LCE * move import to their appropriate conditions * Integrate Phi3 LCE support linkedin/Liger-Kernel#103 --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
Summary
Add FusedLinearCrossEntropy support for Phi3. #98
Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence