Skip to content

[🐯+GRPO] Support FSDP + Fix bug when using LigerGRPO with DDP #3260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 49 commits into from
Apr 30, 2025

Conversation

shivam15s
Copy link
Contributor

@shivam15s shivam15s commented Apr 8, 2025

What does this PR do?

This PR aims to do two things:

  1. The recent integration of LigerGRPO had a bug: when using DDP and performing a forward pass through a submodule of the unwrapped model, the necessary hooks weren't registered correctly. This caused the model weights across GPUs to fall out of sync. To fix this, the PR introduces a Forward Redirection mechanism—a workaround that ensures hooks are properly registered (compatible with both DDP and FSDP) and enables the custom forward pass required by Liger.
  2. Add support for FSDP to GRPO Trainer. We leverage summon_full_params to make model.generate work with FSDP.

Experiment Script: https://gist.github.com/shivam15s/08a9bccd0d72dd0d29bdb912cb9885be

DDP: Liger (blue) v Non-liger (black)
image

FSDP: Liger (green) v Non-liger (Purple)
image

Known Limitations with FSDP (can add support in subsequent PR(s))

  1. sync_ref_model not supported currently
  2. create_reference_model not supported currently

Benchmarking:
Dist Strategy: DDP
7 policy workers, 1 vllm worker (8 h100)
image

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif kashif self-assigned this Apr 8, 2025
@kashif
Copy link
Collaborator

kashif commented Apr 11, 2025

testing using:

import torch
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import torch.distributed as dist
from torch.profiler import profile, record_function, ProfilerActivity
from transformers import TrainerCallback
import os
# from torch.distributed.fsdp import FSDPConfig, AutoWrapPolicy
# dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
dataset = load_dataset("trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness", split="train")
# only keep the prompt column
dataset = dataset.map(lambda x: {"prompt": x["prompt"]}, remove_columns=dataset.column_names)

training_args = GRPOConfig(
    output_dir="./scratch_dir",
    learning_rate=0.001,  # increase the learning rate to speed up the test
    per_device_train_batch_size=3,  # reduce the batch size to reduce memory usage
    num_generations=3,  # reduce the number of generations to reduce memory usage
    report_to=["tensorboard"],
    max_completion_length=256,  # reduce the completion length to reduce memory usage
    logging_steps=1,
    save_strategy="no",
    max_steps=50,
    use_liger_loss=True,
)
trainer = GRPOTrainer(
    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
    reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
    args=training_args,
    train_dataset=dataset,
)

class ProfCallback(TrainerCallback):
    def __init__(self, prof):
        self.prof = prof

    def on_step_end(self, args, state, control, **kwargs):
        self.prof.step()

# Create directory for profiling outputs
os.makedirs("profiling_results", exist_ok=True)

# Define profiling context manager
def train_with_profiling(enable_profiling=True):
    if enable_profiling:
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            record_shapes=True,
            profile_memory=True,
            with_stack=True,
            with_flops=True,
            on_trace_ready=torch.profiler.tensorboard_trace_handler("profiling_results") if trainer.accelerator.is_main_process else None,
            schedule=torch.profiler.schedule(
                wait=1,
                warmup=1,
                active=2,
                repeat=1),
        ) as prof:
            trainer.add_callback(ProfCallback(prof))
            trainer.train()
        # Print profiling results summary
        # print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
    else:
        trainer.train()

# trainer.train()
train_with_profiling(enable_profiling=False)

# destroy process group
if dist.is_initialized():
    dist.destroy_process_group()

@kashif kashif marked this pull request as ready for review April 11, 2025 20:20
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@kashif kashif mentioned this pull request Apr 29, 2025
4 tasks
@@ -407,7 +410,7 @@ def __init__(
if self.beta == 0.0:
# If beta is 0.0, the reference model is not needed
self.ref_model = None
elif is_deepspeed_zero3_enabled():
elif is_deepspeed_zero3_enabled() or args.fsdp_config is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.fsdp_config defaults to

{'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}

when fsdp is not enabled. Probably want args.fsdp instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking

Copy link
Collaborator

@LeonEricsson LeonEricsson Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ran with

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer


def main():
    # Load dataset
    train_dataset = load_dataset("trl-lib/tldr", split="train[:128]")

    def reward_len(completions, **kwargs):
        return [-abs(20 - len(completion)) for completion in completions]

    # Train model
    training_args = GRPOConfig(
        output_dir=f"./output",
        logging_steps=10,
        bf16=True,
        max_prompt_length=250,
        max_completion_length=250,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        num_generations=2,
        num_train_epochs=1,
        do_eval=True,
        optim="paged_adamw_8bit",
        max_steps=10,
        report_to="none",
    )

    print(training_args.fsdp_config)

    trainer = GRPOTrainer(
        args=training_args,
        model="Qwen/Qwen2.5-0.5B-Instruct",
        train_dataset=train_dataset,
        eval_dataset=train_dataset,
        reward_funcs=reward_len,
    )

    trainer.train()


if __name__ == "__main__":
    main()

Copy link
Member

@lewtun lewtun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the nice PR @shivam15s ! LGTM with a small change on how we determine if FSP is enable for the ref model

kashif and others added 2 commits April 30, 2025 11:23
@kashif
Copy link
Collaborator

kashif commented Apr 30, 2025

@LeonEricsson I have integrated the commit from @jglaser here

@kashif kashif merged commit 09b669f into huggingface:main Apr 30, 2025
9 checks passed
@LeonEricsson LeonEricsson mentioned this pull request May 9, 2025
@thepowerfuldeez
Copy link
Contributor

Are there plans to support FSDP2?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants