From 076b02af3b92db9fbd484de38816116e3eab1fc5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 6 Dec 2023 16:36:11 +0000 Subject: [PATCH 1/2] fix sfttrainer when args is None --- trl/trainer/sft_trainer.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 5e70114f96..d35d9177b8 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -175,24 +175,19 @@ def __init__( inspect.signature(prepare_model_for_kbit_training).parameters ) - preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing} + preprare_model_kwargs = { + "use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False) + } if _support_gc_kwargs: - preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs + preprare_model_kwargs["gradient_checkpointing_kwargs"] = getattr( + args, "gradient_checkpointing_kwargs", None + ) model = prepare_model_for_kbit_training(model, **preprare_model_kwargs) - args = dataclasses.replace(args, gradient_checkpointing=False) - elif getattr(args, "gradient_checkpointing", False): - # For backward compatibility with older versions of transformers - if hasattr(model, "enable_input_require_grads"): - model.enable_input_require_grads() - else: - - def make_inputs_require_grad(module, input, output): - output.requires_grad_(True) - - model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) + if args is not None: + args = dataclasses.replace(args, gradient_checkpointing=False) model = get_peft_model(model, peft_config) From a3d7d47a749831573c565817069cd052d775b0b8 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Wed, 6 Dec 2023 16:39:19 +0000 Subject: [PATCH 2/2] oops --- trl/trainer/sft_trainer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index d35d9177b8..b5d126912d 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -188,6 +188,16 @@ def __init__( if args is not None: args = dataclasses.replace(args, gradient_checkpointing=False) + elif getattr(args, "gradient_checkpointing", False): + # For backward compatibility with older versions of transformers + if hasattr(model, "enable_input_require_grads"): + model.enable_input_require_grads() + else: + + def make_inputs_require_grad(module, input, output): + output.requires_grad_(True) + + model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) model = get_peft_model(model, peft_config)