-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Clean up model preparation #4577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
|
||
| @require_peft | ||
| @require_bitsandbytes | ||
| def test_peft_model_with_quantization(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we just align the test with the other tests, to make maintenance easier
| TrainingArguments, | ||
| is_comet_available, | ||
| ) | ||
| from transformers.models.auto.auto_factory import _BaseAutoModelClass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to good practices, we shouldn't import this method, but I suggest that we make a special case, it's just for type hint.
| from ..extras.vllm_client import VLLMClient | ||
| from ..import_utils import is_liger_kernel_available, is_vllm_available | ||
| from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation | ||
| from ..models import prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we drop prepare_peft_model:
This part is now done directly the the trainer init:
Lines 560 to 561 in 0726977
| if isinstance(model, PeftModel) and peft_config is not None: | |
| model = model.merge_and_unload() |
The logic below (which I find quite hard to read) is intended to enable gradient checkpointing, with a few exceptions for QLoRA. After investigation, this behavior is already correctly handled by PEFT and Transformers, so this custom logic is no longer necessary. It is likely a leftover from a period when native support was incomplete, although it’s difficult to be certain. This is also a good reminder of the importance of adding comments whenever code is not self-explanatory.
Lines 563 to 584 in 0726977
| # Handle quantized models (QLoRA) | |
| is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) | |
| is_sharded_qlora = False | |
| if getattr(model, "is_loaded_in_4bit", False): | |
| # Check if model is sharded (FSDP/DS-Zero3) | |
| for _, param in model.named_parameters(): | |
| if param.__class__.__name__ == "Params4bit": | |
| is_sharded_qlora = param.data.device.type in {"cpu", "meta"} | |
| break | |
| # Prepare model for kbit training if needed | |
| if is_qlora and not is_sharded_qlora and not isinstance(model, PeftModel): | |
| model = prepare_model_for_kbit_training( | |
| model, | |
| use_gradient_checkpointing=args.gradient_checkpointing, | |
| gradient_checkpointing_kwargs=args.gradient_checkpointing_kwargs or {}, | |
| ) | |
| # Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training | |
| args.gradient_checkpointing = False | |
| elif args.gradient_checkpointing: | |
| model = enable_gradient_checkpointing(model, args.gradient_checkpointing_kwargs) |
It’s not obvious from the current code (again: missing comments), but autocast_adapter_dtype=False is intended to force the adapter dtype to bfloat16 when using a quantized model, however, this behavior doesn’t seem to be functional at the moment. See here
This logic has now been moved into the trainers’ initialization, which is in my opinion clearer
Lines 586 to 599 in 0726977
| # Create PEFT model | |
| if peft_config is not None: | |
| if ( | |
| version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12 | |
| and getattr(model, "is_loaded_in_4bit", False) | |
| and is_sharded_qlora | |
| ): | |
| model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) | |
| else: | |
| model = get_peft_model(model, peft_config) | |
| # Handle bf16 casting for 4-bit models | |
| if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora: | |
| peft_module_casting_to_bf16(model) |
commit f278d03 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri Dec 5 19:34:42 2025 +0100 Remove no longer applicable warning once BCO was moved to experimental (huggingface#4628) commit e7071bf Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Fri Dec 5 10:07:16 2025 -0700 Add logos as assets (huggingface#4627) Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> commit 794d87f Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri Dec 5 08:45:20 2025 +0100 Add missing experimental autodoc classes to docs (huggingface#4618) commit bc7888d Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri Dec 5 07:48:33 2025 +0100 Raise FutureWarning for trainer moved to experimental (huggingface#4620) commit fce5dfd Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri Dec 5 07:34:04 2025 +0100 Raise warnings at 2nd stack level (huggingface#4621) commit c5da8ec Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri Dec 5 07:33:04 2025 +0100 Silence experimental warning during docs build (huggingface#4623) commit 2af35fb Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Thu Dec 4 21:23:41 2025 -0700 Clean up model preparation (huggingface#4577) commit cbd90d4 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu Dec 4 20:05:43 2025 +0100 Remove deprecated batched formatting in GOLDTrainer (huggingface#4622) commit 903b57d Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Date: Thu Dec 4 19:16:00 2025 +0100 Update ministral notebooks with official bf16 ckpt (huggingface#4626) commit 9266135 Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Date: Thu Dec 4 19:01:46 2025 +0100 Fix link to OpenEnv blog in docs (huggingface#4625) commit 495381d Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Date: Thu Dec 4 11:32:34 2025 +0100 Fix README style (huggingface#4619) commit ddb65e8 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed Dec 3 21:20:40 2025 +0100 Add experimental imports to docs (huggingface#4616) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit 5fab472 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed Dec 3 17:38:16 2025 +0100 Replace arXiv paper links with HF links (huggingface#4613) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit a3c1dfb Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Date: Wed Dec 3 17:28:45 2025 +0100 Add ministral 3 free notebooks (huggingface#4614) commit 560fd3d Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Wed Dec 3 10:12:20 2025 +0000 [GRPOTrainer]: Add SAPO Loss (huggingface#4600) commit 814d4af Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Tue Dec 2 15:52:51 2025 -0700 Move MergeModelCallback to experimental (huggingface#4608) Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> commit 2a81076 Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> Date: Tue Dec 2 15:07:11 2025 +0100 Fixed OpenEnv example scripts (huggingface#4610) commit de343cd Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Tue Dec 2 07:32:22 2025 +0100 Remove deprecations for 0.26 release (huggingface#4607) commit 07b4a84 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Mon Dec 1 12:55:24 2025 -0700 Silence experimental warnings when imported in the stable (huggingface#4606) commit c55ef4b Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Mon Dec 1 12:40:42 2025 -0700 Update How-to guides (huggingface#4604) commit c686d7d Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon Dec 1 20:34:31 2025 +0100 Raise FutureWarning for classes moved to experimental (huggingface#4605) commit c7d172b Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Mon Dec 1 01:47:22 2025 -0800 docs: Expand speeding up training guide with acceleration methods (huggingface#4428) Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> commit f1dfef0 Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Date: Mon Dec 1 01:39:08 2025 -0800 docs: Expand training customization examples (huggingface#4427) Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com> commit eb76389 Author: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com> Date: Sun Nov 30 16:45:21 2025 +0100 [GRPO] Sequence-level TIS & MIS (huggingface#4530) commit 0726977 Author: xuanduy04 <65279552+xuanduy04@users.noreply.github.com> Date: Fri Nov 28 23:56:22 2025 +0700 docs: Add Beyond the 80/20 Rule (2506.01939) to Paper Index (huggingface#4580) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit 9731d08 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri Nov 28 17:43:38 2025 +0100 Revert "Hotfix CI with dev dependencies: xfail test_prepare_inputs_for_generation" (huggingface#4587) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> commit 84a0bbc Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri Nov 28 16:13:56 2025 +0100 Fix 'generation_config' AttributeError (huggingface#4596) commit f67c3f2 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri Nov 28 15:46:02 2025 +0100 Remove module-level imports of extra deps in experimental.judges (huggingface#4598) commit cb5fdf9 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Thu Nov 27 11:08:26 2025 +0100 Add missing require_bitsandbytes marker to CI tests (huggingface#4586) commit 4a3b584 Author: juejuezi <juejuezi.git@foxmail.com> Date: Thu Nov 27 00:11:56 2025 +0800 fix: use shift_labels for metrics when using CP or SP (huggingface#4579) Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit d2e4315 Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed Nov 26 15:40:15 2025 +0100 Revert hotfix Fall back to config.text_config._name_or_path (huggingface#4581) commit 357e331 Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Nov 26 04:55:46 2025 -0700 Move tests for GSPOTokenTrainer to experimental (huggingface#4572) commit a59f2cf Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Nov 26 04:50:44 2025 -0700 Move `WinRateCallback` to experimental (huggingface#4558) Co-authored-by: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> commit cf431db Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Date: Wed Nov 26 04:11:04 2025 -0700 Fix PPO example (huggingface#4556) commit cac9f1d Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com> Date: Tue Nov 25 21:27:58 2025 +0000 Fix Replay Buffer docs. (huggingface#4574)
Summary
Through an in-depth investigation, I found that
Goals of this PR
Script used that covers the various cases