Skip to content

Commit

Permalink
Monkeypatch for Phi3 (#76)
Browse files Browse the repository at this point in the history
## Summary
Add a new monkeypatch function to support patching Huggingface's Phi3
implementation with Liger Kernels.

Phi3 has its own MLP implementation (`Phi3MLP`) so a
`LigerPhi3SwiGLUMLP` implementation that leverages
`LigerSiLUMulFunction` is provided as well.


## Testing Done
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

Convergence test added (and passing on my 4090) for a minimodel based on
Phi3 patched with liger kernels.
All tests passing.

## Questions for Discussion
Apparently Phi3 was only added in [transformers
v4.41](https://github.com/huggingface/transformers/releases/tag/v4.41.0),
but the lowest supported version of transformers in Liger-Kernel is
4.40.1. Additionally, only more recently has [`sdpa` been supported in
HF Phi3](huggingface/transformers#32457).
Thoughts? Should I leave the transformers dependency version as-is?
  • Loading branch information
tyler-romero authored Aug 27, 2024
1 parent a272dad commit ee2dacb
Show file tree
Hide file tree
Showing 11 changed files with 204 additions and 43 deletions.
1 change: 0 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ all: test test-convergence checkstyle
# Command to run pytest for correctness tests
test:
python -m pytest --disable-warnings test/ --ignore=test/convergence


# Command to run flake8 (code style check), isort (import ordering), and black (code formatting)
# Subsequent commands still run if the previous fails, but return failure at the end
Expand Down
26 changes: 14 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@



[![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly)
[![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn)
[![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly)
[![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn)

<img src="./docs/images/logo-banner.png">

Expand Down Expand Up @@ -33,8 +33,8 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and
| ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) |

> **Note:**
> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.
## Examples

Expand Down Expand Up @@ -72,15 +72,15 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and

- `torch >= 2.1.2`
- `triton >= 2.3.0`
- `transformers >= 4.40.1`
- `transformers >= 4.41.0`

> **Note:**
> Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).
To install the stable version:

```bash
$ pip install liger-kernel
$ pip install liger-kernel
```

To install the nightly version:
Expand Down Expand Up @@ -109,7 +109,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_llama
model = transformers.AutoModelForCausalLM.from_pretrained("<some llama model>")

# Adding this line automatically monkey-patches the model with the optimized Liger kernels
apply_liger_kernel_to_llama()
apply_liger_kernel_to_llama()
```

### 2. Compose Your Own Model
Expand Down Expand Up @@ -161,6 +161,8 @@ loss.backward()
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss |
| Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |



### Kernels
Expand All @@ -175,11 +177,11 @@ loss.backward()
| FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|

- **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction.
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction.
- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
$$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$
, is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction.
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by
$$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$
, is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used.
- **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.).
Expand All @@ -188,7 +190,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$


<!-- TODO: be more specific about batch size -->
> **Note:**
> **Note:**
> Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder.
## Note on ML Compiler
Expand All @@ -202,7 +204,7 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil
| Torch Compile | 3780 | 66.4 |
| Torch Compile + Liger Kernel | 3702 | 31.0 |

> **Note:**
> **Note:**
> 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> 2. Tested on torch `2.5.0.dev20240731+cu118`
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
install_requires=[
"torch>=2.1.2",
"triton>=2.3.0",
"transformers>=4.40.1",
"transformers>=4.41.0",
],
extras_require={
"dev": [
Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
)
33 changes: 32 additions & 1 deletion src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP
from liger_kernel.transformers.swiglu import (
LigerBlockSparseTop2MLP,
LigerPhi3SwiGLUMLP,
LigerSwiGLUMLP,
)


def apply_liger_kernel_to_llama(
Expand Down Expand Up @@ -181,3 +185,30 @@ def apply_liger_kernel_to_qwen2(
modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward
if swiglu:
modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP


def apply_liger_kernel_to_phi3(
rope: bool = True,
cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Phi3 models.
Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True.
"""
from transformers.models.phi3 import modeling_phi3

if rope:
modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma
if rms_norm:
modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama
if swiglu:
modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP
if cross_entropy:
modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss
24 changes: 24 additions & 0 deletions src/liger_kernel/transformers/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,27 @@ def __init__(self, config):
def forward(self, x):

return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x)))


class LigerPhi3SwiGLUMLP(nn.Module):
"""
Patch Phi3MLP to use LigerSiLUMulFunction
https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241
"""

def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_up_proj = nn.Linear(
self.hidden_size, 2 * self.intermediate_size, bias=False
)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_act not in ["silu", "swish"]:
raise ValueError(f"Activation function {config.hidden_act} not supported.")

def forward(self, x):
up_states = self.gate_up_proj(x)
gate, up_states = up_states.chunk(2, dim=-1)
return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states))
2 changes: 2 additions & 0 deletions src/liger_kernel/transformers/trainer_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
)

logger = logging.getLogger(__name__)
Expand All @@ -15,6 +16,7 @@
"llama": apply_liger_kernel_to_llama,
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
"phi3": apply_liger_kernel_to_phi3,
}


Expand Down
28 changes: 28 additions & 0 deletions test/convergence/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
from transformers.models.llama import LlamaConfig, LlamaForCausalLM
from transformers.models.mistral import MistralConfig, MistralForCausalLM
from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM
from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM
from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM

from liger_kernel.transformers import (
apply_liger_kernel_to_gemma,
apply_liger_kernel_to_llama,
apply_liger_kernel_to_mistral,
apply_liger_kernel_to_mixtral,
apply_liger_kernel_to_phi3,
apply_liger_kernel_to_qwen2,
)

Expand Down Expand Up @@ -176,6 +178,30 @@
attn_implementation="sdpa", # default value, pytorch native attention
),
),
"mini_phi3": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_phi3,
model_class=Phi3ForCausalLM,
mini_model_config=Phi3Config(
attention_dropout=0.0,
bos_token_id=1,
eos_token_id=2, # 32000
hidden_act="silu",
hidden_size=896, # 3072
initializer_range=0.02,
intermediate_size=4864, # 8192
max_position_embeddings=4096,
num_attention_heads=8, # 32
num_hidden_layers=4, # 32
num_key_value_heads=None, # defaults to num_attention_heads
rms_norm_eps=1e-5,
rope_theta=10000.0,
sliding_window=None,
tie_word_embeddings=False,
use_cache=True,
vocab_size=32064,
attn_implementation="eager",
),
),
}


Expand Down Expand Up @@ -253,6 +279,8 @@ def run_mini_model(
("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5),
],
)
def test_mini_model(
Expand Down
Loading

0 comments on commit ee2dacb

Please sign in to comment.