Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SFT] Support context parallelism for SFT #132

Merged
merged 68 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
88878e4
support skip template apply; directly read prompt as str for sft dataset
xingyaoww Jan 15, 2025
d5f0fad
add initial lora sft support
Jiayi-Pan Jan 15, 2025
66ff737
add ring_attn utils
Jiayi-Pan Jan 15, 2025
dabfdab
fix config serialization
Jiayi-Pan Jan 15, 2025
902ddbe
improve hyper-params
Jiayi-Pan Jan 15, 2025
6a43cb8
Merge commit '902ddbe6c26623f9e9e511e55abe9f8676707ff2' into dev
xingyaoww Jan 15, 2025
b478b04
minor ux tweaks
Jiayi-Pan Jan 20, 2025
c7aaed6
support label smoothing
xingyaoww Jan 22, 2025
ed7377c
init
Jiayi-Pan Jan 22, 2025
5b4a800
add sft example
Jiayi-Pan Jan 22, 2025
d123e3b
enable prompt+response filter
xingyaoww Jan 23, 2025
a486ed4
Merge commit '5b4a80000effa8b19b6e2b1e8ea8a63e6fa0274b' into dev
xingyaoww Jan 23, 2025
94ef011
add label smoothing to sp
xingyaoww Jan 23, 2025
843a4e1
add wip
xingyaoww Jan 23, 2025
596494e
Merge commit '884a7273123c3996c0970fce90b2ace55e251203' into dev
xingyaoww Jan 23, 2025
775b546
fix backward shape mismatch
xingyaoww Jan 23, 2025
37e672a
get seq parallel working!
xingyaoww Jan 23, 2025
e2cc036
clean up debug print and fix debug using the latest signature
xingyaoww Jan 23, 2025
6194d1e
rm filter prompt and response now we supported 128k SFT
xingyaoww Jan 23, 2025
eb59cc1
remove extra debug
xingyaoww Jan 23, 2025
ef38f9a
add tqdm to show training progress
xingyaoww Jan 23, 2025
17d521f
Merge upstream/main into feature/lora
openhands-agent Jan 24, 2025
4b75406
chore: remove unused ring_attn_utils.py and add peft dependency
openhands-agent Jan 24, 2025
537eb7e
fix: address PR comments - add default LoRA config, fix script name, …
openhands-agent Jan 24, 2025
8866aba
refactor: use is_lora parameter instead of checking peft_config directly
openhands-agent Jan 24, 2025
77baf5c
refactor: simplify fsdp wrap policy for LoRA using is_lora flag
openhands-agent Jan 24, 2025
31f44ad
revert: remove unnecessary LoRA changes from fsdp_workers.py
openhands-agent Jan 24, 2025
decfe61
merge: upstream/main into feature/lora
openhands-agent Jan 24, 2025
e4e7af3
Merge commit 'decfe61082a9c96116cad4aef9534aea4a8e3b00' into dev
xingyaoww Jan 24, 2025
92cd230
ci: add workflow for LoRA testing
openhands-agent Jan 24, 2025
4772756
Add peft to pyproject.toml dependencies
openhands-agent Jan 25, 2025
3f27c0d
Run formatter
openhands-agent Jan 25, 2025
0130411
fix peft script
xingyaoww Jan 25, 2025
e272ce6
Merge commit '22e93114e3122c622399554a0605a4926a77f27e' into dev
xingyaoww Jan 25, 2025
127e9ef
Merge commit '01304116170ea4e42cd0996e5eae647b81923b94' into dev
xingyaoww Jan 25, 2025
d40cd66
revert some unnecessary changes
xingyaoww Jan 25, 2025
bbfa712
revert uncessary changes
xingyaoww Jan 25, 2025
41590a7
tweak debug into assert
xingyaoww Jan 25, 2025
f9f1c37
add a bunch of timers
xingyaoww Jan 25, 2025
b3ae176
address comments
xingyaoww Jan 25, 2025
596fcb7
remove total epochs
xingyaoww Jan 25, 2025
84dfcf4
Revert "add a bunch of timers"
xingyaoww Jan 25, 2025
9bd522d
refactor: move debug functionality to tests/sft/test_trainer.py
openhands-agent Jan 25, 2025
04e6d05
Merge commit '596fcb7e42bd53309e482f717095983d4af2866b' into xw/cp-sft
xingyaoww Jan 25, 2025
d2e76d0
feat: add test script for sequence parallelism with size 2
openhands-agent Jan 25, 2025
dd293e4
chore: remove outdated comment from run_qwen_05_sp2.sh
openhands-agent Jan 25, 2025
5086eba
chore: remove unnecessary debug flag from run_qwen_05_sp2.sh
openhands-agent Jan 25, 2025
ba27b63
chore: remove sft.sh
openhands-agent Jan 25, 2025
24ad93f
Merge branch 'main' into xw/cp-sft
xingyaoww Jan 25, 2025
93e6f2d
remove unused
xingyaoww Jan 25, 2025
d5d8da6
Add license header to test_trainer.py
openhands-agent Jan 25, 2025
6934a70
Fix code formatting
openhands-agent Jan 25, 2025
68cfd0a
update scripts and default config
xingyaoww Jan 25, 2025
720ee5c
cleanup debug print
xingyaoww Jan 25, 2025
e0919e6
fix microbsz
xingyaoww Jan 25, 2025
f8c1de0
remove label smoothing
xingyaoww Jan 25, 2025
5a02740
add sequence parallism to CI
xingyaoww Jan 25, 2025
4adb707
add ci to check the loss diff between sp and default implementation
xingyaoww Jan 25, 2025
4e38c46
remove uncessary rope scaling
xingyaoww Jan 25, 2025
b08abcd
Address review comments:
openhands-agent Jan 26, 2025
79d4a3f
Merge loss computation implementations to reduce code duplication
openhands-agent Jan 26, 2025
818e5aa
Apply code formatting
openhands-agent Jan 26, 2025
d665b9c
Add missing nullcontext import
openhands-agent Jan 26, 2025
e05d273
cleanup
xingyaoww Jan 26, 2025
b38db7b
adjust order for cleaner git diff
xingyaoww Jan 26, 2025
d7d2683
make sure the test works
xingyaoww Jan 26, 2025
c5e0f1b
format script
xingyaoww Jan 27, 2025
6f90fb7
Merge branch 'main' into xw/cp-sft
xingyaoww Jan 27, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .github/workflows/e2e_sft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,12 @@ jobs:
- name: Running gsm8k e2e training tests on 8 L20 GPUs with rmpad using function rm
run: |
ray stop --force
bash tests/sft/run_sft.sh
bash tests/sft/run_sft.sh
- name: Running gsm8k e2e training tests on 8 L20 GPUs with sequence parallism
run: |
ray stop --force
bash examples/sft/gsm8k/run_qwen_05_sp2.sh 8 $HOME/ckpts/
- name: Check loss difference between sequence parallel vs. default implementation
run: |
ray stop --force
bash tests/sft/run_sft_sp_loss_match.sh
32 changes: 32 additions & 0 deletions examples/sft/gsm8k/run_qwen_05_sp2.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
set -x

if [ "$#" -lt 2 ]; then
echo "Usage: run_qwen_05_sp2.sh <nproc_per_node> <save_path> [other_configs...]"
exit 1
fi

nproc_per_node=$1
save_path=$2

# Shift the arguments so $@ refers to the rest
shift 2

torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \
-m verl.trainer.fsdp_sft_trainer \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.prompt_key=extra_info \
data.response_key=extra_info \
optim.lr=1e-4 \
+data.prompt_dict_keys=['question'] \
+data.response_dict_keys=['answer'] \
data.micro_batch_size=4 \
model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
trainer.default_local_dir=$save_path \
trainer.project_name=gsm8k-sft \
trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2 \
trainer.logger=['console'] \
trainer.total_training_steps=1 \
trainer.default_hdfs_dir=null $@ \
ulysses_sequence_parallel_size=2 \
use_remove_padding=true
24 changes: 24 additions & 0 deletions tests/sft/run_sft_sp_loss_match.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Tested with 2 & 4 GPUs

set -x

torchrun --standalone --nnodes=1 --nproc_per_node=8 \
tests/sft/test_sp_loss_match.py \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.prompt_key=extra_info \
data.response_key=extra_info \
+data.prompt_dict_keys=['question'] \
+data.response_dict_keys=['answer'] \
data.micro_batch_size=32 \
model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \
ulysses_sequence_parallel_size=2 \
use_remove_padding=True \
trainer.default_local_dir=$HOME/ckpts/ \
trainer.project_name=qwen2.5-sft \
trainer.experiment_name=gsm8k-sft-gemma-2b-it \
trainer.total_training_steps=1 \
trainer.logger=['console'] \
trainer.default_hdfs_dir=null $@

rm -rf $HOME/ckpts/
128 changes: 128 additions & 0 deletions tests/sft/test_sp_loss_match.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.distributed
from tensordict import TensorDict
from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer
from torch.distributed.device_mesh import init_device_mesh
from verl.utils.distributed import initialize_global_process_group


def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4):
"""Test consistency between original forward pass and SP+rmpad forward passes.

Args:
trainer: The FSDPSFTTrainer instance to test
total_steps: Number of steps to test (default: 4)
"""
if trainer.device_mesh.get_rank() == 0:
print("\nStarting debug comparison between original and SP+rmpad forward passes...")
print(f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}")
print(f"Remove padding: {trainer.use_remove_padding}\n")

steps_remaining = total_steps

for epoch in range(1): # Just one epoch for testing
trainer.train_sampler.set_epoch(epoch=epoch)
for data in trainer.train_dataloader:
data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda()
trainer.fsdp_model.train()
micro_batches = data.split(trainer.config.data.micro_batch_size)

for idx, micro_batch in enumerate(micro_batches):
if trainer.device_mesh.get_rank() == 0:
print(f"\nProcessing micro batch {idx + 1}/{len(micro_batches)}")

# Compute losses using both methods
# Disable SP and rmpad
trainer.use_remove_padding = False
old_sp = trainer.config.ulysses_sequence_parallel_size
trainer.config.ulysses_sequence_parallel_size = 1
loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)

# Do SP and rmpad
trainer.config.ulysses_sequence_parallel_size = old_sp
trainer.use_remove_padding = True
loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False)

# Collect losses across all ranks
loss_ref_all = loss_ref.clone()
loss_sp_all = loss_sp.clone()
torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG)
torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG)

# Calculate relative difference of averaged losses
rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8)

if trainer.device_mesh.get_rank() == 0:
print("\nComparison Results (Averaged across ranks):")
print(f"Reference Loss: {loss_ref_all.item():.6f}")
print(f"SP+rmpad Loss: {loss_sp_all.item():.6f}")
print(f"Relative Difference: {rel_diff.item():.6f}")

assert rel_diff.item() < 1e-2, "Significant difference detected between averaged losses!"
print("Loss difference is within the acceptable range.")

steps_remaining -= 1
if steps_remaining == 0:
break
if steps_remaining == 0:
break
break

if trainer.device_mesh.get_rank() == 0:
print("\nDebug comparison completed successfully.")


def create_trainer(config):
"""Create and initialize a trainer instance with the given config.

Args:
config: Configuration object with training parameters

Returns:
FSDPSFTTrainer: Initialized trainer instance
"""
local_rank, rank, world_size = initialize_global_process_group()

device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',))

dp_size = world_size // config.ulysses_sequence_parallel_size
ulysses_device_mesh = init_device_mesh(device_type='cuda',
mesh_shape=(dp_size, config.ulysses_sequence_parallel_size),
mesh_dim_names=('dp', 'sp'))

return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh)


def main(config):
"""Main function to run trainer tests.

Args:
config: Configuration object with training parameters
"""
trainer = create_trainer(config)
test_trainer_forward_consistency(trainer)


if __name__ == '__main__':
import hydra
from omegaconf import DictConfig

@hydra.main(config_path="../../verl/trainer/config", config_name="sft_trainer")
def hydra_entry(cfg: DictConfig) -> None:
main(cfg)

hydra_entry()
5 changes: 3 additions & 2 deletions verl/trainer/config/sft_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ model:
trust_remote_code: False
lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32)
lora_alpha: 16 # LoRA scaling factor
target_modules: [q_proj, v_proj] # Target modules for LoRA adaptation
target_modules: all-linear # Target modules for LoRA adaptation
optim:
lr: 1e-5
betas: [0.9, 0.95]
weight_decay: 0.01
warmup_steps_ratio: 0.1
clip_grad: 1.0

ulysses_sequence_parallel_size: 1
use_remove_padding: False
trainer:
default_local_dir: /tmp/sft_model
default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here
Expand Down
Loading
Loading