From e4ed7a3a5aa0f1e1b4f78317b3c7b25e5bf597f4 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 23 May 2024 18:34:22 +0530 Subject: [PATCH] do not upcast adapters when using FSDP+QLoRA (#1654) --- trl/trainer/sft_trainer.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 883c51560c..8862baa811 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -237,7 +237,14 @@ def make_inputs_require_grad(module, input, output): model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) - model = get_peft_model(model, peft_config) + if ( + "autocast_adapter_dtype" in list(inspect.signature(get_peft_model).parameters) + and getattr(model, "is_loaded_in_4bit", False) + and is_sharded_qlora + ): + model = get_peft_model(model, peft_config, autocast_adapter_dtype=False) + else: + model = get_peft_model(model, peft_config) if ( args is not None and args.bf16