From 0055dc8c2e4a8244424f41fea93e57835425292e Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Sat, 24 Aug 2024 15:05:53 -0700 Subject: [PATCH 01/10] Monkeypatch for Phi3 --- setup.py | 1 + src/liger_kernel/transformers/__init__.py | 1 + src/liger_kernel/transformers/monkey_patch.py | 33 ++++++++++++- src/liger_kernel/transformers/swiglu.py | 19 ++++++++ .../transformers/trainer_integration.py | 2 + test/convergence/test_mini_models.py | 48 +++++++++++++++---- test/transformers/test_trainer_integration.py | 6 ++- .../test_transformers_monkey_patch.py | 1 + 8 files changed, 99 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 86eb42174..79d3316b7 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ "isort>=5.13.2", "pre-commit>=3.7.1", "torch-tb-profiler>=0.4.1", + "pytest>=8.3.2" ] }, ) diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 4c6148d4f..0d43bdbc5 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -4,4 +4,5 @@ apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_phi3, ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 31ea7f3c5..9fec84bf7 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -4,7 +4,7 @@ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb -from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerSwiGLUMLP +from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerPhi3SwiGLUMLP, LigerSwiGLUMLP def apply_liger_kernel_to_llama( @@ -167,3 +167,34 @@ def apply_liger_kernel_to_qwen2( modeling_qwen2.Qwen2ForCausalLM.forward = qwen2_lce_forward if swiglu: modeling_qwen2.Qwen2MLP = LigerSwiGLUMLP + + +def apply_liger_kernel_to_phi3( + rope: bool = True, + cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = True, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace Phi3 models. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused lienar cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True. + """ + from transformers.models.phi3 import modeling_phi3 + + if rope: + modeling_phi3.apply_rotary_pos_emb = liger_rotary_pos_emb # Same as Gemma + if rms_norm: + modeling_phi3.Phi3RMSNorm = LigerRMSNorm # Same as Llama + if swiglu: + modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP + if cross_entropy: + modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss \ No newline at end of file diff --git a/src/liger_kernel/transformers/swiglu.py b/src/liger_kernel/transformers/swiglu.py index ebf1f0c03..83e3515bb 100644 --- a/src/liger_kernel/transformers/swiglu.py +++ b/src/liger_kernel/transformers/swiglu.py @@ -38,3 +38,22 @@ def __init__(self, config): def forward(self, x): return self.w2(LigerSiLUMulFunction.apply(self.w1(x), self.w3(x))) + + +class LigerPhi3SwiGLUMLP(nn.Module): + """Patch the Phi3MLP to use Liger's LigerSiLUMulFunction""" + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + if config.hidden_act not in ["silu", "swish"]: + raise ValueError(f"Activation function {config.hidden_act} not supported.") + + def forward(self, x): + up_states = self.gate_up_proj(x) + gate, up_states = up_states.chunk(2, dim=-1) + return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states)) \ No newline at end of file diff --git a/src/liger_kernel/transformers/trainer_integration.py b/src/liger_kernel/transformers/trainer_integration.py index 4caf03173..b943404d5 100644 --- a/src/liger_kernel/transformers/trainer_integration.py +++ b/src/liger_kernel/transformers/trainer_integration.py @@ -5,6 +5,7 @@ apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, + apply_liger_kernel_to_phi3, ) logger = logging.getLogger(__name__) @@ -15,6 +16,7 @@ "llama": apply_liger_kernel_to_llama, "mistral": apply_liger_kernel_to_mistral, "mixtral": apply_liger_kernel_to_mixtral, + "phi3": apply_liger_kernel_to_phi3, } diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index ae96a5648..a0e9f1b3e 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -15,12 +15,14 @@ from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM +from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM from liger_kernel.transformers import ( apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_phi3, ) MINI_MODEL_SETUPS = { @@ -144,6 +146,30 @@ attn_implementation="sdpa", # default value, pytorch native attention ), ), + "mini_phi3": MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_phi3, + model_class=Phi3ForCausalLM, + mini_model_config=Phi3Config( + attention_dropout=0.0, + bos_token_id=1, + eos_token_id=2, # 32000 + hidden_act="silu", + hidden_size=896, # 3072 + initializer_range=0.02, + intermediate_size=4864, # 8192 + max_position_embeddings=4096, + num_attention_heads=8, # 32 + num_hidden_layers=4, # 32 + num_key_value_heads=None, # defaults to num_attention_heads + rms_norm_eps=1e-5, + rope_theta=10000.0, + sliding_window=None, + tie_word_embeddings=False, + use_cache=True, + vocab_size=32064, + attn_implementation="eager", + ), + ), } @@ -201,16 +227,18 @@ def run_mini_model( @pytest.mark.parametrize( "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", [ - ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 1e-4, 1e-5, 2e-3, 1e-5), - ("mini_llama3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), - # TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine - # TODO: mixtral MoE structure makes the convergence flaky so disable the test for now. It needs high tol to pass. - # ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5), - # ("mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 2.0, 1e-5, 1e-2, 1e-5), - ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - ("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), - ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - ("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), + # ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 1e-4, 1e-5, 2e-3, 1e-5), + # ("mini_llama3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), + # # TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine + # # TODO: mixtral MoE structure makes the convergence flaky so disable the test for now. It needs high tol to pass. + # # ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5), + # # ("mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 2.0, 1e-5, 1e-2, 1e-5), + # ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + # ("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), + # ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + # ("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), + ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), ], ) def test_mini_model( diff --git a/test/transformers/test_trainer_integration.py b/test/transformers/test_trainer_integration.py index a76f8f4a8..6056e5723 100644 --- a/test/transformers/test_trainer_integration.py +++ b/test/transformers/test_trainer_integration.py @@ -20,15 +20,19 @@ def test_apply_liger_kernel_only_supported_model_type_called(): mock_gemma = Mock() mock_llama = Mock() mock_mistral = Mock() + mock_mixtral = Mock() + mock_phi3 = Mock() with patch.dict( MODEL_TYPE_TO_APPLY_LIGER_FN, - {"gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral}, + {"gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral, "mixtral": mock_mixtral, "phi3": mock_phi3}, ): _apply_liger_kernel("llama") mock_llama.assert_called_once() mock_gemma.assert_not_called() mock_mistral.assert_not_called() + mock_mixtral.assert_not_called() + mock_phi3.assert_not_called() def test_apply_liger_kernel_passes_kwargs(): diff --git a/test/transformers/test_transformers_monkey_patch.py b/test/transformers/test_transformers_monkey_patch.py index 2af747da2..3c4d9666f 100644 --- a/test/transformers/test_transformers_monkey_patch.py +++ b/test/transformers/test_transformers_monkey_patch.py @@ -9,6 +9,7 @@ def test_import_from_root(): apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, apply_liger_kernel_to_qwen2, + apply_liger_kernel_to_phi3, ) except Exception: pytest.fail("Import kernel patch from root fails") From 859b5d5dbfeba764cd013bf7cb06e3b168b651c5 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Sat, 24 Aug 2024 15:10:39 -0700 Subject: [PATCH 02/10] checkstyle --- setup.py | 2 +- src/liger_kernel/transformers/__init__.py | 2 +- src/liger_kernel/transformers/monkey_patch.py | 8 ++++++-- src/liger_kernel/transformers/swiglu.py | 6 ++++-- test/convergence/test_mini_models.py | 4 ++-- test/transformers/test_trainer_integration.py | 8 +++++++- test/transformers/test_transformers_monkey_patch.py | 2 +- 7 files changed, 22 insertions(+), 10 deletions(-) diff --git a/setup.py b/setup.py index 79d3316b7..5aff2ae64 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ "isort>=5.13.2", "pre-commit>=3.7.1", "torch-tb-profiler>=0.4.1", - "pytest>=8.3.2" + "pytest>=8.3.2", ] }, ) diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 0d43bdbc5..836823acb 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -3,6 +3,6 @@ apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, - apply_liger_kernel_to_qwen2, apply_liger_kernel_to_phi3, + apply_liger_kernel_to_qwen2, ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 9fec84bf7..21893d1f8 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -4,7 +4,11 @@ from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward from liger_kernel.transformers.rms_norm import LigerRMSNorm from liger_kernel.transformers.rope import liger_rotary_pos_emb -from liger_kernel.transformers.swiglu import LigerBlockSparseTop2MLP, LigerPhi3SwiGLUMLP, LigerSwiGLUMLP +from liger_kernel.transformers.swiglu import ( + LigerBlockSparseTop2MLP, + LigerPhi3SwiGLUMLP, + LigerSwiGLUMLP, +) def apply_liger_kernel_to_llama( @@ -197,4 +201,4 @@ def apply_liger_kernel_to_phi3( if swiglu: modeling_phi3.Phi3MLP = LigerPhi3SwiGLUMLP if cross_entropy: - modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss \ No newline at end of file + modeling_phi3.CrossEntropyLoss = LigerCrossEntropyLoss diff --git a/src/liger_kernel/transformers/swiglu.py b/src/liger_kernel/transformers/swiglu.py index 83e3515bb..e0495a943 100644 --- a/src/liger_kernel/transformers/swiglu.py +++ b/src/liger_kernel/transformers/swiglu.py @@ -48,7 +48,9 @@ def __init__(self, config): self.config = config self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size - self.gate_up_proj = nn.Linear(self.hidden_size, 2 * self.intermediate_size, bias=False) + self.gate_up_proj = nn.Linear( + self.hidden_size, 2 * self.intermediate_size, bias=False + ) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) if config.hidden_act not in ["silu", "swish"]: raise ValueError(f"Activation function {config.hidden_act} not supported.") @@ -56,4 +58,4 @@ def __init__(self, config): def forward(self, x): up_states = self.gate_up_proj(x) gate, up_states = up_states.chunk(2, dim=-1) - return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states)) \ No newline at end of file + return self.down_proj(LigerSiLUMulFunction.apply(gate, up_states)) diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index a0e9f1b3e..a9f3a9607 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -14,15 +14,15 @@ from transformers.models.llama import LlamaConfig, LlamaForCausalLM from transformers.models.mistral import MistralConfig, MistralForCausalLM from transformers.models.mixtral import MixtralConfig, MixtralForCausalLM -from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from transformers.models.phi3 import Phi3Config, Phi3ForCausalLM +from transformers.models.qwen2 import Qwen2Config, Qwen2ForCausalLM from liger_kernel.transformers import ( apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, - apply_liger_kernel_to_qwen2, apply_liger_kernel_to_phi3, + apply_liger_kernel_to_qwen2, ) MINI_MODEL_SETUPS = { diff --git a/test/transformers/test_trainer_integration.py b/test/transformers/test_trainer_integration.py index 6056e5723..397cc0dbd 100644 --- a/test/transformers/test_trainer_integration.py +++ b/test/transformers/test_trainer_integration.py @@ -25,7 +25,13 @@ def test_apply_liger_kernel_only_supported_model_type_called(): with patch.dict( MODEL_TYPE_TO_APPLY_LIGER_FN, - {"gemma": mock_gemma, "llama": mock_llama, "mistral": mock_mistral, "mixtral": mock_mixtral, "phi3": mock_phi3}, + { + "gemma": mock_gemma, + "llama": mock_llama, + "mistral": mock_mistral, + "mixtral": mock_mixtral, + "phi3": mock_phi3, + }, ): _apply_liger_kernel("llama") mock_llama.assert_called_once() diff --git a/test/transformers/test_transformers_monkey_patch.py b/test/transformers/test_transformers_monkey_patch.py index 3c4d9666f..9443d56c1 100644 --- a/test/transformers/test_transformers_monkey_patch.py +++ b/test/transformers/test_transformers_monkey_patch.py @@ -8,8 +8,8 @@ def test_import_from_root(): apply_liger_kernel_to_llama, apply_liger_kernel_to_mistral, apply_liger_kernel_to_mixtral, - apply_liger_kernel_to_qwen2, apply_liger_kernel_to_phi3, + apply_liger_kernel_to_qwen2, ) except Exception: pytest.fail("Import kernel patch from root fails") From b80b319f458b9030c80fc9966cce5695704c4549 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Sat, 24 Aug 2024 15:12:36 -0700 Subject: [PATCH 03/10] some cleanup --- src/liger_kernel/transformers/monkey_patch.py | 4 ---- test/convergence/test_mini_models.py | 20 +++++++++---------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 21893d1f8..0335dc54f 100644 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -185,10 +185,6 @@ def apply_liger_kernel_to_phi3( Args: rope (bool): Whether to apply Liger's rotary position embedding. Default is True. cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. - fused_linear_cross_entropy (bool): - Whether to apply Liger's fused lienar cross entropy loss. Default is True. - `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. - If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. swiglu (bool): Whether to apply Liger's SwiGLU Phi3MLP. Default is True. """ diff --git a/test/convergence/test_mini_models.py b/test/convergence/test_mini_models.py index a9f3a9607..b4289417e 100644 --- a/test/convergence/test_mini_models.py +++ b/test/convergence/test_mini_models.py @@ -227,16 +227,16 @@ def run_mini_model( @pytest.mark.parametrize( "model_name, num_steps, lr, dtype, loss_atol, loss_rtol, logits_atol, logits_rtol, param_atol, param_rtol", [ - # ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 1e-4, 1e-5, 2e-3, 1e-5), - # ("mini_llama3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), - # # TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine - # # TODO: mixtral MoE structure makes the convergence flaky so disable the test for now. It needs high tol to pass. - # # ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5), - # # ("mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 2.0, 1e-5, 1e-2, 1e-5), - # ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - # ("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), - # ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), - # ("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), + ("mini_llama3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 1e-4, 1e-5, 2e-3, 1e-5), + ("mini_llama3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), + # TODO: torch 2.5.0 nightly breaks mixtral test, but torch 2.3.0 works fine + # TODO: mixtral MoE structure makes the convergence flaky so disable the test for now. It needs high tol to pass. + # ("mini_mixtral", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 8e-3, 1e-5), + # ("mini_mixtral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 2.0, 1e-5, 1e-2, 1e-5), + ("mini_mistral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_mistral", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), + ("mini_qwen2", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), + ("mini_qwen2", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), ("mini_phi3", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5), ("mini_phi3", 32, 1e-4, torch.bfloat16, 1e-8, 1e-5, 1e-1, 1e-5, 1e-2, 1e-5), ], From 853b3f50744f7ab5c585e6551f6ce2a7ba189375 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Sat, 24 Aug 2024 15:27:31 -0700 Subject: [PATCH 04/10] Test for LigerPhi3SwiGLUMLP --- src/liger_kernel/transformers/swiglu.py | 5 +- test/transformers/test_swiglu.py | 86 ++++++++++++++++++++++++- 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/transformers/swiglu.py b/src/liger_kernel/transformers/swiglu.py index e0495a943..42f4df106 100644 --- a/src/liger_kernel/transformers/swiglu.py +++ b/src/liger_kernel/transformers/swiglu.py @@ -41,7 +41,10 @@ def forward(self, x): class LigerPhi3SwiGLUMLP(nn.Module): - """Patch the Phi3MLP to use Liger's LigerSiLUMulFunction""" + """ + Patch Phi3MLP to use LigerSiLUMulFunction + https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/phi3/modeling_phi3.py#L241 + """ def __init__(self, config): super().__init__() diff --git a/test/transformers/test_swiglu.py b/test/transformers/test_swiglu.py index 14132c2a9..3cff32d46 100644 --- a/test/transformers/test_swiglu.py +++ b/test/transformers/test_swiglu.py @@ -2,14 +2,21 @@ import torch from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaMLP +from transformers.models.phi3.configuration_phi3 import Phi3Config +from transformers.models.phi3.modeling_phi3 import Phi3MLP -from liger_kernel.transformers.swiglu import LigerSwiGLUMLP +from liger_kernel.transformers.swiglu import LigerPhi3SwiGLUMLP, LigerSwiGLUMLP LLAMA_CONFIG = LlamaConfig( hidden_size=4096, intermediate_size=11008, hidden_act="silu", ) +PHI3_CONFIG = Phi3Config( + hidden_size=4096, + intermediate_size=11008, + hidden_act="silu", +) SLEEP_SECONDS = 0.1 @@ -33,7 +40,9 @@ (torch.bfloat16, 1e4, 1e-2), ], ) -def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol): +def test_correctness_llamamlp( + bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol +): _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) @@ -94,3 +103,76 @@ def test_correctness(bsz, seq_len, hidden_size, intermediate_size, dtype, atol, ) assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) is True + + +@pytest.mark.parametrize( + "bsz, seq_len, hidden_size, intermediate_size", + [ + (2, 2048, 4096, 11008), + (2, 2048, 2048, 4096), + # weird shapes + (9, 41, 341, 4231), + (6, 42, 256, 2048), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + # atol is for small values: they have more difference, so set atol higher + # rtol is for larger values: they are very close, so set rtol lower + (torch.float32, 1e-0, 1e-5), + # TODO: we should find a better way to tune this. 1e4 is too large apparently + (torch.bfloat16, 1e4, 1e-2), + ], +) +def test_correctness_phi3mlp( + bsz, seq_len, hidden_size, intermediate_size, dtype, atol, rtol +): + + _input = torch.randn(bsz, seq_len, hidden_size, device="cuda", dtype=dtype) + + x1 = _input.clone().requires_grad_(True) + x2 = _input.clone().requires_grad_(True) + + # initialize weights + GU = torch.randn(hidden_size, intermediate_size * 2, device="cuda", dtype=dtype) + D = torch.randn(intermediate_size, hidden_size, device="cuda", dtype=dtype) + + phi3_mlp = Phi3MLP(config=PHI3_CONFIG).to("cuda").to(dtype) + phi3_mlp.gate_up_proj.weight.data = GU.T + phi3_mlp.down_proj.weight.data = D.T + + liger_mlp = LigerPhi3SwiGLUMLP(config=PHI3_CONFIG).to("cuda").to(dtype) + liger_mlp.gate_up_proj.weight.data = GU.T + liger_mlp.down_proj.weight.data = D.T + + y1 = phi3_mlp(x1) + y2 = liger_mlp(x2) + + assert torch.allclose(y1, y2, atol=atol, rtol=rtol) is True + + dy = torch.randn_like(y1) + + y1.backward(dy.clone(), retain_graph=True) + y2.backward(dy.clone(), retain_graph=True) + + assert ( + torch.allclose( + phi3_mlp.gate_up_proj.weight.grad, + liger_mlp.gate_up_proj.weight.grad, + atol=atol, + rtol=rtol, + ) + is True + ) + assert ( + torch.allclose( + phi3_mlp.down_proj.weight.grad, + liger_mlp.down_proj.weight.grad, + atol=atol, + rtol=rtol, + ) + is True + ) + + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) is True From cb6f1096ee6747d80c1811cbfed1afab0305d116 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Sat, 24 Aug 2024 15:34:17 -0700 Subject: [PATCH 05/10] Update Readme --- README.md | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 9c11d1512..2b26ec5bf 100644 --- a/README.md +++ b/README.md @@ -2,8 +2,8 @@ -[![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly) -[![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn) +[![Downloads](https://static.pepy.tech/badge/liger-kernel)](https://pepy.tech/project/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel.svg)](https://badge.fury.io/py/liger-kernel) [![PyPI version](https://badge.fury.io/py/liger-kernel-nightly.svg)](https://badge.fury.io/py/liger-kernel-nightly) +[![](https://dcbadge.vercel.app/api/server/cudamode?style=flat)](https://discord.gg/CX2YmNmn) [Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [APIs](#apis) | [Structure](#structure) | [Contributing](#contributing) @@ -32,8 +32,8 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and | ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) | > **Note:** -> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. -> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K. +> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. +> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K. ## Examples @@ -79,7 +79,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and To install the stable version: ```bash -$ pip install liger-kernel +$ pip install liger-kernel ``` To install the nightly version: @@ -101,7 +101,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_llama model = transformers.AutoModelForCausalLM.from_pretrained("") # Adding this line automatically monkey-patches the model with the optimized Liger kernels -apply_liger_kernel_to_llama() +apply_liger_kernel_to_llama() ``` ### 2. Compose Your Own Model @@ -153,6 +153,8 @@ loss.backward() | Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss | | Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss | | Qwen2 | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| Phi3 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss | + ### Kernels @@ -167,11 +169,11 @@ loss.backward() | FusedLinearCrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`| - **RMSNorm**: [RMSNorm](https://arxiv.org/pdf/1910.07467), which normalizes activations using their root mean square, is implemented by fusing the normalization and scaling steps into a single Triton kernel, and achieves ~3X speedup with ~3X peak memory reduction. -- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction. -- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by +- **RoPE**: [Rotary Positional Embedding](https://arxiv.org/pdf/2104.09864) is implemented by fusing the query and key embedding rotary into a single kernel with inplace replacement, and achieves ~3X speedup with ~3X peak memory reduction. +- **SwiGLU**: [Swish Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by $$\text{SwiGLU}(x)=\text{Swish}_{\beta}(xW+b)\otimes(xV+c)$$ , is implemented by fusing the elementwise multiplication (denoted by $\otimes$) into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. -- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by +- **GeGLU**: [GELU Gated Linear Units](https://arxiv.org/pdf/2002.05202), given by $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ , is implemented by fusing the elementwise multiplication into a single kernel with inplace replacement, and achieves parity speed with ~1.5X peak memory reduction. Note that the [tanh approximation form of GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html) is used. - **CrossEntropy**: [Cross entropy loss](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html) is implemented by computing both the loss and gradient in the forward pass with inplace replacement of input to reduce the peak memory by avoiding simultaneous materialization of both input logits and gradient. It achieves >2X speedup and >4X memory reduction for common vocab sizes (e.g., 32K, 128K, etc.). @@ -180,7 +182,7 @@ $$\text{GeGLU}(x)=\text{GELU}(xW+b)\otimes(xV+c)$$ -> **Note:** +> **Note:** > Reported speedups and memory reductions are with respect to the LLaMA 3-8B Hugging Face layer implementations. All models use 4K hidden size and 4K sequence length and are evaluated based on memory usage and wall time for the forward+backward pass on a single NVIDIA A100 80G GPU using small batch sizes. Liger kernels exhibit more efficient scaling to larger batch sizes, detailed further in the [Benchmark](./benchmark) folder. ## Note on ML Compiler @@ -194,7 +196,7 @@ Since Liger Kernel is 100% Triton-based, it works seamlessly with [`torch.compil | Torch Compile | 3780 | 66.4 | | Torch Compile + Liger Kernel | 3702 | 31.0 | -> **Note:** +> **Note:** > 1. Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Seq Len = 4096, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s. > 2. Tested on torch `2.5.0.dev20240731+cu118` From ae9e06028722981ecb2b517326b886f2acccd7e1 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Sun, 25 Aug 2024 13:48:19 -0700 Subject: [PATCH 06/10] Address PR nit --- test/transformers/test_swiglu.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/test/transformers/test_swiglu.py b/test/transformers/test_swiglu.py index 3cff32d46..8fe2e1edf 100644 --- a/test/transformers/test_swiglu.py +++ b/test/transformers/test_swiglu.py @@ -67,7 +67,7 @@ def test_correctness_llamamlp( y1 = llama_mlp(x1) y2 = liger_mlp(x2) - assert torch.allclose(y1, y2, atol=atol, rtol=rtol) is True + assert torch.allclose(y1, y2, atol=atol, rtol=rtol) dy = torch.randn_like(y1) @@ -81,7 +81,6 @@ def test_correctness_llamamlp( atol=atol, rtol=rtol, ) - is True ) assert ( torch.allclose( @@ -90,7 +89,6 @@ def test_correctness_llamamlp( atol=atol, rtol=rtol, ) - is True ) assert ( torch.allclose( @@ -99,10 +97,9 @@ def test_correctness_llamamlp( atol=atol, rtol=rtol, ) - is True ) - assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) is True + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) @pytest.mark.parametrize( @@ -149,7 +146,7 @@ def test_correctness_phi3mlp( y1 = phi3_mlp(x1) y2 = liger_mlp(x2) - assert torch.allclose(y1, y2, atol=atol, rtol=rtol) is True + assert torch.allclose(y1, y2, atol=atol, rtol=rtol) dy = torch.randn_like(y1) @@ -163,7 +160,6 @@ def test_correctness_phi3mlp( atol=atol, rtol=rtol, ) - is True ) assert ( torch.allclose( @@ -172,7 +168,6 @@ def test_correctness_phi3mlp( atol=atol, rtol=rtol, ) - is True ) - assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) is True + assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) From 95b0ca2f8a7b0e7614959144cf2dc498d498695a Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Sun, 25 Aug 2024 13:49:12 -0700 Subject: [PATCH 07/10] Checkstyle --- test/transformers/test_swiglu.py | 60 +++++++++++++------------------- 1 file changed, 25 insertions(+), 35 deletions(-) diff --git a/test/transformers/test_swiglu.py b/test/transformers/test_swiglu.py index 8fe2e1edf..0b8ef3d45 100644 --- a/test/transformers/test_swiglu.py +++ b/test/transformers/test_swiglu.py @@ -74,29 +74,23 @@ def test_correctness_llamamlp( y1.backward(dy.clone(), retain_graph=True) y2.backward(dy.clone(), retain_graph=True) - assert ( - torch.allclose( - llama_mlp.gate_proj.weight.grad, - liger_mlp.gate_proj.weight.grad, - atol=atol, - rtol=rtol, - ) + assert torch.allclose( + llama_mlp.gate_proj.weight.grad, + liger_mlp.gate_proj.weight.grad, + atol=atol, + rtol=rtol, ) - assert ( - torch.allclose( - llama_mlp.up_proj.weight.grad, - liger_mlp.up_proj.weight.grad, - atol=atol, - rtol=rtol, - ) + assert torch.allclose( + llama_mlp.up_proj.weight.grad, + liger_mlp.up_proj.weight.grad, + atol=atol, + rtol=rtol, ) - assert ( - torch.allclose( - llama_mlp.down_proj.weight.grad, - liger_mlp.down_proj.weight.grad, - atol=atol, - rtol=rtol, - ) + assert torch.allclose( + llama_mlp.down_proj.weight.grad, + liger_mlp.down_proj.weight.grad, + atol=atol, + rtol=rtol, ) assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) @@ -153,21 +147,17 @@ def test_correctness_phi3mlp( y1.backward(dy.clone(), retain_graph=True) y2.backward(dy.clone(), retain_graph=True) - assert ( - torch.allclose( - phi3_mlp.gate_up_proj.weight.grad, - liger_mlp.gate_up_proj.weight.grad, - atol=atol, - rtol=rtol, - ) + assert torch.allclose( + phi3_mlp.gate_up_proj.weight.grad, + liger_mlp.gate_up_proj.weight.grad, + atol=atol, + rtol=rtol, ) - assert ( - torch.allclose( - phi3_mlp.down_proj.weight.grad, - liger_mlp.down_proj.weight.grad, - atol=atol, - rtol=rtol, - ) + assert torch.allclose( + phi3_mlp.down_proj.weight.grad, + liger_mlp.down_proj.weight.grad, + atol=atol, + rtol=rtol, ) assert torch.allclose(x1.grad, x2.grad, atol=atol, rtol=rtol) From 5baa8ee6b3c5155953d2a1763d08c9a3e40f4418 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Sun, 25 Aug 2024 14:29:43 -0700 Subject: [PATCH 08/10] Correctly resolve test.utils dir for make test command --- Makefile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index ae86fd0e7..7a4f16ce1 100644 --- a/Makefile +++ b/Makefile @@ -2,8 +2,8 @@ # Command to run pytest for correctness tests test: - pytest --disable-warnings test/ --ignore=test/convergence - + python -m pytest --disable-warnings test/ --ignore=test/convergence + # Command to run flake8 (code style check), isort (import ordering), and black (code formatting) # Subsequent commands still run if the previous fails, but return failure at the end @@ -18,4 +18,4 @@ checkstyle: # Command to run pytest for convergence tests # We have to explicitly set HF_DATASETS_OFFLINE=1, or dataset will silently try to send metrics and timeout (80s) https://github.com/huggingface/datasets/blob/37a603679f451826cfafd8aae00738b01dcb9d58/src/datasets/load.py#L286 test-convergence: - HF_DATASETS_OFFLINE=1 pytest --disable-warnings test/convergence \ No newline at end of file + HF_DATASETS_OFFLINE=1 python -m pytest --disable-warnings test/convergence \ No newline at end of file From f001ff5cfa817d76db9515e659ea745023328057 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Mon, 26 Aug 2024 10:07:54 -0700 Subject: [PATCH 09/10] Bump transformers version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d678220c7..ca101d2b3 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ install_requires=[ "torch>=2.1.2", "triton>=2.3.0", - "transformers>=4.40.1", + "transformers>=4.41.0", ], extras_require={ "dev": [ From b617c7730af7006f364126b56d8752826239a3c8 Mon Sep 17 00:00:00 2001 From: Tyler Romero Date: Mon, 26 Aug 2024 10:17:42 -0700 Subject: [PATCH 10/10] Bump transformers version in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f53c53abd..cc5fb83ce 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ With one line of code, Liger Kernel can increase throughput by more than 20% and - `torch >= 2.1.2` - `triton >= 2.3.0` -- `transformers >= 4.40.1` +- `transformers >= 4.41.0` > **Note:** > Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).