2828import pandas as pd
2929import torch
3030import torch .utils .data
31- import transformers
3231from accelerate import logging
3332from accelerate .utils import broadcast_object_list , gather , gather_object , is_peft_model , set_seed
3433from datasets import Dataset , IterableDataset
3534from torch import nn
3635from torch .distributed .fsdp import FullyShardedDataParallel as FSDP
3736from torch .utils .data import DataLoader , Sampler
3837from transformers import (
39- AutoConfig ,
4038 AutoModelForSequenceClassification ,
4139 AutoProcessor ,
4240 AutoTokenizer ,
6159from ..extras .profiling import profiling_context , profiling_decorator
6260from ..extras .vllm_client import VLLMClient
6361from ..import_utils import is_liger_kernel_available , is_vllm_available
64- from ..models import prepare_deepspeed , prepare_fsdp , prepare_peft_model , unwrap_model_for_generation
62+ from ..models import prepare_deepspeed , prepare_fsdp , unwrap_model_for_generation
6563from ..models .utils import _ForwardRedirection
6664from .base_trainer import BaseTrainer
6765from .callbacks import SyncRefModelCallback
6866from .grpo_config import GRPOConfig
6967from .utils import (
7068 RepeatSampler ,
69+ create_model_from_path ,
7170 disable_dropout_in_model ,
7271 ensure_master_addr_port ,
7372 entropy_from_logits ,
8786
8887
8988if is_peft_available ():
90- from peft import PeftConfig , PeftModel
89+ from peft import PeftConfig , PeftModel , get_peft_model
9190
9291if is_liger_kernel_available ():
9392 from liger_kernel .chunked_loss import LigerFusedLinearGRPOLoss
@@ -254,28 +253,14 @@ def __init__(
254253 model_name = model_name .split ("/" )[- 1 ]
255254 args = GRPOConfig (f"{ model_name } -GRPO" )
256255
257- # Models
258- # Trained model
259- model_init_kwargs = args .model_init_kwargs or {}
256+ # Model
260257 if isinstance (model , str ):
261- model_id = model
262- dtype = model_init_kwargs .get ("dtype" , "auto" )
263- if isinstance (dtype , torch .dtype ) or dtype == "auto" or dtype is None :
264- pass # dtype is already a torch.dtype or "auto" or None
265- elif isinstance (dtype , str ): # it's a str, but not "auto"
266- dtype = getattr (torch , dtype )
267- model_init_kwargs ["dtype" ] = dtype
268- else :
269- raise ValueError (
270- "Invalid `dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
271- f"a `torch.dtype` (e.g., 'float32'), but got { dtype } ."
272- )
273- model_init_kwargs ["device_map" ] = model_init_kwargs .get ("device_map" , "auto" )
274- config = AutoConfig .from_pretrained (model_id )
275- architecture = getattr (transformers , config .architectures [0 ])
276- model = architecture .from_pretrained (model_id , ** model_init_kwargs )
258+ model_init_kwargs = args .model_init_kwargs or {}
259+ # Special case for DeepSpeed: requires device_map=None ("auto" fails)
260+ if args .distributed_state .distributed_type == "DEEPSPEED" :
261+ model_init_kwargs ["device_map" ] = None
262+ model = create_model_from_path (model , ** model_init_kwargs )
277263 else :
278- model_id = get_config_model_id (model .config )
279264 if args .model_init_kwargs is not None :
280265 logger .warning (
281266 "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
@@ -290,9 +275,6 @@ def __init__(
290275 else inspect .signature (model .get_base_model ().forward ).parameters .keys ()
291276 )
292277
293- if peft_config is not None or (is_peft_available () and isinstance (model , PeftModel )):
294- model = prepare_peft_model (model , peft_config , args )
295-
296278 # Processing class
297279 if processing_class is None :
298280 processing_class = AutoProcessor .from_pretrained (get_config_model_id (model .config ), truncation_side = "left" )
@@ -312,12 +294,40 @@ def __init__(
312294 self .pad_token_id = tokenizer .pad_token_id
313295 self .eos_token_id = tokenizer .eos_token_id
314296
297+ if is_peft_available () and isinstance (model , PeftModel ) and peft_config is not None :
298+ # If the model is already a PeftModel, we need to merge and unload it.
299+ # Further information: https://huggingface.co/docs/trl/dpo_trainer#reference-model-considerations-with-peft
300+ model = model .merge_and_unload ()
301+
302+ # Create PEFT model
303+ if peft_config is not None :
304+ model = get_peft_model (model , peft_config )
305+
306+ # When using gradient checkpointing with PEFT, we need to enable input gradients. transformers.Trainer normally
307+ # handles this, but a bug currently prevents it; see https://github.com/huggingface/transformers/issues/42489
308+ if is_peft_available () and isinstance (model , PeftModel ) and args .gradient_checkpointing :
309+ model .enable_input_require_grads ()
310+
311+ # When using QLoRA, the PEFT adapter weights are converted to bf16 to follow the recommendations from the
312+ # original paper (see https://huggingface.co/papers/2305.14314, paragraph 3). Normally, this can be done by
313+ # passing `autocast_adapter_dtype=False` to `get_peft_model`, but this option is not yet supported for
314+ # quantized models. See: https://github.com/huggingface/peft/issues/2889
315+ # Non-quantized models do not have the `is_loaded_in_{8,4}bit` attributes, whereas quantized models do
316+ if getattr (model , "is_loaded_in_4bit" , False ) or getattr (model , "is_loaded_in_8bit" , False ):
317+ for param in model .parameters ():
318+ if param .requires_grad :
319+ param .data = param .data .to (torch .bfloat16 )
320+
315321 # Reward functions
316322 if not isinstance (reward_funcs , list ):
317323 reward_funcs = [reward_funcs ]
318324 self .reward_func_names = []
319325 for i , reward_func in enumerate (reward_funcs ):
320326 if isinstance (reward_func , str ):
327+ model_init_kwargs = args .model_init_kwargs or {}
328+ # Special case for DeepSpeed: requires device_map=None ("auto" fails)
329+ if args .distributed_state .distributed_type == "DEEPSPEED" :
330+ model_init_kwargs ["device_map" ] = None
321331 reward_funcs [i ] = AutoModelForSequenceClassification .from_pretrained (
322332 reward_func , num_labels = 1 , ** model_init_kwargs
323333 )
@@ -476,9 +486,11 @@ def __init__(
476486 self .ref_model = None
477487 else :
478488 # For deepspeed, fsdp or non-distributed models, create a reference model from scratch
479- config = AutoConfig .from_pretrained (model_id )
480- architecture = getattr (transformers , config .architectures [0 ])
481- self .ref_model = architecture .from_pretrained (model_id , ** model_init_kwargs )
489+ model_init_kwargs = args .model_init_kwargs or {}
490+ # Special case for DeepSpeed: requires device_map=None ("auto" fails)
491+ if self .args .distributed_state .distributed_type == "DEEPSPEED" :
492+ model_init_kwargs ["device_map" ] = None
493+ self .ref_model = create_model_from_path (get_config_model_id (self .model .config ), ** model_init_kwargs )
482494
483495 # Disable dropout in the models
484496 if args .disable_dropout :
0 commit comments