Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient Checkpointing Cause Issues in Reward Modeling #694

Closed
BaleChen opened this issue Aug 25, 2023 · 3 comments
Closed

Gradient Checkpointing Cause Issues in Reward Modeling #694

BaleChen opened this issue Aug 25, 2023 · 3 comments

Comments

@BaleChen
Copy link

Hi, I'm reporting an issue about gradient checkpointing and my current workaround.

Context: Reward modeling with LlamaForSequenceClassification (7b) and peft LoRa.

Issue
If I run it with the following code:

training_args = TrainingArguments(
    output_dir=script_args.output_dir,

    per_device_train_batch_size=script_args.per_device_batch_size,
    per_device_eval_batch_size=script_args.per_device_batch_size,
    gradient_accumulation_steps=script_args.gradient_accumulation_steps,

    optim="adamw_torch",
    learning_rate=script_args.learning_rate,
    lr_scheduler_type=script_args.lr_scheduler_type,
    warmup_steps=script_args.num_warmup_steps,

    fp16=not script_args.no_fp16,
    bf16=script_args.bf16,

    remove_unused_columns=False,
    evaluation_strategy="steps" if script_args.eval_split != "none" else "no",
    save_strategy="steps" if not script_args.debug else "no",
    max_steps=script_args.max_steps if not script_args.debug else 5,
    eval_steps=script_args.eval_freq if not script_args.debug else 5,
    logging_steps=script_args.logging_steps if not script_args.debug else 1,
    save_steps=script_args.save_freq,
    weight_decay=script_args.weight_decay,

    run_name=script_args.exp_name,
    report_to="wandb" if script_args.log_with == "wandb" else None,
)

# Step 4: Define the LoraConfig
if script_args.use_peft:
    peft_config = LoraConfig(r=script_args.lora_rank, lora_alpha=32, bias="none", task_type="SEQ_CLS", modules_to_save=["scores"])
else:
    peft_config = None

# Step 5: Define the Trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=eval_data,
    peft_config=peft_config,
    max_length=script_args.seq_length,
)
trainer.train()

It would lead to the following error:
RuntimeError: Expected to mark a variable ready only once. This error is caused by one of the following reasons: 1) Use of a module parameter outside the forwardfunction. Please make sure model parameters are not shared across multiple concurrent forward-backward passes. or try to use _set_static_graph() as a workaround if this module graph does not change during training loop.2) Reused parameters in multiple reentrant backward passes. For example, if you use multiplecheckpoint functions to wrap the same part of your model, it would result in the same set of parameters been used by different reentrant backward passes multiple times, and hence marking a variable ready multiple times. DDP does not support such use cases in default. You can try to use _set_static_graph() as a workaround if your module graph does not change over iterations. Parameter at index 127 has been marked as ready twice.

My workaround:

I looked through some issue reports and found that the error is related to the gradient checkpointing. In trainer/reward_trainer.py, line 108: prepare_model_for_int8_training would enable gradient checkpointing automatically (See peft/utils/other.py, line 58-70).

So, I switched it off by adding use_gradient_checkpointing=False to prepare_model_for_int8_training function call.

Can you let me know if there's any error in my implementation? Otherwise, I think we probably need to add a gradient checkpointing kwarg in the reward trainer to avoid the same issue.

@mnoukhov
Copy link
Contributor

This seems like a duplicate of #480

@BaleChen
Copy link
Author

Oh yes, sorry for the duplicate, and thanks for the reference.

mnoukhov added a commit to mnoukhov/trl that referenced this issue Aug 30, 2023
from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for huggingface#694
@lvwerra
Copy link
Member

lvwerra commented Aug 31, 2023

Closing this in favour of the other issue.

@lvwerra lvwerra closed this as completed Aug 31, 2023
mnoukhov added a commit to mnoukhov/trl that referenced this issue Sep 2, 2023
from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for huggingface#694
younesbelkada pushed a commit that referenced this issue Sep 12, 2023
* update to `prepare_model_for_kbit_training`

from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for #694

* workaround for gradient checkpointing issue

calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop

also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
kushal-tri pushed a commit to kushalarora/trl that referenced this issue Sep 19, 2023
* update to `prepare_model_for_kbit_training`

from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for huggingface#694

* workaround for gradient checkpointing issue

calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop

also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
lapp0 pushed a commit to lapp0/trl that referenced this issue May 10, 2024
* update to `prepare_model_for_kbit_training`

from deprecated `prepare_model_for_int8_training`
and add `use_gradient_checkpointing=args.gradient_checkpointing` to
automatically follow the gradient checkpointing choice

is also the workaround for huggingface#694

* workaround for gradient checkpointing issue

calling model.gradient_checkpointing_enable() twice causes issues
this workaround calls it in prepare_model_for_kbit_training and then
changes the arg to false to make sure it isn't called again in
huggingface trainer inner loop

also changes stack_llama_2 sft trainer to use correct device map for ddp
training so that you can test this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants