Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Monkeypatch for Phi3 #76

Merged
merged 13 commits into from
Aug 27, 2024
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ all: test test-convergence checkstyle
# Command to run pytest for correctness tests
test:
python -m pytest --disable-warnings test/ --ignore=test/convergence
tyler-romero marked this conversation as resolved.
Show resolved Hide resolved


# Command to run flake8 (code style check), isort (import ordering), and black (code formatting)
# Subsequent commands still run if the previous fails, but return failure at the end
Expand Down
26 changes: 14 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@



[![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly)
[![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn)
[![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly)
[![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn)

<img src="./docs/images/logo-banner.png">

Expand Down Expand Up @@ -33,8 +33,8 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |

> **Note:**
> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.

## Examples

Expand Down Expand Up @@ -72,15 +72,15 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and

- `torch >= 2.1.2`
- `triton >= 2.3.0`
- `transformers >= 4.40.1`
- `transformers >= 4.41.0`

> **Note:**
> Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).

To install the stable version:

```bash
$ pip install liger-kernel
$ pip install liger-kernel
```

To install the nightly version:
Expand Down Expand Up @@ -109,7 +109,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_llama
model = transformers.AutoModelForCausalLM.from_pretrained("<some llama model>")

# Adding this line automatically monkey-patches the model with the optimized Liger kernels
apply_liger_kernel_to_llama()
apply_liger_kernel_to_llama()
```

### 2. Compose Your Own Model
Expand Down Expand Up @@ -161,6 +161,8 @@ loss.backward()
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
| Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the whitespace modifications here, this is the only change in this file

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a lot for helping on the housekeeping! <3

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to implement for FusedLinearCrossEntropy in the next PR and also create a issue to track? Essentially add a method patch

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah happy to do that as well

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#98




### Kernels
Expand All @@ -175,11 +177,11 @@ loss.backward()
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|

- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
$$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
, is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
$$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
, is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
Expand All @@ -188,7 +190,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$


<!-- TODO: be more specific about batch size -->
> **Note:**
> **Note:**
> Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.

## Note on ML Compiler
Expand All @@ -202,7 +204,7 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil
| Torch Compile | 3780 | 66.4 |
| Torch Compile + Liger Kernel | 3702 | 31.0 |

> **Note:**
> **Note:**
> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> 2. Tested on torch `2.5.0.dev20240731+cu118`

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
install_requires=[
"torch>=2.1.2",
"triton>=2.3.0",
"transformers>=4.40.1",
"transformers>=4.41.0",
],
extras_require={
"dev": [
Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
)
33 changes: 32 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
from liger_kernel.transformers.swiglu import (
LigerBlockSparseTop2MLP,
LigerPhi3SwiGLUMLP,
LigerSwiGLUMLP,
)


def apply_liger_kernel_to_llama(
Expand Down Expand Up @@ -181,3 +185,30 @@ def apply_liger_kernel_to_qwen2(
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
if swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_phi3(
rope: bool = True,
cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
"""
from transformers.models.phi3 import modeling_phi3

if rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
if rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
if swiglu:
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
if cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
24 changes: 24 additions & 0 deletions src/liger_kernel/transformers/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,27 @@ def __init__(self, config):
def forward(self, x):

return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))


class LigerPhi3SwiGLUMLP(nn.Module):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So Llama has its own MLP implementation here, but its named very generally. I went for a model-specific name here, but open to suggestions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah. this is necessary

"""
Patch Phi3MLP to use LigerSiLUMulFunction
https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241
"""

def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = nn.Linear(
self.hidden_size, 2 * self.intermediate_size, bias=False
)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")

def forward(self, x):
up_states = self.gate_up_proj(x)
gate, up_states = up_states.chunk(2, dim=-1)
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
2 changes: 2 additions & 0 deletions src/liger_kernel/transformers/trainer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
)

logger = logging.getLogger(__name__)
Expand All @@ -15,6 +16,7 @@
"llama": apply_liger_kernel_to_llama,
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
"phi3": apply_liger_kernel_to_phi3,
}


Expand Down
28 changes: 28 additions & 0 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM

from liger_kernel.transformers import (
apply_liger_kernel_to_gemma,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
)

Expand Down Expand Up @@ -176,6 +178,30 @@
attn_implementation="sdpa", # default value, pytorch native attention
),
),
"mini_phi3": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_phi3,
model_class=Phi3ForCausalLM,
mini_model_config=Phi3Config(
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2, # 32000
hidden_act="silu",
hidden_size=896, # 3072
initializer_range=0.02,
intermediate_size=4864, # 8192
max_position_embeddings=4096,
num_attention_heads=8, # 32
num_hidden_layers=4, # 32
num_key_value_heads=None, # defaults to num_attention_heads
rms_norm_eps=1e-5,
rope_theta=10000.0,
sliding_window=None,
tie_word_embeddings=False,
use_cache=True,
vocab_size=32064,
attn_implementation="eager",
),
),
}


Expand Down Expand Up @@ -253,6 +279,8 @@ def run_mini_model(
("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
],
)
def test_mini_model(
Expand Down
Loading
Loading