Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated Trainer's liger-kernel integration to call correct patching API #33502

Merged
merged 2 commits into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,19 +468,18 @@ def __init__(

if self.args.use_liger_kernel:
if is_liger_kernel_available():
from liger_kernel.transformers.trainer_integration import _apply_liger_kernel
from liger_kernel.transformers import _apply_liger_kernel_to_instance

model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None)
if model_type:
# Monkey patch the model with liger kernels. Use the default kernel configurations.
_apply_liger_kernel(model_type=model_type)
if isinstance(model, PreTrainedModel):
# Patch the model with liger kernels. Use the default kernel configurations.
_apply_liger_kernel_to_instance(model=model)
else:
logger.warning(
"The model does not have a valid `model_type` specified. No liger kernels will be applied."
"The model is not an instance of PreTrainedModel. No liger kernels will be applied."
)
else:
raise ImportError(
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.1.0 is not available. "
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. "
"Please install it with `pip install liger-kernel`"
)

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,7 @@ def is_liger_kernel_available():
if not _liger_kernel_available:
return False

return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.1.0")
return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0")


# docstyle-ignore
Expand Down
30 changes: 18 additions & 12 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1344,22 +1344,28 @@ def test_get_eval_dataloader_with_persistent_workers(self):

@require_liger_kernel
def test_use_liger_kernel_patching(self):
# Test that the model code actually gets patched with Liger kernel
from liger_kernel.transformers.rms_norm import LigerRMSNorm
# Ensure any monkey patching is cleaned up for subsequent tests
with patch("transformers.models.llama.modeling_llama"):
from liger_kernel.transformers import LigerRMSNorm, liger_rotary_pos_emb

from transformers.models.llama import modeling_llama
from transformers.models.llama import modeling_llama

config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)
config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4)
tiny_llama = LlamaForCausalLM(config)

args = TrainingArguments(
"./test",
use_liger_kernel=True,
)
Trainer(tiny_llama, args)
# Spot check that modeling code and model instance variables are not yet patched
self.assertNotEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm))

args = TrainingArguments(
"./test",
use_liger_kernel=True,
)
Trainer(tiny_llama, args)

# Check that one of the Llama model layers has been correctly patched with Liger kernel
self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm)
# Spot check that modeling code and model instance variables are patched
self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb)
self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm))

@require_liger_kernel
@require_torch_gpu
Expand Down