Skip to content

Commit d22d309

Browse files
yuki-97soodoshll
authored andcommitted
fix: fix async vllm nccl fail on dsv3 tp16pp2 and non-colocated on single node (NVIDIA-NeMo#898)
Signed-off-by: Yuki Huang <yukih@nvidia.com> Signed-off-by: Qidong Su <qidongs@nvidia.com>
1 parent 062c398 commit d22d309

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -202,9 +202,14 @@ def _patched_maybe_force_spawn():
202202
logger.info("Successfully patched vllm.utils._maybe_force_spawn.")
203203

204204
def _patch_vllm_init_workers_ray():
205-
# Patch the vLLM ray_distributed_executor.py file to pass custom runtime_env in _init_workers_ray call.
206-
# This allows passing custom py_executable to worker initialization.
205+
"""Patch the vLLM ray_distributed_executor.py file.
207206
207+
1. Pass custom runtime_env in _init_workers_ray call.
208+
- This allows passing custom py_executable to worker initialization.
209+
2. Add NCCL_CUMEM_ENABLE and NCCL_NVLS_ENABLE to vLLM ADDITIONAL_ENV_VARS.
210+
- This is a workaround to fix async vllm in some scenarios.
211+
- See https://github.com/NVIDIA-NeMo/RL/pull/898 for more details.
212+
"""
208213
try:
209214
import vllm.executor.ray_distributed_executor as ray_executor_module
210215

@@ -213,26 +218,36 @@ def _patch_vllm_init_workers_ray():
213218
with open(file_to_patch, "r") as f:
214219
content = f.read()
215220

216-
old_line = "self._init_workers_ray(placement_group)"
217-
new_line = f'self._init_workers_ray(placement_group, runtime_env={{"py_executable": "{self.py_executable}"}})'
221+
old_lines = [
222+
"self._init_workers_ray(placement_group)",
223+
'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}',
224+
]
218225

219-
if new_line in content:
220-
return
226+
new_lines = [
227+
f'self._init_workers_ray(placement_group, runtime_env={{"py_executable": "{self.py_executable}"}})',
228+
'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "NCCL_CUMEM_ENABLE", "NCCL_NVLS_ENABLE"}',
229+
]
221230

222-
if old_line not in content:
223-
return
231+
need_replace = False
232+
for old_line, new_line in zip(old_lines, new_lines):
233+
if new_line in content or old_line not in content:
234+
continue
235+
content = content.replace(old_line, new_line)
236+
need_replace = True
224237

225-
patched_content = content.replace(old_line, new_line)
238+
if not need_replace:
239+
return
226240

227241
# Write back the patched content
228242
with open(file_to_patch, "w") as f:
229-
f.write(patched_content)
243+
f.write(content)
230244

231245
except (ImportError, FileNotFoundError, PermissionError):
232246
# Allow failures gracefully
233247
pass
234248

235249
_patch_vllm_init_workers_ray()
250+
logger.info("Successfully patched vllm _init_workers_ray.")
236251

237252
except (ImportError, AttributeError):
238253
# vllm not installed or has a different structure, skipping patch.

0 commit comments

Comments
 (0)