Skip to content

Commit 2af35fb

Browse files
authored
Clean up model preparation (#4577)
1 parent cbd90d4 commit 2af35fb

File tree

5 files changed

+147
-139
lines changed

5 files changed

+147
-139
lines changed

tests/test_sft_trainer.py

Lines changed: 36 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,97 +1600,55 @@ def test_prompt_tuning(self):
16001600

16011601
@require_peft
16021602
@require_bitsandbytes
1603-
def test_peft_model_with_quantization(self):
1604-
"""SFTTrainer should not freeze layers of existing PeftModel.
1605-
1606-
This test simulates a realistic QLoRA scenario where a quantized base model is first converted to a PeftModel,
1607-
then passed to SFTTrainer. The issue was that prepare_model_for_kbit_training would freeze all parameters
1608-
including the LoRA adapters, making training impossible.
1609-
"""
1603+
def test_peft_with_quantization(self):
16101604
# Get the base model
16111605
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
1612-
model = AutoModelForCausalLM.from_pretrained(model_id)
16131606

1614-
# Simulate a realistic QLoRA setup by mocking quantization attributes
1615-
# This mimics what happens when loading a model with load_in_4bit=True
1616-
model.is_loaded_in_4bit = True
1617-
model.is_loaded_in_8bit = False
1618-
1619-
# Verify that this triggers the is_qlora condition
1620-
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
1621-
assert is_qlora, "Model should be detected as QLoRA (quantized)"
1622-
1623-
# Create LoRA configuration suitable for QLoRA
1624-
lora_config = LoraConfig(
1625-
task_type=TaskType.CAUSAL_LM,
1626-
target_modules=["q_proj", "v_proj"],
1627-
r=16,
1628-
lora_alpha=32,
1629-
lora_dropout=0.1,
1607+
quantization_config = BitsAndBytesConfig(
1608+
load_in_4bit=True,
1609+
bnb_4bit_use_double_quant=True,
1610+
bnb_4bit_quant_type="nf4",
1611+
bnb_4bit_compute_dtype=torch.float16,
16301612
)
1631-
1632-
# Convert the quantized model to a PeftModel (typical QLoRA workflow)
1633-
peft_model = get_peft_model(model, lora_config)
1634-
1635-
# Verify the quantization attributes are preserved on the PeftModel
1636-
assert getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag"
1613+
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
16371614

16381615
# Get the dataset
16391616
dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train")
16401617

1641-
# Analyze parameters before SFTTrainer initialization
1642-
trainable_params_before = []
1643-
base_params_before = []
1644-
lora_params_before = []
1645-
1646-
for name, param in peft_model.named_parameters():
1647-
if param.requires_grad:
1648-
trainable_params_before.append(name)
1649-
if "lora" in name.lower():
1650-
lora_params_before.append(name)
1651-
else:
1652-
base_params_before.append(name)
1653-
1654-
# Ensure we have the expected parameter distribution for QLoRA
1655-
assert len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially"
1656-
assert len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters"
1657-
assert len(base_params_before) == 0, "Base model parameters should already be frozen in PeftModel"
1658-
16591618
# Initialize the trainer with the already configured PeftModel
1660-
training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", max_steps=1)
1661-
trainer = SFTTrainer(model=peft_model, args=training_args, train_dataset=dataset)
1662-
1663-
# Analyze parameters after SFTTrainer initialization
1664-
trainable_params_after = []
1665-
lora_params_after = []
1666-
1667-
for name, param in trainer.model.named_parameters():
1668-
if param.requires_grad:
1669-
trainable_params_after.append(name)
1670-
if "lora" in name.lower():
1671-
lora_params_after.append(name)
1619+
training_args = SFTConfig(output_dir=self.tmp_dir, learning_rate=0.1, report_to="none")
1620+
trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset, peft_config=LoraConfig())
16721621

1673-
# LoRA parameters should remain trainable
1674-
assert len(trainable_params_after) > 0, (
1675-
f"PeftModel should still have trainable parameters after SFTTrainer initialization. "
1676-
f"Found {len(trainable_params_after)} trainable params. "
1677-
f"This test fails without the fix for issue #3926."
1678-
)
1622+
# Save initial parameters to check they change during training
1623+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
16791624

1680-
assert len(lora_params_after) > 0, (
1681-
f"LoRA adapter parameters should remain trainable. "
1682-
f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original."
1683-
)
1625+
trainer.train()
16841626

1685-
# Ensure the parameter counts are preserved (no additional freezing occurred)
1686-
assert len(trainable_params_before) == len(trainable_params_after), (
1687-
"Number of trainable parameters should not change after SFTTrainer initialization"
1688-
)
1627+
# Check that training completed successfully
1628+
assert trainer.state.log_history[-1]["train_loss"] is not None
1629+
assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None
16891630

1690-
# Verify that all original LoRA parameters are still trainable
1691-
assert set(lora_params_before) == set(lora_params_after), (
1692-
"All original LoRA parameters should remain trainable after SFTTrainer initialization"
1693-
)
1631+
# Check the peft params have changed and the base model params have not changed
1632+
for n, param in previous_trainable_params.items():
1633+
new_param = trainer.model.get_parameter(n)
1634+
# In bitsandbytes, bias parameters are automatically cast to the input dtype during the forward pass if
1635+
# their dtype doesn’t match. This causes the module to change unexpectedly during the first forward pass of
1636+
# the training. To handle this, we cast these specific bias parameters to float32 before comparison.
1637+
# https://github.com/bitsandbytes-foundation/bitsandbytes/blob/45553f7392e524eacf400b132cfe01261f6477be/bitsandbytes/nn/modules.py#L518
1638+
# We still need to investigate why the compute dtype ends up being different than for these parameters.
1639+
if n in [
1640+
"base_model.model.model.layers.1.self_attn.k_proj.bias",
1641+
"base_model.model.model.layers.1.self_attn.q_proj.base_layer.bias",
1642+
"base_model.model.model.layers.1.self_attn.v_proj.base_layer.bias",
1643+
]:
1644+
param = param.float()
1645+
1646+
if "lora" not in n: # We expect the base model parameters to be the same
1647+
assert torch.allclose(param, new_param), f"Parameter {n} has changed"
1648+
elif "lora" in n: # We expect the peft parameters to be different
1649+
assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
1650+
else:
1651+
raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}")
16941652

16951653
@require_peft
16961654
def test_prompt_tuning_peft_model(self):

trl/trainer/grpo_trainer.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,13 @@
2828
import pandas as pd
2929
import torch
3030
import torch.utils.data
31-
import transformers
3231
from accelerate import logging
3332
from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
3433
from datasets import Dataset, IterableDataset
3534
from torch import nn
3635
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
3736
from torch.utils.data import DataLoader, Sampler
3837
from transformers import (
39-
AutoConfig,
4038
AutoModelForSequenceClassification,
4139
AutoProcessor,
4240
AutoTokenizer,
@@ -61,13 +59,14 @@
6159
from ..extras.profiling import profiling_context, profiling_decorator
6260
from ..extras.vllm_client import VLLMClient
6361
from ..import_utils import is_liger_kernel_available, is_vllm_available
64-
from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation
62+
from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
6563
from ..models.utils import _ForwardRedirection
6664
from .base_trainer import BaseTrainer
6765
from .callbacks import SyncRefModelCallback
6866
from .grpo_config import GRPOConfig
6967
from .utils import (
7068
RepeatSampler,
69+
create_model_from_path,
7170
disable_dropout_in_model,
7271
ensure_master_addr_port,
7372
entropy_from_logits,
@@ -87,7 +86,7 @@
8786

8887

8988
if is_peft_available():
90-
from peft import PeftConfig, PeftModel
89+
from peft import PeftConfig, PeftModel, get_peft_model
9190

9291
if is_liger_kernel_available():
9392
from liger_kernel.chunked_loss import LigerFusedLinearGRPOLoss
@@ -254,28 +253,14 @@ def __init__(
254253
model_name = model_name.split("/")[-1]
255254
args = GRPOConfig(f"{model_name}-GRPO")
256255

257-
# Models
258-
# Trained model
259-
model_init_kwargs = args.model_init_kwargs or {}
256+
# Model
260257
if isinstance(model, str):
261-
model_id = model
262-
dtype = model_init_kwargs.get("dtype", "auto")
263-
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
264-
pass # dtype is already a torch.dtype or "auto" or None
265-
elif isinstance(dtype, str): # it's a str, but not "auto"
266-
dtype = getattr(torch, dtype)
267-
model_init_kwargs["dtype"] = dtype
268-
else:
269-
raise ValueError(
270-
"Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
271-
f"a `torch.dtype` (e.g., 'float32'), but got {dtype}."
272-
)
273-
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
274-
config = AutoConfig.from_pretrained(model_id)
275-
architecture = getattr(transformers, config.architectures[0])
276-
model = architecture.from_pretrained(model_id, **model_init_kwargs)
258+
model_init_kwargs = args.model_init_kwargs or {}
259+
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
260+
if args.distributed_state.distributed_type == "DEEPSPEED":
261+
model_init_kwargs["device_map"] = None
262+
model = create_model_from_path(model, **model_init_kwargs)
277263
else:
278-
model_id = get_config_model_id(model.config)
279264
if args.model_init_kwargs is not None:
280265
logger.warning(
281266
"You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
@@ -290,9 +275,6 @@ def __init__(
290275
else inspect.signature(model.get_base_model().forward).parameters.keys()
291276
)
292277

293-
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
294-
model = prepare_peft_model(model, peft_config, args)
295-
296278
# Processing class
297279
if processing_class is None:
298280
processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left")
@@ -312,12 +294,40 @@ def __init__(
312294
self.pad_token_id = tokenizer.pad_token_id
313295
self.eos_token_id = tokenizer.eos_token_id
314296

297+
if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
298+
# If the model is already a PeftModel, we need to merge and unload it.
299+
# Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
300+
model = model.merge_and_unload()
301+
302+
# Create PEFT model
303+
if peft_config is not None:
304+
model = get_peft_model(model, peft_config)
305+
306+
# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
307+
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
308+
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
309+
model.enable_input_require_grads()
310+
311+
# When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
312+
# original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by
313+
# passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for
314+
# quantized models. See: https://github.com/huggingface/peft/issues/2889
315+
# Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do
316+
if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):
317+
for param in model.parameters():
318+
if param.requires_grad:
319+
param.data = param.data.to(torch.bfloat16)
320+
315321
# Reward functions
316322
if not isinstance(reward_funcs, list):
317323
reward_funcs = [reward_funcs]
318324
self.reward_func_names = []
319325
for i, reward_func in enumerate(reward_funcs):
320326
if isinstance(reward_func, str):
327+
model_init_kwargs = args.model_init_kwargs or {}
328+
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
329+
if args.distributed_state.distributed_type == "DEEPSPEED":
330+
model_init_kwargs["device_map"] = None
321331
reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
322332
reward_func, num_labels=1, **model_init_kwargs
323333
)
@@ -476,9 +486,11 @@ def __init__(
476486
self.ref_model = None
477487
else:
478488
# For deepspeed, fsdp or non-distributed models, create a reference model from scratch
479-
config = AutoConfig.from_pretrained(model_id)
480-
architecture = getattr(transformers, config.architectures[0])
481-
self.ref_model = architecture.from_pretrained(model_id, **model_init_kwargs)
489+
model_init_kwargs = args.model_init_kwargs or {}
490+
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
491+
if self.args.distributed_state.distributed_type == "DEEPSPEED":
492+
model_init_kwargs["device_map"] = None
493+
self.ref_model = create_model_from_path(get_config_model_id(self.model.config), **model_init_kwargs)
482494

483495
# Disable dropout in the models
484496
if args.disable_dropout:

trl/trainer/reward_trainer.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import torch
2727
import torch.nn as nn
28-
import transformers
2928
from accelerate import PartialState
3029
from accelerate.logging import get_logger
3130
from datasets import Dataset, IterableDataset
@@ -42,14 +41,14 @@
4241
from transformers.utils import is_peft_available
4342

4443
from ..data_utils import is_conversational
45-
from ..models import clone_chat_template, get_act_offloading_ctx_manager, prepare_peft_model
44+
from ..models import clone_chat_template, get_act_offloading_ctx_manager
4645
from .base_trainer import BaseTrainer
4746
from .reward_config import RewardConfig
48-
from .utils import disable_dropout_in_model, get_config_model_id, pad, remove_none_values
47+
from .utils import create_model_from_path, disable_dropout_in_model, get_config_model_id, pad, remove_none_values
4948

5049

5150
if is_peft_available():
52-
from peft import PeftConfig, PeftModel
51+
from peft import PeftConfig, PeftModel, get_peft_model
5352

5453

5554
logger = get_logger(__name__)
@@ -279,24 +278,13 @@ def __init__(
279278
args = RewardConfig(f"{model_name}-Reward")
280279

281280
# Model
282-
model_init_kwargs = args.model_init_kwargs or {}
283281
if isinstance(model, str):
284-
model_id = model
285-
dtype = model_init_kwargs.get("dtype", "auto")
286-
if isinstance(dtype, torch.dtype) or dtype == "auto" or dtype is None:
287-
pass # dtype is already a torch.dtype or "auto" or None
288-
elif isinstance(dtype, str) and dtype in ["bfloat16", "float16", "float32"]:
289-
model_init_kwargs["dtype"] = getattr(torch, dtype)
290-
else:
291-
raise ValueError(
292-
"Invalid `dtype` passed to `RewardConfig`. Expected either 'auto' or a string representing "
293-
f"a valid `torch.dtype` (e.g., 'float32'), but got {dtype}."
294-
)
295-
model_init_kwargs["device_map"] = model_init_kwargs.get("device_map", "auto")
296-
with suppress_from_pretrained_warning(transformers.modeling_utils.logger):
297-
model = AutoModelForSequenceClassification.from_pretrained(model_id, num_labels=1, **model_init_kwargs)
282+
model_init_kwargs = args.model_init_kwargs or {}
283+
# Special case for DeepSpeed: requires device_map=None ("auto" fails)
284+
if args.distributed_state.distributed_type == "DEEPSPEED":
285+
model_init_kwargs["device_map"] = None
286+
model = create_model_from_path(model, AutoModelForSequenceClassification, **model_init_kwargs)
298287
else:
299-
model_id = get_config_model_id(model.config)
300288
if args.model_init_kwargs is not None:
301289
logger.warning(
302290
"You passed `model_init_kwargs` to the `RewardConfig`, but your model is already instantiated. "
@@ -305,7 +293,7 @@ def __init__(
305293

306294
# Processing class
307295
if processing_class is None:
308-
processing_class = AutoTokenizer.from_pretrained(model_id)
296+
processing_class = AutoTokenizer.from_pretrained(get_config_model_id(model.config))
309297

310298
# Handle pad token for processors or tokenizers
311299
if args.eos_token is not None:
@@ -356,8 +344,29 @@ def __init__(
356344
else:
357345
peft_config.modules_to_save.append("lm_head")
358346

359-
if peft_config is not None or (is_peft_available() and isinstance(model, PeftModel)):
360-
model = prepare_peft_model(model, peft_config, args)
347+
if is_peft_available() and isinstance(model, PeftModel) and peft_config is not None:
348+
# If the model is already a PeftModel, we need to merge and unload it.
349+
# Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
350+
model = model.merge_and_unload()
351+
352+
# Create PEFT model
353+
if peft_config is not None:
354+
model = get_peft_model(model, peft_config)
355+
356+
# When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
357+
# handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
358+
if is_peft_available() and isinstance(model, PeftModel) and args.gradient_checkpointing:
359+
model.enable_input_require_grads()
360+
361+
# When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
362+
# original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by
363+
# passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for
364+
# quantized models. See: https://github.com/huggingface/peft/issues/2889
365+
# Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do
366+
if getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False):
367+
for param in model.parameters():
368+
if param.requires_grad:
369+
param.data = param.data.to(torch.bfloat16)
361370

362371
# Disable dropout in the model
363372
if args.disable_dropout:

0 commit comments

Comments
 (0)