diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index d101935301..11c08c1083 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -172,27 +172,28 @@ def __init__( ) if not isinstance(model, PeftModel): + _support_gc_kwargs = hasattr( + args, "gradient_checkpointing_kwargs" + ) and "gradient_checkpointing_kwargs" in list( + inspect.signature(prepare_model_for_kbit_training).parameters + ) + gradient_checkpointing_kwargs = getattr(args, "gradient_checkpointing_kwargs", None) or {} if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): - _support_gc_kwargs = hasattr( - args, "gradient_checkpointing_kwargs" - ) and "gradient_checkpointing_kwargs" in list( - inspect.signature(prepare_model_for_kbit_training).parameters - ) - preprare_model_kwargs = { "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) } if _support_gc_kwargs: - preprare_model_kwargs["gradient_checkpointing_kwargs"] = getattr( - args, "gradient_checkpointing_kwargs", None - ) + preprare_model_kwargs["gradient_checkpointing_kwargs"] = gradient_checkpointing_kwargs model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) if args is not None: args = dataclasses.replace(args, gradient_checkpointing=False) - elif getattr(args, "gradient_checkpointing", False): + elif getattr(args, "gradient_checkpointing", False) and ( + "use_reentrant" not in gradient_checkpointing_kwargs + or gradient_checkpointing_kwargs["use_reentrant"] + ): # For backward compatibility with older versions of transformers if hasattr(model, "enable_input_require_grads"): model.enable_input_require_grads()