Skip to content

Conversation

@qgallouedec
Copy link
Member

@qgallouedec qgallouedec commented Nov 26, 2025

Summary

Through an in-depth investigation, I found that

  1. much of the model-preparation logic in the codebase is outdated, unnecessary, or incorrect.
  2. In addition, model preparation currently varies across trainers, with no unified approach, which implies maintenance headaches.

Goals of this PR

  1. Standardize model preparation for all stable trainers (SFT, GRPO, Reward, excluding DPO which is currently being refactored)
  2. Provide a correct, up-to-date, and well-documented model-preparation pipeline, derived from a thorough review of all cases covered in the referenced script.

Script used that covers the various cases

import numpy as np
import torch
import pandas as pd
from datasets import Dataset
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, Trainer, TrainingArguments

data = np.random.randint(0, 1000, (16, 64)).tolist()
dataset = Dataset.from_dict({"input_ids": data, "labels": data})

def human_readable_int(value: int) -> str:
    """Short human-readable numbers (1.2M, 340k...)."""
    if value >= 1_000_000:
        return f"{value / 1_000_000:.1f}M"
    if value >= 1_000:
        return f"{value / 1_000:.0f}k"
    return str(value)


def run_scenario(name, dtype, gc, quantized, use_lora):
    model_kwargs = {"device_map": "auto"}
    if quantized:
        model_kwargs["quantization_config"] = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )
    else:
        model_kwargs["dtype"] = dtype

    model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", **model_kwargs)

    if use_lora:
        model = get_peft_model(model, LoraConfig())
        model.enable_input_require_grads()  # ideally not needed, but see https://github.com/huggingface/transformers/issues/42489

    trainer = Trainer(
        model=model,
        args=TrainingArguments(gradient_checkpointing=gc),
        train_dataset=dataset,
    )
    trainer.train()

    model = trainer.model
    params = list(model.named_parameters())
    total_params = sum(p.numel() for _, p in params)
    trainable_params = sum(p.numel() for _, p in params if p.requires_grad)
    sample_dtype = params[0][1].dtype if params else "n/a"
    quant_method = getattr(model, "quantization_method", "")
    actual_quant = "4bit" if getattr(model, "is_loaded_in_4bit", False) else ("8bit" if getattr(model, "is_loaded_in_8bit", False) else (quant_method or "none"))
    is_lora_model = isinstance(model, PeftModel)
    observed = f"gc:{'on' if model.is_gradient_checkpointing else 'off'}, lora:{'yes' if is_lora_model else 'no'}, quant:{actual_quant}, dtype:{sample_dtype}"
    return {
        "Scenario": name,
        "Observed": observed,
        "Trainable": f"{human_readable_int(trainable_params)} / {human_readable_int(total_params)}",
    }


def main():
    scenarios = [
        {"name": "FFT",         "dtype": "auto",        "gc": False, "quantized": False, "use_lora": False},
        {"name": "FFT + FP16",  "dtype": torch.float16, "gc": False, "quantized": False, "use_lora": False},
        {"name": "FFT + GC",    "dtype": "auto",        "gc": True,  "quantized": False, "use_lora": False},
        {"name": "LoRA",        "dtype": "auto",        "gc": False, "quantized": False, "use_lora": True},
        {"name": "LoRA + GC",   "dtype": "auto",        "gc": True,  "quantized": False, "use_lora": True},
        {"name": "Q-LoRA",      "dtype": "auto",        "gc": False, "quantized": True,  "use_lora": True},
        {"name": "Q-LoRA + GC", "dtype": "auto",        "gc": True,  "quantized": True,  "use_lora": True},
    ]
    results = [run_scenario(**scenario) for scenario in scenarios]
    df = pd.DataFrame(results)
    print(df.to_markdown(index=False))


if __name__ == "__main__":
    main()
Scenario Observed Trainable
FFT gc:off, lora:no, quant:none, dtype:torch.bfloat16 596.0M / 596.0M
FFT + FP16 gc:off, lora:no, quant:none, dtype:torch.float16 596.0M / 596.0M
FFT + GC gc:on, lora:no, quant:none, dtype:torch.bfloat16 596.0M / 596.0M
LoRA gc:off, lora:yes, quant:none, dtype:torch.bfloat16 1.1M / 597.2M
LoRA + GC gc:on, lora:yes, quant:none, dtype:torch.bfloat16 1.1M / 597.2M
Q-LoRA gc:off, lora:yes, quant:4bit, dtype:torch.float16 1.1M / 377.0M
Q-LoRA + GC gc:on, lora:yes, quant:4bit, dtype:torch.float16 1.1M / 377.0M

@qgallouedec qgallouedec marked this pull request as ready for review November 29, 2025 00:54
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.


@require_peft
@require_bitsandbytes
def test_peft_model_with_quantization(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we just align the test with the other tests, to make maintenance easier

TrainingArguments,
is_comet_available,
)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to good practices, we shouldn't import this method, but I suggest that we make a special case, it's just for type hint.

from ..extras.vllm_client import VLLMClient
from ..import_utils import is_liger_kernel_available, is_vllm_available
from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation
from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we drop prepare_peft_model:

This part is now done directly the the trainer init:

trl/trl/models/utils.py

Lines 560 to 561 in 0726977

if isinstance(model, PeftModel) and peft_config is not None:
model = model.merge_and_unload()

The logic below (which I find quite hard to read) is intended to enable gradient checkpointing, with a few exceptions for QLoRA. After investigation, this behavior is already correctly handled by PEFT and Transformers, so this custom logic is no longer necessary. It is likely a leftover from a period when native support was incomplete, although it’s difficult to be certain. This is also a good reminder of the importance of adding comments whenever code is not self-explanatory.

trl/trl/models/utils.py

Lines 563 to 584 in 0726977

# Handle quantized models (QLoRA)
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
is_sharded_qlora = False
if getattr(model, "is_loaded_in_4bit", False):
# Check if model is sharded (FSDP/DS-Zero3)
for _, param in model.named_parameters():
if param.__class__.__name__ == "Params4bit":
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
break
# Prepare model for kbit training if needed
if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel):
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=args.gradient_checkpointing,
gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {},
)
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
args.gradient_checkpointing = False
elif args.gradient_checkpointing:
model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs)

It’s not obvious from the current code (again: missing comments), but autocast_adapter_dtype=False is intended to force the adapter dtype to bfloat16 when using a quantized model, however, this behavior doesn’t seem to be functional at the moment. See here
This logic has now been moved into the trainers’ initialization, which is in my opinion clearer

trl/trl/models/utils.py

Lines 586 to 599 in 0726977

# Create PEFT model
if peft_config is not None:
if (
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
and getattr(model, "is_loaded_in_4bit", False)
and is_sharded_qlora
):
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
else:
model = get_peft_model(model, peft_config)
# Handle bf16 casting for 4-bit models
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
peft_module_casting_to_bf16(model)

@qgallouedec qgallouedec changed the title [WIP] Clean up model preparation Clean up model preparation Dec 4, 2025
@qgallouedec qgallouedec merged commit 2af35fb into main Dec 5, 2025
10 of 11 checks passed
@qgallouedec qgallouedec deleted the better-model-prepare branch December 5, 2025 04:23
qgallouedec added a commit to neha222222/trl that referenced this pull request Dec 5, 2025
commit f278d03
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Dec 5 19:34:42 2025 +0100

    Remove no longer applicable warning once BCO was moved to experimental (huggingface#4628)

commit e7071bf
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Fri Dec 5 10:07:16 2025 -0700

    Add logos as assets (huggingface#4627)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 794d87f
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Dec 5 08:45:20 2025 +0100

    Add missing experimental autodoc classes to docs (huggingface#4618)

commit bc7888d
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Dec 5 07:48:33 2025 +0100

    Raise FutureWarning for trainer moved to experimental (huggingface#4620)

commit fce5dfd
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Dec 5 07:34:04 2025 +0100

    Raise warnings at 2nd stack level (huggingface#4621)

commit c5da8ec
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Dec 5 07:33:04 2025 +0100

    Silence experimental warning during docs build (huggingface#4623)

commit 2af35fb
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Dec 4 21:23:41 2025 -0700

    Clean up model preparation  (huggingface#4577)

commit cbd90d4
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Thu Dec 4 20:05:43 2025 +0100

    Remove deprecated batched formatting in GOLDTrainer (huggingface#4622)

commit 903b57d
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Thu Dec 4 19:16:00 2025 +0100

    Update ministral notebooks with official bf16 ckpt (huggingface#4626)

commit 9266135
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Thu Dec 4 19:01:46 2025 +0100

    Fix link to OpenEnv blog in docs (huggingface#4625)

commit 495381d
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Thu Dec 4 11:32:34 2025 +0100

    Fix README style (huggingface#4619)

commit ddb65e8
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Wed Dec 3 21:20:40 2025 +0100

    Add experimental imports to docs (huggingface#4616)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 5fab472
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Wed Dec 3 17:38:16 2025 +0100

    Replace arXiv paper links with HF links (huggingface#4613)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit a3c1dfb
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Wed Dec 3 17:28:45 2025 +0100

    Add ministral 3 free notebooks (huggingface#4614)

commit 560fd3d
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Wed Dec 3 10:12:20 2025 +0000

    [GRPOTrainer]: Add SAPO Loss (huggingface#4600)

commit 814d4af
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Tue Dec 2 15:52:51 2025 -0700

    Move MergeModelCallback to experimental (huggingface#4608)

    Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>

commit 2a81076
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Tue Dec 2 15:07:11 2025 +0100

    Fixed OpenEnv example scripts (huggingface#4610)

commit de343cd
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Tue Dec 2 07:32:22 2025 +0100

    Remove deprecations for 0.26 release (huggingface#4607)

commit 07b4a84
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Dec 1 12:55:24 2025 -0700

    Silence experimental warnings when imported in the stable (huggingface#4606)

commit c55ef4b
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Dec 1 12:40:42 2025 -0700

    Update How-to guides (huggingface#4604)

commit c686d7d
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Mon Dec 1 20:34:31 2025 +0100

    Raise FutureWarning for classes moved to experimental (huggingface#4605)

commit c7d172b
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Mon Dec 1 01:47:22 2025 -0800

    docs: Expand speeding up training guide with acceleration methods (huggingface#4428)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit f1dfef0
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Mon Dec 1 01:39:08 2025 -0800

    docs: Expand training customization examples (huggingface#4427)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit eb76389
Author: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
Date:   Sun Nov 30 16:45:21 2025 +0100

    [GRPO] Sequence-level TIS & MIS (huggingface#4530)

commit 0726977
Author: xuanduy04 <65279552+xuanduy04@users.noreply.github.com>
Date:   Fri Nov 28 23:56:22 2025 +0700

    docs: Add Beyond the 80/20 Rule (2506.01939) to Paper Index (huggingface#4580)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 9731d08
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Nov 28 17:43:38 2025 +0100

    Revert "Hotfix CI with dev dependencies: xfail test_prepare_inputs_for_generation" (huggingface#4587)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 84a0bbc
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Nov 28 16:13:56 2025 +0100

    Fix 'generation_config' AttributeError (huggingface#4596)

commit f67c3f2
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Fri Nov 28 15:46:02 2025 +0100

    Remove module-level imports of extra deps in experimental.judges (huggingface#4598)

commit cb5fdf9
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Thu Nov 27 11:08:26 2025 +0100

    Add missing require_bitsandbytes marker to CI tests (huggingface#4586)

commit 4a3b584
Author: juejuezi <juejuezi.git@foxmail.com>
Date:   Thu Nov 27 00:11:56 2025 +0800

    fix: use shift_labels for metrics when using CP or SP (huggingface#4579)

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit d2e4315
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Wed Nov 26 15:40:15 2025 +0100

    Revert hotfix Fall back to config.text_config._name_or_path (huggingface#4581)

commit 357e331
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 26 04:55:46 2025 -0700

    Move tests for GSPOTokenTrainer to experimental (huggingface#4572)

commit a59f2cf
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 26 04:50:44 2025 -0700

    Move `WinRateCallback` to experimental (huggingface#4558)

    Co-authored-by: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit cf431db
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 26 04:11:04 2025 -0700

    Fix PPO example (huggingface#4556)

commit cac9f1d
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Tue Nov 25 21:27:58 2025 +0000

    Fix Replay Buffer docs. (huggingface#4574)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants