-
Notifications
You must be signed in to change notification settings - Fork 2k
[🐯+GRPO] Support FSDP + Fix bug when using LigerGRPO with DDP #3260
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
testing using: import torch
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
import torch.distributed as dist
from torch.profiler import profile, record_function, ProfilerActivity
from transformers import TrainerCallback
import os
# from torch.distributed.fsdp import FSDPConfig, AutoWrapPolicy
# dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
dataset = load_dataset("trl-lib/ultrafeedback-gpt-3.5-turbo-helpfulness", split="train")
# only keep the prompt column
dataset = dataset.map(lambda x: {"prompt": x["prompt"]}, remove_columns=dataset.column_names)
training_args = GRPOConfig(
output_dir="./scratch_dir",
learning_rate=0.001, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
report_to=["tensorboard"],
max_completion_length=256, # reduce the completion length to reduce memory usage
logging_steps=1,
save_strategy="no",
max_steps=50,
use_liger_loss=True,
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)
class ProfCallback(TrainerCallback):
def __init__(self, prof):
self.prof = prof
def on_step_end(self, args, state, control, **kwargs):
self.prof.step()
# Create directory for profiling outputs
os.makedirs("profiling_results", exist_ok=True)
# Define profiling context manager
def train_with_profiling(enable_profiling=True):
if enable_profiling:
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
on_trace_ready=torch.profiler.tensorboard_trace_handler("profiling_results") if trainer.accelerator.is_main_process else None,
schedule=torch.profiler.schedule(
wait=1,
warmup=1,
active=2,
repeat=1),
) as prof:
trainer.add_callback(ProfCallback(prof))
trainer.train()
# Print profiling results summary
# print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
else:
trainer.train()
# trainer.train()
train_with_profiling(enable_profiling=False)
# destroy process group
if dist.is_initialized():
dist.destroy_process_group() |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@@ -407,7 +410,7 @@ def __init__( | |||
if self.beta == 0.0: | |||
# If beta is 0.0, the reference model is not needed | |||
self.ref_model = None | |||
elif is_deepspeed_zero3_enabled(): | |||
elif is_deepspeed_zero3_enabled() or args.fsdp_config is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args.fsdp_config
defaults to
{'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}
when fsdp is not enabled. Probably want args.fsdp
instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ran with
from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer
def main():
# Load dataset
train_dataset = load_dataset("trl-lib/tldr", split="train[:128]")
def reward_len(completions, **kwargs):
return [-abs(20 - len(completion)) for completion in completions]
# Train model
training_args = GRPOConfig(
output_dir=f"./output",
logging_steps=10,
bf16=True,
max_prompt_length=250,
max_completion_length=250,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
num_generations=2,
num_train_epochs=1,
do_eval=True,
optim="paged_adamw_8bit",
max_steps=10,
report_to="none",
)
print(training_args.fsdp_config)
trainer = GRPOTrainer(
args=training_args,
model="Qwen/Qwen2.5-0.5B-Instruct",
train_dataset=train_dataset,
eval_dataset=train_dataset,
reward_funcs=reward_len,
)
trainer.train()
if __name__ == "__main__":
main()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the nice PR @shivam15s ! LGTM with a small change on how we determine if FSP is enable for the ref model
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
@LeonEricsson I have integrated the commit from @jglaser here |
Are there plans to support FSDP2? |
What does this PR do?
This PR aims to do two things:
Experiment Script: https://gist.github.com/shivam15s/08a9bccd0d72dd0d29bdb912cb9885be
DDP: Liger (blue) v Non-liger (black)

FSDP: Liger (green) v Non-liger (Purple)

Known Limitations with FSDP (can add support in subsequent PR(s))
Benchmarking:

Dist Strategy: DDP
7 policy workers, 1 vllm worker (8 h100)
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.