Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[megatron] feat: support qwen2 megatron backend #261

Merged
merged 8 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions .github/workflows/e2e_gsm8k_megatron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,14 @@ jobs:
- name: Prepare gsm8k dataset
run: |
python3 examples/data_preprocess/gsm8k.py
- name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron
- name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron (Deepseek)
run: |
ray stop --force
[ ! -d "$HOME/Megatron-LM" ] && git clone -b core_v0.4.0_verl https://github.com/eric-haibin-lin/Megatron-LM $HOME/Megatron-LM
export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM
bash tests/e2e/run_deepseek_megatron.sh
bash tests/e2e/run_deepseek_megatron.sh
- name: Running gsm8k e2e training tests on 8 L20 GPUs with Megatron (Qwen)
run: |
ray stop --force
export PYTHONPATH=$PYTHONPATH:$HOME/Megatron-LM
bash tests/e2e/run_qwen_megatron.sh
42 changes: 42 additions & 0 deletions examples/ppo_trainer/run_qwen2-7b_math_gsm8k_megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
set -x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could u later also add a section to https://github.com/volcengine/verl/blob/main/docs/experiment/ppo.rst , add a new table for MATH, and include the logs, command, and test score in next PR


export VLLM_ATTENTION_BACKEND=XFORMERS

gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
math_train_path=$HOME/data/math/train.parquet
math_test_path=$HOME/data/math/test.parquet

train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"

python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.val_batch_size=6312 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
critic.optim.lr=1e-5 \
critic.model.path=Qwen/Qwen2-7B-Instruct \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_megatron_math_gsm8k_examples' \
trainer.experiment_name='qwen2_7b_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=100 $@
41 changes: 41 additions & 0 deletions tests/e2e/run_qwen_megatron.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
set -x

# the config file used: verl/trainer/main_ppo/config/ppo_megatron_trainer.yaml

huggingface-cli download Qwen/Qwen2.5-0.5B

python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.train_batch_size=1024 \
data.val_batch_size=1312 \
data.max_prompt_length=512 \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B \
actor_rollout_ref.actor.optim.lr=2e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=2 \
critic.optim.lr=2e-5 \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.megatron.tensor_model_parallel_size=2 \
algorithm.kl_ctrl.kl_coef=0.001 \
trainer.critic_warmup=0 \
trainer.logger=['console'] \
trainer.project_name='verl_megatron_gsm8k_examples' \
trainer.experiment_name='qwen2_5_0b5_function_rm' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=-1 \
trainer.test_freq=1 \
trainer.total_epochs=15 \
trainer.total_training_steps=3 $@
5 changes: 3 additions & 2 deletions verl/models/llama/megatron/modeling_llama_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,15 +513,16 @@ def forward(self,

class ParallelLlamaForCausalLMRmPadPP(nn.Module):

def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process):
def __init__(self, config: LlamaConfig, megatron_config: ModelParallelConfig, pre_process, post_process,
share_embeddings_and_output_weights):
super().__init__()
self.config = config
self.megatron_config = megatron_config
self.model = ParallelLlamaModelRmPadPP(config,
megatron_config=megatron_config,
pre_process=pre_process,
post_process=post_process)
self.share_embeddings_and_output_weights = None # workaround, megatron requires this attr
self.share_embeddings_and_output_weights = share_embeddings_and_output_weights
self.vocab_size = config.vocab_size
self.pre_process = pre_process
self.post_process = post_process
Expand Down
13 changes: 13 additions & 0 deletions verl/models/qwen2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
24 changes: 24 additions & 0 deletions verl/models/qwen2/megatron/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .modeling_qwen2_megatron import (
# original model with megatron
ParallelQwen2Model,
ParallelQwen2ForCausalLM,
# rmpad with megatron
ParallelQwen2ForCausalLMRmPad,
ParallelQwen2ForValueRmPad,
# rmpad with megatron and pipeline parallelism
ParallelQwen2ForCausalLMRmPadPP,
ParallelQwen2ForValueRmPadPP)
13 changes: 13 additions & 0 deletions verl/models/qwen2/megatron/checkpoint_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading