Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FusedLinerCrossEntropy support for Phi3 #103

Merged
merged 32 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0055dc8
Monkeypatch for Phi3
tyler-romero Aug 24, 2024
859b5d5
checkstyle
tyler-romero Aug 24, 2024
b80b319
some cleanup
tyler-romero Aug 24, 2024
853b3f5
Test for LigerPhi3SwiGLUMLP
tyler-romero Aug 24, 2024
cb6f109
Update Readme
tyler-romero Aug 24, 2024
e11c2d6
Merge branch 'main' into tyler/monkeypatch-phi3
tyler-romero Aug 25, 2024
ae9e060
Address PR nit
tyler-romero Aug 25, 2024
95b0ca2
Checkstyle
tyler-romero Aug 25, 2024
5baa8ee
Correctly resolve test.utils dir for make test command
tyler-romero Aug 25, 2024
2162490
Merge branch 'main' into tyler/monkeypatch-phi3
tyler-romero Aug 26, 2024
f001ff5
Bump transformers version
tyler-romero Aug 26, 2024
b617c77
Bump transformers version in README
tyler-romero Aug 26, 2024
2499b16
Add FusedLinerCrossEntropy support for Phi3
tyler-romero Aug 26, 2024
546ae4c
modify phi3 monkeypatch
tyler-romero Aug 26, 2024
43c8def
Add convergence tests for phi3 and qwen2
tyler-romero Aug 26, 2024
e8db468
Update README
tyler-romero Aug 26, 2024
53c1374
Add qwen2 to trainer integration
tyler-romero Aug 26, 2024
a407b0a
Merge branch 'main' into tyler/fused-ce-phi3
tyler-romero Aug 27, 2024
2ef2984
Merge branch 'main' into tyler/fused-ce-phi3
tyler-romero Aug 27, 2024
54160df
Merge branch 'main' into tyler/fused-ce-phi3
tyler-romero Aug 27, 2024
fd05b62
checkstyle
tyler-romero Aug 27, 2024
cbf7358
Typo fix
tyler-romero Aug 27, 2024
c2fc797
Merge branch 'main' into tyler/fused-ce-phi3
tyler-romero Aug 27, 2024
37aeb38
Merge branch 'main' into tyler/fused-ce-phi3
tyler-romero Aug 27, 2024
d27770a
checkstyle
tyler-romero Aug 27, 2024
6b4a147
Merge branch 'main' into tyler/fused-ce-phi3
tyler-romero Aug 27, 2024
7ab60c4
fallback to torch native linear+CE when without label
tyler-romero Aug 27, 2024
8c7fd83
fallback to torch native linear+CE when without label
tyler-romero Aug 27, 2024
76c6a76
Merge branch 'main' into tyler/fused-ce-phi3
tyler-romero Aug 27, 2024
6d6b01c
Fix for broken tests on main
tyler-romero Aug 27, 2024
ca5e8a0
Merge branch 'main' into tyler/fused-ce-phi3
lancerts Aug 27, 2024
dfdd40e
Merge branch 'main' into tyler/fused-ce-phi3
tyler-romero Aug 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ loss.backward()
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
| Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |



Expand Down
136 changes: 136 additions & 0 deletions src/liger_kernel/transformers/model/phi3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import List, Optional, Tuple, Union

import torch
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.phi3.modeling_phi3 import (
_CONFIG_FOR_DOC,
PHI3_INPUTS_DOCSTRING,
)
from transformers.utils import (
add_start_docstrings_to_model_forward,
replace_return_docstrings,
)

from liger_kernel.transformers.fused_linear_cross_entropy import (
LigerFusedLinearCrossEntropyLoss,
)


@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
)
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy


Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

Returns:

Example:

```python
>>> from transformers import AutoTokenizer, Phi3ForCausalLM

>>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct")
>>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct")

>>> prompt = "This is an example script ."
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum'
```"""

output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]

loss = None
logits = None

if self.training and labels is not None:
shift_hidden_states = hidden_states[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()

# flatten tokens
shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size)
shift_labels = shift_labels.view(-1)

lce = LigerFusedLinearCrossEntropyLoss()
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output

return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
25 changes: 19 additions & 6 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.geglu import LigerGEGLUMLP
from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward
from liger_kernel.transformers.model.llama import lce_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
Expand All @@ -33,7 +34,7 @@ def apply_liger_kernel_to_llama(
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused lienar cross entropy loss. Default is True.
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
Expand All @@ -55,7 +56,7 @@ def apply_liger_kernel_to_llama(
if cross_entropy:
modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_llama.LlamaForCausalLM.forward = lce_forward
modeling_llama.LlamaForCausalLM.forward = llama_lce_forward


def apply_liger_kernel_to_mistral(
Expand All @@ -72,7 +73,7 @@ def apply_liger_kernel_to_mistral(
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused lienar cross entropy loss. Default is True.
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
Expand Down Expand Up @@ -206,7 +207,7 @@ def apply_liger_kernel_to_qwen2(
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused lienar cross entropy loss. Default is True.
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
Expand All @@ -232,7 +233,8 @@ def apply_liger_kernel_to_qwen2(

def apply_liger_kernel_to_phi3(
rope: bool = True,
cross_entropy: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
) -> None:
Expand All @@ -242,9 +244,17 @@ def apply_liger_kernel_to_phi3(
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
"""
assert not (
cross_entropy and fused_linear_cross_entropy
), "cross_entropy and fused_linear_cross_entropy cannot both be True."

from transformers.models.phi3 import modeling_phi3

if rope:
Expand All @@ -255,11 +265,14 @@ def apply_liger_kernel_to_phi3(
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
if cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
if fused_linear_cross_entropy:
modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward


# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was breaking one of the monkeypatch tests on main

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tyler-romero what is the root cause? Is it still breaking on the current main?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes as of now its still broken on main:

> make test
python -m pytest --disable-warnings test/ --ignore=test/convergence
=========================================================================================== test session starts ===========================================================================================
platform linux -- Python 3.10.13, pytest-8.3.2, pluggy-1.5.0
rootdir: /home/tromero/workspace/Liger-Kernel
plugins: devtools-0.12.2
collected 141 items                                                                                                                                                                                       

test/transformers/test_auto_model.py .                                                                                                                                                              [  0%]
test/transformers/test_cross_entropy.py ........................................................ss                                                                                                  [ 41%]
test/transformers/test_fused_linear_cross_entropy.py ......                                                                                                                                         [ 46%]
test/transformers/test_geglu.py ........                                                                                                                                                            [ 51%]
test/transformers/test_monkey_patch.py ....F                                                                                                                                                        [ 55%]
test/transformers/test_rms_norm.py ................................                                                                                                                                 [ 78%]
test/transformers/test_rope.py ............                                                                                                                                                         [ 86%]
test/transformers/test_swiglu.py ................                                                                                                                                                   [ 97%]
test/transformers/test_trainer_integration.py .                                                                                                                                                     [ 98%]
test/triton/test_triton_monkey_patch.py ..                                                                                                                                                          [100%]

================================================================================================ FAILURES =================================================================================================
__________________________________________________________________________________ test_patching_apis_match_auto_mapping __________________________________________________________________________________

    def test_patching_apis_match_auto_mapping():
        # Test that all of the patching APIs present also have a corresponding entry in the auto mapping
        patching_functions = [
            func
            for name, func in inspect.getmembers(monkey_patch, inspect.isfunction)
            if name.startswith("apply_liger_kernel_to_")
        ]
    
>       assert set(patching_functions) == set(MODEL_TYPE_TO_APPLY_LIGER_FN.values())
E       assert {<function ap...86dbbe0>, ...} == {<function ap...70e7586db7f0>}
E         
E         Extra items in the left set:
E         <function apply_liger_kernel_to_gemma2 at 0x70e7586dba30>
E         Use -v to get more diff

test/transformers/test_monkey_patch.py:95: AssertionError
========================================================================================= short test summary info =========================================================================================
FAILED test/transformers/test_monkey_patch.py::test_patching_apis_match_auto_mapping - assert {<function ap...86dbbe0>, ...} == {<function ap...70e7586db7f0>}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test is just checking for the presence of this function in the mapping

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there were simultaneous commits merged with this test added and a new model type added. Are you able to fix the test? It's just making sure all the patching APIs are accounted for in the mapping (used with AutoModel class)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes the test is also fixed by this PR!

"llama": apply_liger_kernel_to_llama,
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
Expand Down
28 changes: 28 additions & 0 deletions test/convergence/test_mini_models_no_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM

from liger_kernel.transformers import (
apply_liger_kernel_to_gemma,
apply_liger_kernel_to_gemma2,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
)

Expand Down Expand Up @@ -85,6 +87,30 @@
attn_implementation="sdpa", # default value, pytorch native attention
),
),
"mini_phi3": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_phi3,
model_class=Phi3ForCausalLM,
mini_model_config=Phi3Config(
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2, # 32000
hidden_act="silu",
hidden_size=896, # 3072
initializer_range=0.02,
intermediate_size=4864, # 8192
max_position_embeddings=4096,
num_attention_heads=8, # 32
num_hidden_layers=4, # 32
num_key_value_heads=None, # defaults to num_attention_heads
rms_norm_eps=1e-5,
rope_theta=10000.0,
sliding_window=None,
tie_word_embeddings=False,
use_cache=True,
vocab_size=32064,
attn_implementation="eager",
),
),
"mini_mistral": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_mistral,
model_class=MistralForCausalLM,
Expand Down Expand Up @@ -268,6 +294,8 @@ def run_mini_model(
("mini_llama3", 32, 1e-4, torch.bfloat16, 5e-3, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way)
Expand Down
1 change: 1 addition & 0 deletions test/transformers/test_monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def test_import_from_root():
from liger_kernel.transformers import ( # noqa: F401
AutoLigerKernelForCausalLM,
apply_liger_kernel_to_gemma,
apply_liger_kernel_to_gemma2,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
Expand Down
Loading