From 1610baf0dfee7c0f1e63ab1b1c452dab733c4c33 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Wed, 6 Sep 2023 10:24:55 +0200 Subject: [PATCH] check correctly for condition (#668) --- trl/trainer/sft_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index cd21a8df83..ddc7e2b8c3 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -131,6 +131,8 @@ def __init__( "You passed a `DataCollatorForCompletionOnlyLM` to the SFTTrainer. This is not compatible with the `packing` argument." ) + supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) + if is_peft_available() and peft_config is not None: if not isinstance(peft_config, PeftConfig): raise ValueError( @@ -151,7 +153,7 @@ def __init__( if callbacks is None: callbacks = [PeftSavingCallback] - elif not isinstance(model, (PreTrainedModel, PeftModel)): + elif not isinstance(model, supported_classes): model = AutoModelForCausalLM.from_pretrained(model) if tokenizer is None: