Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SFTTrainer] Fix Trainer when args is None #1064

Merged
merged 3 commits into from
Dec 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,19 @@ def __init__(
inspect.signature(prepare_model_for_kbit_training).parameters
)

preprare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
preprare_model_kwargs = {
"use_gradient_checkpointing": getattr(args, "gradient_checkpointing", False)
}

if _support_gc_kwargs:
preprare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
preprare_model_kwargs["gradient_checkpointing_kwargs"] = getattr(
args, "gradient_checkpointing_kwargs", None
)

model = prepare_model_for_kbit_training(model, **preprare_model_kwargs)

args = dataclasses.replace(args, gradient_checkpointing=False)
if args is not None:
args = dataclasses.replace(args, gradient_checkpointing=False)
elif getattr(args, "gradient_checkpointing", False):
# For backward compatibility with older versions of transformers
if hasattr(model, "enable_input_require_grads"):
Expand Down
Loading