Skip to content

Commit

Permalink
dont use get_peft_model if model is already peft (#857)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur authored Oct 11, 2023
1 parent dd9b8f4 commit f7707fd
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,13 @@ def __init__(
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it to use the PEFT models"
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
if not isinstance(model, PeftModel):
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=args.gradient_checkpointing
)

model = get_peft_model(model, peft_config)
model = get_peft_model(model, peft_config)

if is_peft_available() and callbacks is None and isinstance(model, PeftModel):
callbacks = [PeftSavingCallback()]
Expand Down

0 comments on commit f7707fd

Please sign in to comment.