From e53dcdb9a379a76a769f8aca0d3bebd1f0c51fe2 Mon Sep 17 00:00:00 2001 From: BearBiscuit <55008898+BearBiscuit05@users.noreply.github.com> Date: Mon, 24 Feb 2025 12:25:18 +0800 Subject: [PATCH] [fix] Improve the params template for generation (#351) fix the issue[#331](https://github.com/volcengine/verl/issues/331) --- .github/workflows/vllm.yml | 5 +++++ tests/generation/run_gen_qwen05.sh | 30 ++++++++++++++++++++++++++ verl/trainer/config/generation.yaml | 33 ++++++++++++++++++++++++++++- verl/trainer/main_generation.py | 2 +- 4 files changed, 68 insertions(+), 2 deletions(-) create mode 100755 tests/generation/run_gen_qwen05.sh diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 1f40adca..6bf304b5 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -51,3 +51,8 @@ jobs: pip3 install --upgrade vllm cd tests/rollout torchrun --standalone --nnodes=1 --nproc_per_node=4 $(which pytest) -s test_vllm_spmd.py + - name: Run QWen 0.5B generation test + run: | + cd tests/generation + bash ./run_gen_qwen05.sh 4 $HOME/data/gen/qwen_05_gen_test.parquet + rm -rf $HOME/data/gen/qwen_05_gen_test.parquet diff --git a/tests/generation/run_gen_qwen05.sh b/tests/generation/run_gen_qwen05.sh new file mode 100755 index 00000000..d660559e --- /dev/null +++ b/tests/generation/run_gen_qwen05.sh @@ -0,0 +1,30 @@ +# Tested with 1 & 4 GPUs +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_gen_qwen05.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +python3 -m verl.trainer.main_generation \ + trainer.nnodes=1 \ + trainer.n_gpus_per_node=$nproc_per_node \ + data.path=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=prompt \ + data.n_samples=1 \ + data.output_path=$save_path \ + model.path=Qwen/Qwen2.5-0.5B-Instruct \ + +model.trust_remote_code=True \ + rollout.temperature=1.0 \ + rollout.top_k=50 \ + rollout.top_p=0.7 \ + rollout.prompt_length=2048 \ + rollout.response_length=1024 \ + rollout.tensor_model_parallel_size=2 \ + rollout.gpu_memory_utilization=0.8 diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index 27d92116..14cd2e5d 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -32,4 +32,35 @@ rollout: log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu log_prob_micro_batch_size_per_gpu: 8 # for hf rollout - do_sample: True \ No newline at end of file + do_sample: True + disable_log_stats: True + enable_chunked_prefill: True + n: 1 +actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + min_num_params: 0 + param_offload: False + optimizer_offload: False + fsdp_size: -1 \ No newline at end of file diff --git a/verl/trainer/main_generation.py b/verl/trainer/main_generation.py index 8c3bd923..044c6e4d 100644 --- a/verl/trainer/main_generation.py +++ b/verl/trainer/main_generation.py @@ -59,7 +59,7 @@ def main(config): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='rollout') + ray_cls_with_init = RayClassWithInitArgs(cls=ray.remote(ActorRolloutRefWorker), config=config, role='actor_rollout') resource_pool = RayResourcePool(process_on_nodes=[config.trainer.n_gpus_per_node] * config.trainer.nnodes) wg = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=ray_cls_with_init) wg.init_model()