Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hangs during vllm rollout, no error message #12

Open
Vamix opened this issue Nov 12, 2024 · 3 comments
Open

Hangs during vllm rollout, no error message #12

Vamix opened this issue Nov 12, 2024 · 3 comments

Comments

@Vamix
Copy link

Vamix commented Nov 12, 2024

Hi veRL team, thanks for open-sourcing the great framework. I have successfully run the ppo training of qwen2-7b using 2 nodes, so I think there is no problem with my environment. But I encountered an issue when trying to run ppo training of qwen2.5 32b model with 8 nodes.

The config is https://github.com/volcengine/verl/blob/main/examples/ppo_trainer/run_qwen2.5-32b.sh. First, I found it triggered OOM using the default setting, so I changed trainer.nnodes into 8. Then, when using 8 nodes to run the ppo training, I find it stuck during vllm rollout. Even turned on vllm debug flags, there is no error message, the last output I can see is

INFO 11-12 22:04:23 metrics.py:406] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 103.0 tokens/s, Running: 2 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 11.0%, CPU KV cache usage: 0.0%.

Have you ever run into this hang issue? Hope you can share some suggestions for debugging. Thanks a lot!

@PeterSH6
Copy link
Collaborator

Hi @Vamix, glad to hear that you successfully run qwen2-7b using 2 nodes.

For the hanging issue, we didn't meet this before.
I wonder if you can check which line of code is stucked by launching the qwen2.5-32b script with CUDA_LAUNCH_BLOCKING=1 environ set and using py-spy dump --pid [one-of-ray-process] to insepct the inner vllm process.

From the vllm debug message, I guess it's stucked at the prefill stage (as Running: 2reqs).
Another way to debug is to change qwen2.5-32b model to other models with similar size.

Let's see if these debugging methods could help.

@Vamix
Copy link
Author

Vamix commented Nov 13, 2024

Hi @PeterSH6 , thanks for your reply, and I find two problems in your code that may lead to the hanging issue:

(1) When config.rollout.log_prob_micro_batch_size is smaller than world_size.
It is very obscure that the config.rollout.log_prob_micro_batch_size will be divided by device_mesh.shape[0] (code). And the config.rollout.log_prob_micro_batch_size will be used to split a tensor dict(code). If the config.rollout.log_prob_micro_batch_size is smaller than world_size, it will become 0, and the batch.split(micro_batch_size) will hang when micro_batch_size = 0. (There is a while loop in batch.split, if micro_batch_size = 0, the while loop will not end. please refer to code).

(2) When the evaluation dataset size is not divisible by the world_size, and tensor parallel is enabled for vllm.
Because verl will try to all-gather data within tensor parallel groups (code), however after the division of evaluation dataset, some ranks may hold different data size as the dataset size is not divisible. So the program hangs at torch.distributed.all_gather().

Hope you can add more checks in the code to raise errors when some hyper-parameters are not set correctly. And more instructions on how to set the hyper-parameters will be appreciated (e.g., which hyper-parameters should be divisible by which one).

However, even I have fixed the above issues, I still cannot run the distributed training with tensor parallel rollout successfully. I'm facing new issues of CUDA error:

:task_name:main_task
Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): �[36mray::WorkerDict.actor_rollout_generate_sequences()�[39m (pid=6843, ip=10.51.4.32, actor_id=e2e75b1e6165ace12db813b606000000, repr=<single_controller.ray.base.WorkerDict object at 0x7f147e1bc6a0>)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/vamix/verl/verl/trainer/ppo/rollout/vllm_rollout/vllm_rollout.py", line 180, in generate_sequences
    output = self.inference_engine.generate(
  File "/usr/local/lib/python3.10/dist-packages/vllm/utils.py", line 895, in inner
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/llm.py", line 330, in generate
    outputs = self._run_engine(use_tqdm=use_tqdm)
  File "/home/vamix/verl/verl/third_party/vllm/vllm_v_0_5_4/llm.py", line 183, in _run_engine
    step_outputs = self.llm_engine.step()
  File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 919, in step
    output = self.model_executor.execute_model(
  File "/home/vamix/verl/verl/third_party/vllm/vllm_v_0_5_4/spmd_gpu_executor.py", line 157, in execute_model
    all_outputs = self.worker.execute_model(execute_model_req=execute_model_req)
  File "/home/vamix/verl/verl/third_party/vllm/vllm_v_0_5_4/worker.py", line 261, in execute_model
    return self.model_runner.execute_model(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 1363, in execute_model
    hidden_or_intermediate_states = model_executable(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2.py", line 360, in forward
    hidden_states = self.model(input_ids, positions, kv_caches,
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2.py", line 276, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2.py", line 210, in forward
    hidden_states = self.self_attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/qwen2.py", line 157, in forward
    attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/vllm/attention/layer.py", line 98, in forward
    return self.impl.forward(query,
  File "/usr/local/lib/python3.10/dist-packages/vllm/attention/backends/flash_attn.py", line 556, in forward
    output[num_prefill_tokens:] = flash_attn_with_kvcache(
  File "/usr/local/lib/python3.10/dist-packages/vllm_flash_attn/flash_attn_interface.py", line 1295, in flash_attn_with_kvcache
    out, softmax_lse = flash_attn_cuda.fwd_kvcache(
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

�[36mray::WorkerDict.actor_rollout_generate_sequences()�[39m (pid=6843, ip=10.51.4.32, actor_id=e2e75b1e6165ace12db813b606000000, repr=<single_controller.ray.base.WorkerDict object at 0x7f147e1bc6a0>)
  File "/home/vamix/verl/single_controller/ray/base.py", line 395, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/home/vamix/verl/single_controller/base/decorator.py", line 404, in inner
    return func(*args, **kwargs)
  File "/home/vamix/verl/verl/trainer/ppo/workers/fsdp_workers.py", line 358, in generate_sequences
    with self.sharding_manager:
  File "/home/vamix/verl/verl/trainer/ppo/hybrid_engine/fsdp_vllm.py", line 82, in __exit__
    torch.cuda.empty_cache()
  File "/usr/local/lib/python3.10/dist-packages/torch/cuda/memory.py", line 170, in empty_cache
    torch._C._cuda_emptyCache()
RuntimeError: CUDA error: an illegal memory access was encountered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

I suspect it is still related to wrong setting of some hyper-parameters, my script is as blow:

set -x

gsm8k_train_path=/home/vamix/datasets/gsm8k/train.parquet
gsm8k_test_path=/home/vamix/datasets/gsm8k/test.parquet
math_train_path=/home/vamix/datasets/math/train.parquet
math_test_path=/home/vamix/datasets/math/test.parquet

train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"

python3 -m verl.trainer.main_ppo \
    data.train_files="$train_files" \
    data.val_files="$test_files" \
    data.train_batch_size=1024 \
    data.val_batch_size=6144 \
    data.max_prompt_length=1024 \
    data.max_response_length=1024 \
    actor_rollout_ref.model.path=/home/vamix/models/huggingface.co/Qwen/Qwen2.5-32B-Instruct \
    actor_rollout_ref.model.enable_gradient_checkpointing=False \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.ppo_mini_batch_size=512 \
    actor_rollout_ref.actor.ppo_micro_batch_size=64 \
    actor_rollout_ref.actor.fsdp_config.param_offload=False \
    actor_rollout_ref.actor.fsdp_config.grad_offload=False \
    actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
    actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \
    actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
    actor_rollout_ref.rollout.name=vllm \
    actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \
    actor_rollout_ref.ref.log_prob_micro_batch_size=64 \
    actor_rollout_ref.ref.fsdp_config.param_offload=True \
    critic.optim.lr=1e-5 \
    critic.model.path=/home/vamix/models/huggingface.co/Qwen/Qwen2.5-32B-Instruct \
    critic.model.enable_gradient_checkpointing=False \
    critic.ppo_micro_batch_size=64 \
    critic.model.fsdp_config.param_offload=False \
    critic.model.fsdp_config.grad_offload=False \
    critic.model.fsdp_config.optimizer_offload=False \
    algorithm.kl_ctrl.kl_coef=0.0001 \
    trainer.critic_warmup=0 \
    trainer.logger=['console'] \
    trainer.project_name='verl_example' \
    trainer.experiment_name='Qwen2.5-32B-Instruct_function_rm' \
    trainer.n_gpus_per_node=8 \
    trainer.nnodes=8 \
    trainer.save_freq=-1 \
    trainer.test_freq=10 \
    trainer.total_epochs=15

Could you give some suggestions for debugging this? Thanks a lot!

@PeterSH6
Copy link
Collaborator

PeterSH6 commented Nov 14, 2024

Hi @Vamix, Thanks for your constructive advice!

When config.rollout.log_prob_micro_batch_size is smaller than world_size.
It is very obscure that the config.rollout.log_prob_micro_batch_size will be divided by device_mesh.shape[0] (code). And the config.rollout.log_prob_micro_batch_size will be used to split a tensor dict(code).

The reason that log_prob_micro_batch_size will be divided by the device_mesh.shape[0] (i.e., FSDP world_size) is that we consider the xx_micro_batch_size from the perspective of the single-controller. Therefore, when it's input to the FSDPWorkerGroup, it should then partition based on the world_size.

You are right that this may be confusing for users. We will discuss this issue and possibly add more tutorials/assertions for better usability.

(2) When the evaluation dataset size is not divisible by the world_size, and tensor parallel is enabled for vllm.
Because verl will try to all-gather data within tensor parallel groups (code), however after the division of evaluation dataset, some ranks may hold different data size as the dataset size is not divisible. So the program hangs at torch.distributed.all_gather().

Great findings! This could be a defect issue, we may add some dummy samples to align the eval dataset size. Do you have any quick workaround to contribute to verl directly?

Hope you can add more checks in the code to raise errors when some hyper-parameters are not set correctly. And more instructions on how to set the hyper-parameters will be appreciated (e.g., which hyper-parameters should be divisible by which one).

Will do so. Thanks for your suggestion.

CUDA illegal memory error

We also encounter this issue at random times when using qwen2.5 (other models seem to be fine). We found that this may related to an internal bug in flash_attn or vLLM. See vllm-project/vllm#5687 and vllm-project/vllm#5376

Can you try it using a different backend of vLLM export VLLM_ATTENTION_BACKEND=XFORMERS
(If your are submitting tasks to ray cluster, you should add this environ to runtime_env.yaml

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants