diff --git a/README.md b/README.md index 6b7cc3921..da2d64580 100644 --- a/README.md +++ b/README.md @@ -193,7 +193,7 @@ loss.backward() | Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss | | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | -| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss | +| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | diff --git a/src/liger_kernel/transformers/model/phi3.py b/src/liger_kernel/transformers/model/phi3.py new file mode 100644 index 000000000..4cb7ec0ea --- /dev/null +++ b/src/liger_kernel/transformers/model/phi3.py @@ -0,0 +1,136 @@ +from typing import List, Optional, Tuple, Union + +import torch +from torch.nn import CrossEntropyLoss +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.phi3.modeling_phi3 import ( + _CONFIG_FOR_DOC, + PHI3_INPUTS_DOCSTRING, +) +from transformers.utils import ( + add_start_docstrings_to_model_forward, + replace_return_docstrings, +) + +from liger_kernel.transformers.fused_linear_cross_entropy import ( + LigerFusedLinearCrossEntropyLoss, +) + + +@add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) +@replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC +) +def lce_forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, +) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Copy paste phi3 forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy + + + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, Phi3ForCausalLM + + >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") + >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + loss = None + logits = None + + if self.training and labels is not None: + shift_hidden_states = hidden_states[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # flatten tokens + shift_hidden_states = shift_hidden_states.view(-1, self.config.hidden_size) + shift_labels = shift_labels.view(-1) + + lce = LigerFusedLinearCrossEntropyLoss() + loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 5c6dc665f..10a30fcb5 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -5,8 +5,9 @@ from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss from liger_kernel.transformers.geglu import LigerGEGLUMLP from liger_kernel.transformers.model.gemma import lce_forward as gemma_lce_forward -from liger_kernel.transformers.model.llama import lce_forward +from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.mistral import lce_forward as mistral_lce_forward +from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb @@ -33,7 +34,7 @@ def apply_liger_kernel_to_llama( rope (bool): Whether to apply Liger's rotary position embedding. Default is True. cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. fused_linear_cross_entropy (bool): - Whether to apply Liger's fused lienar cross entropy loss. Default is True. + Whether to apply Liger's fused linear cross entropy loss. Default is True. `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. @@ -55,7 +56,7 @@ def apply_liger_kernel_to_llama( if cross_entropy: modeling_llama.CrossEntropyLoss = LigerCrossEntropyLoss if fused_linear_cross_entropy: - modeling_llama.LlamaForCausalLM.forward = lce_forward + modeling_llama.LlamaForCausalLM.forward = llama_lce_forward def apply_liger_kernel_to_mistral( @@ -72,7 +73,7 @@ def apply_liger_kernel_to_mistral( rope (bool): Whether to apply Liger's rotary position embedding. Default is True. cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is True. fused_linear_cross_entropy (bool): - Whether to apply Liger's fused lienar cross entropy loss. Default is True. + Whether to apply Liger's fused linear cross entropy loss. Default is True. `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. @@ -206,7 +207,7 @@ def apply_liger_kernel_to_qwen2( rope (bool): Whether to apply Liger's rotary position embedding. Default is True. cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. fused_linear_cross_entropy (bool): - Whether to apply Liger's fused lienar cross entropy loss. Default is True. + Whether to apply Liger's fused linear cross entropy loss. Default is True. `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. @@ -232,7 +233,8 @@ def apply_liger_kernel_to_qwen2( def apply_liger_kernel_to_phi3( rope: bool = True, - cross_entropy: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, rms_norm: bool = True, swiglu: bool = True, ) -> None: @@ -242,9 +244,17 @@ def apply_liger_kernel_to_phi3( Args: rope (bool): Whether to apply Liger's rotary position embedding. Default is True. cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True. """ + assert not ( + cross_entropy and fused_linear_cross_entropy + ), "cross_entropy and fused_linear_cross_entropy cannot both be True." + from transformers.models.phi3 import modeling_phi3 if rope: @@ -255,11 +265,14 @@ def apply_liger_kernel_to_phi3( modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP if cross_entropy: modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss + if fused_linear_cross_entropy: + modeling_phi3.Phi3ForCausalLM.forward = phi3_lce_forward # Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py MODEL_TYPE_TO_APPLY_LIGER_FN = { "gemma": apply_liger_kernel_to_gemma, + "gemma2": apply_liger_kernel_to_gemma2, "llama": apply_liger_kernel_to_llama, "mistral": apply_liger_kernel_to_mistral, "mixtral": apply_liger_kernel_to_mixtral, diff --git a/test/convergence/test_mini_models_no_logits.py b/test/convergence/test_mini_models_no_logits.py index 90d512a9d..964ba9a33 100644 --- a/test/convergence/test_mini_models_no_logits.py +++ b/test/convergence/test_mini_models_no_logits.py @@ -14,6 +14,7 @@ from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mistral import MistralConfig, MistralForCausalLM +from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from liger_kernel.transformers import ( @@ -21,6 +22,7 @@ apply_liger_kernel_to_gemma2, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, + apply_liger_kernel_to_phi3, apply_liger_kernel_to_qwen2, ) @@ -85,6 +87,30 @@ attn_implementation="sdpa", # default value, pytorch native attention ), ), + "mini_phi3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_phi3, + model_class=Phi3ForCausalLM, + mini_model_config=Phi3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=896, # 3072 + initializer_range=0.02, + intermediate_size=4864, # 8192 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=None, # defaults to num_attention_heads + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + attn_implementation="eager", + ), + ), "mini_mistral": MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_mistral, model_class=MistralForCausalLM, @@ -268,6 +294,8 @@ def run_mini_model( ("mini_llama3", 32, 1e-4, torch.bfloat16, 5e-3, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), ("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), + ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), ("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), # Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match (casts are not done the same way) diff --git a/test/transformers/test_monkey_patch.py b/test/transformers/test_monkey_patch.py index a105b0881..0db1d258a 100644 --- a/test/transformers/test_monkey_patch.py +++ b/test/transformers/test_monkey_patch.py @@ -16,6 +16,7 @@ def test_import_from_root(): from liger_kernel.transformers import ( # noqa: F401 AutoLigerKernelForCausalLM, apply_liger_kernel_to_gemma, + apply_liger_kernel_to_gemma2, apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral,