@@ -202,9 +202,14 @@ def _patched_maybe_force_spawn():
202202 logger .info ("Successfully patched vllm.utils._maybe_force_spawn." )
203203
204204 def _patch_vllm_init_workers_ray ():
205- # Patch the vLLM ray_distributed_executor.py file to pass custom runtime_env in _init_workers_ray call.
206- # This allows passing custom py_executable to worker initialization.
205+ """Patch the vLLM ray_distributed_executor.py file.
207206
207+ 1. Pass custom runtime_env in _init_workers_ray call.
208+ - This allows passing custom py_executable to worker initialization.
209+ 2. Add NCCL_CUMEM_ENABLE and NCCL_NVLS_ENABLE to vLLM ADDITIONAL_ENV_VARS.
210+ - This is a workaround to fix async vllm in some scenarios.
211+ - See https://github.com/NVIDIA-NeMo/RL/pull/898 for more details.
212+ """
208213 try :
209214 import vllm .executor .ray_distributed_executor as ray_executor_module
210215
@@ -213,26 +218,36 @@ def _patch_vllm_init_workers_ray():
213218 with open (file_to_patch , "r" ) as f :
214219 content = f .read ()
215220
216- old_line = "self._init_workers_ray(placement_group)"
217- new_line = f'self._init_workers_ray(placement_group, runtime_env={{"py_executable": "{ self .py_executable } "}})'
221+ old_lines = [
222+ "self._init_workers_ray(placement_group)" ,
223+ 'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}' ,
224+ ]
218225
219- if new_line in content :
220- return
226+ new_lines = [
227+ f'self._init_workers_ray(placement_group, runtime_env={{"py_executable": "{ self .py_executable } "}})' ,
228+ 'ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN", "NCCL_CUMEM_ENABLE", "NCCL_NVLS_ENABLE"}' ,
229+ ]
221230
222- if old_line not in content :
223- return
231+ need_replace = False
232+ for old_line , new_line in zip (old_lines , new_lines ):
233+ if new_line in content or old_line not in content :
234+ continue
235+ content = content .replace (old_line , new_line )
236+ need_replace = True
224237
225- patched_content = content .replace (old_line , new_line )
238+ if not need_replace :
239+ return
226240
227241 # Write back the patched content
228242 with open (file_to_patch , "w" ) as f :
229- f .write (patched_content )
243+ f .write (content )
230244
231245 except (ImportError , FileNotFoundError , PermissionError ):
232246 # Allow failures gracefully
233247 pass
234248
235249 _patch_vllm_init_workers_ray ()
250+ logger .info ("Successfully patched vllm _init_workers_ray." )
236251
237252 except (ImportError , AttributeError ):
238253 # vllm not installed or has a different structure, skipping patch.
0 commit comments