Skip to content

Conversation

@HollowMan6
Copy link
Contributor

@HollowMan6 HollowMan6 commented Oct 21, 2025

Purpose

Currently, vllm start will fail if we keep the cache unremoved while switch between whether or not we set those RAY_EXPERIMENTAL_NOSET_* environment variables, which are used to control whether Ray will manipulate the visible device env var (CUDA_VISIBLE_DEVICES in Nvidia GPU's case).

  • If we use the cache where RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES is set to vllm instances where those env vars are not set, we will have AcceleratorError: CUDA error: invalid device ordinal (Please find the stack traces below)
  • If we use the cache where RAY_EXPERIMENTAL_NOSET_* is not set to vllm instances where those env vars are set, all the vllm instances will use device(GPU) 0.

This PR aims to solve this by considering these env vars as factors for the hash as well.

    self.inference_engine = LLM(
                            ^^^^
  File "vllm/entrypoints/llm.py", line 324, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/llm_engine.py", line 188, in from_engine_args
    return cls(
           ^^^^
  File "vllm/v1/engine/llm_engine.py", line 122, in __init__
    self.engine_core = EngineCoreClient.make_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/core_client.py", line 95, in make_client
    return InprocClient(vllm_config, executor_class, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/core_client.py", line 264, in __init__
    self.engine_core = EngineCore(*args, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/core.py", line 113, in __init__
    num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
                                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/core.py", line 224, in _initialize_kv_caches
    available_gpu_memory = self.model_executor.determine_available_memory()
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/executor/abstract.py", line 136, in determine_available_memory
    memory = super().determine_available_memory()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/executor/abstract.py", line 88, in determine_available_memory
    return self.collective_rpc("determine_available_memory")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/executor/uniproc_executor.py", line 75, in collective_rpc
    return [run_method(self.driver_worker, method, args, kwargs)]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/utils/__init__.py", line 1047, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/worker/gpu_worker.py", line 281, in determine_available_memory
    self.model_runner.profile_run()
  File "vllm/v1/worker/gpu_model_runner.py", line 3737, in profile_run
    hidden_states, last_hidden_states = self._dummy_run(
                                        ^^^^^^^^^^^^^^^^
  File "torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/worker/gpu_model_runner.py", line 3470, in _dummy_run
    outputs = self.model(
              ^^^^^^^^^^^
  File "vllm/compilation/cuda_graph.py", line 126, in __call__
    return self.runnable(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/model_executor/models/qwen3_vl.py", line 1746, in forward
    hidden_states = self.language_model.model(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/decorators.py", line 408, in __call__
    output = self.compiled_callable(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
    raise BackendCompilerFailed(
  File "torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/__init__.py", line 2425, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/backends.py", line 658, in __call__
    PiecewiseCompileInterpreter(
  File "vllm/compilation/backends.py", line 384, in run
    return super().run(*fake_args)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/fx/interpreter.py", line 173, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "torch/fx/interpreter.py", line 242, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/backends.py", line 404, in call_module
    self.vllm_backend.compiler_manager.compile(
  File "vllm/compilation/backends.py", line 198, in compile
    compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/backends.py", line 158, in load
    compiled_graph = self.compiler.load(
                     ^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/compiler_interface.py", line 537, in load
    inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/codecache.py", line 1316, in _lookup_graph
    return FxGraphCache.cache_hit_post_compile(graph, cache_info, constants)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/codecache.py", line 1213, in cache_hit_post_compile
    artifact_path = graph.after_deserialization(constants)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/output_code.py", line 698, in after_deserialization
    code_cache = PyCodeCache.load_by_key_path(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/codecache.py", line 3296, in load_by_key_path
    mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/runtime/compile_tasks.py", line 31, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File “~/.cache/vllm/torch_compile_cache/de5ae84ed5/rank_4_0/inductor_cache/ux/cuxpwzi4jcfj5ozmjblzut4dmz4wd7rvb3l2lhnk6cratergustx.py", line 56, in <module>
    triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', '''
                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/async_compile.py", line 350, in triton
    return future.result()
           ^^^^^^^^^^^^^^^
  File "torch/_inductor/codecache.py", line 4037, in result
    self.static_autotuner.precompile(  # type: ignore[union-attr]
  File "torch/_inductor/runtime/triton_heuristics.py", line 411, in precompile
    self._make_launchers()
  File "torch/_inductor/runtime/triton_heuristics.py", line 561, in _make_launchers
    with DeviceGuard(device_interface, self.triton_meta["device"]):
  File "torch/_dynamo/device_interface.py", line 193, in __enter__
    self.prev_idx = self.device_interface.exchange_device(self.idx)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='<vllm.compilation.backends.VllmBackend object at 0x7ee4e050f210>' raised:
AcceleratorError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Test Plan

Tested locally with both of the situations

Test Result

Now 2 different cache with different hash id will be produced (w or w/o RAY_EXPERIMENTAL_NOSET_*)


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a caching bug related to Ray's RAY_EXPERIMENTAL_NOSET_* environment variables by including them in the compilation cache hash. This prevents crashes and incorrect device allocation when these environment variables change. My review includes a suggestion to improve the maintainability of the fix by adding a more explicit warning about the hardcoded list of environment variables, which could become a source of bugs in the future if not kept in sync with Ray's codebase.

@HollowMan6
Copy link
Contributor Author

Just saw some related issues #23107, #16501, and a related ongoing PR at #26468, but that won't solve the problem here as RAY_EXPERIMENTAL_NOSET_* are not defined as vllm's environment variables.

cc: @ProExpertProg to get some feedback

@ProExpertProg
Copy link
Collaborator

cc @zou3519 for caching

@HollowMan6 HollowMan6 force-pushed the torch-compile-cache branch 2 times, most recently from b040e1a to a409475 Compare October 27, 2025 19:28
@zou3519 zou3519 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 27, 2025
@HollowMan6 HollowMan6 force-pushed the torch-compile-cache branch 2 times, most recently from 2e6779f to 5b46295 Compare October 28, 2025 06:35
…ion cache

Currently, vllm start will fail if we keep the cache unremoved while switch
between whether or not we set those RAY_EXPERIMENTAL_NOSET_* environment
variables, which are used to control whether Ray will manipulate the
visible device env var (`CUDA_VISIBLE_DEVICES` in Nvidia GPU's case).

- If we use the cache where RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES is
set to vllm instances where those env vars are not set, we will have
`AcceleratorError: CUDA error: invalid device ordinal` (Please find the
stack traces below)
- If we use the cache where RAY_EXPERIMENTAL_NOSET_* is not set to
vllm instances where those env vars are set, all the vllm instances
will use device(GPU) 0.

This PR aims to solve this by considering these env vars as factors
for the hash as well.

```log
    self.inference_engine = LLM(
                            ^^^^
  File "vllm/entrypoints/llm.py", line 324, in __init__
    self.llm_engine = LLMEngine.from_engine_args(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/llm_engine.py", line 188, in from_engine_args
    return cls(
           ^^^^
  File "vllm/v1/engine/llm_engine.py", line 122, in __init__
    self.engine_core = EngineCoreClient.make_client(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/core_client.py", line 95, in make_client
    return InprocClient(vllm_config, executor_class, log_stats)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/core_client.py", line 264, in __init__
    self.engine_core = EngineCore(*args, **kwargs)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/core.py", line 113, in __init__
    num_gpu_blocks, num_cpu_blocks, kv_cache_config = self._initialize_kv_caches(
                                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/engine/core.py", line 224, in _initialize_kv_caches
    available_gpu_memory = self.model_executor.determine_available_memory()
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/executor/abstract.py", line 136, in determine_available_memory
    memory = super().determine_available_memory()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/executor/abstract.py", line 88, in determine_available_memory
    return self.collective_rpc("determine_available_memory")
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/executor/uniproc_executor.py", line 75, in collective_rpc
    return [run_method(self.driver_worker, method, args, kwargs)]
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/utils/__init__.py", line 1047, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/worker/gpu_worker.py", line 281, in determine_available_memory
    self.model_runner.profile_run()
  File "vllm/v1/worker/gpu_model_runner.py", line 3737, in profile_run
    hidden_states, last_hidden_states = self._dummy_run(
                                        ^^^^^^^^^^^^^^^^
  File "torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "vllm/v1/worker/gpu_model_runner.py", line 3470, in _dummy_run
    outputs = self.model(
              ^^^^^^^^^^^
  File "vllm/compilation/cuda_graph.py", line 126, in __call__
    return self.runnable(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/nn/modules/module.py", line 1784, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/model_executor/models/qwen3_vl.py", line 1746, in forward
    hidden_states = self.language_model.model(
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/decorators.py", line 408, in __call__
    output = self.compiled_callable(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/eval_frame.py", line 749, in compile_wrapper
    raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/output_graph.py", line 1871, in _call_user_compiler
    raise BackendCompilerFailed(
  File "torch/_dynamo/output_graph.py", line 1846, in _call_user_compiler
    compiled_fn = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_dynamo/repro/after_dynamo.py", line 150, in __call__
    compiled_gm = compiler_fn(gm, example_inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/__init__.py", line 2425, in __call__
    return self.compiler_fn(model_, inputs_, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/backends.py", line 658, in __call__
    PiecewiseCompileInterpreter(
  File "vllm/compilation/backends.py", line 384, in run
    return super().run(*fake_args)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/fx/interpreter.py", line 173, in run
    self.env[node] = self.run_node(node)
                     ^^^^^^^^^^^^^^^^^^^
  File "torch/fx/interpreter.py", line 242, in run_node
    return getattr(self, n.op)(n.target, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/backends.py", line 404, in call_module
    self.vllm_backend.compiler_manager.compile(
  File "vllm/compilation/backends.py", line 198, in compile
    compiled_graph = self.load(graph, example_inputs, graph_index, runtime_shape)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/backends.py", line 158, in load
    compiled_graph = self.compiler.load(
                     ^^^^^^^^^^^^^^^^^^^
  File "vllm/compilation/compiler_interface.py", line 537, in load
    inductor_compiled_graph, _ = FxGraphCache._lookup_graph(
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/codecache.py", line 1316, in _lookup_graph
    return FxGraphCache.cache_hit_post_compile(graph, cache_info, constants)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/codecache.py", line 1213, in cache_hit_post_compile
    artifact_path = graph.after_deserialization(constants)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/output_code.py", line 698, in after_deserialization
    code_cache = PyCodeCache.load_by_key_path(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/codecache.py", line 3296, in load_by_key_path
    mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/runtime/compile_tasks.py", line 31, in _reload_python_module
    exec(code, mod.__dict__, mod.__dict__)
  File “~/.cache/vllm/torch_compile_cache/de5ae84ed5/rank_4_0/inductor_cache/ux/cuxpwzi4jcfj5ozmjblzut4dmz4wd7rvb3l2lhnk6cratergustx.py", line 56, in <module>
    triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0 = async_compile.triton('triton_red_fused__to_copy_add_mean_mul_pow_rsqrt_0', '''
                                                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "torch/_inductor/async_compile.py", line 350, in triton
    return future.result()
           ^^^^^^^^^^^^^^^
  File "torch/_inductor/codecache.py", line 4037, in result
    self.static_autotuner.precompile(  # type: ignore[union-attr]
  File "torch/_inductor/runtime/triton_heuristics.py", line 411, in precompile
    self._make_launchers()
  File "torch/_inductor/runtime/triton_heuristics.py", line 561, in _make_launchers
    with DeviceGuard(device_interface, self.triton_meta["device"]):
  File "torch/_dynamo/device_interface.py", line 193, in __enter__
    self.prev_idx = self.device_interface.exchange_device(self.idx)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._dynamo.exc.BackendCompilerFailed: backend='<vllm.compilation.backends.VllmBackend object at 0x7ee4e050f210>' raised:
AcceleratorError: CUDA error: invalid device ordinal
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
```

Signed-off-by: Hollow Man <hollowman@opensuse.org>
@HollowMan6
Copy link
Contributor Author

Now CI has passed! cc @zou3519 @ProExpertProg

@zou3519 zou3519 merged commit 936643a into vllm-project:main Oct 28, 2025
45 checks passed
@HollowMan6 HollowMan6 deleted the torch-compile-cache branch October 28, 2025 14:28
bhagyashrigai pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Oct 29, 2025
…ion cache (vllm-project#27294)

Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Bhagyashri <Bhagyashri.Gaikwad2@ibm.com>
ilmarkov pushed a commit to neuralmagic/vllm that referenced this pull request Nov 7, 2025
…ion cache (vllm-project#27294)

Signed-off-by: Hollow Man <hollowman@opensuse.org>
ZhengHongming888 pushed a commit to ZhengHongming888/vllm that referenced this pull request Nov 8, 2025
…ion cache (vllm-project#27294)

Signed-off-by: Hollow Man <hollowman@opensuse.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants