Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: Triton error when initializing LLM(...) when enable_lora=True and cuda device != cuda:0 #12967

Closed
1 task done
tchang1997 opened this issue Feb 8, 2025 · 7 comments · Fixed by #13027
Closed
1 task done
Labels
bug Something isn't working

Comments

@tchang1997
Copy link

tchang1997 commented Feb 8, 2025

Your current environment

The output of `python collect_env.py`
2025-02-08 12:26:31.089573: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1739035591.107776 3635054 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1739035591.111986 3635054 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-08 12:26:31.127454: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO 02-08 12:26:33 __init__.py:183] Automatically detected platform cuda.
Collecting environment information...
PyTorch version: 2.5.1+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.5 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.26.4
Libc version: glibc-2.35

Python version: 3.12.8 (main, Jan 14 2025, 22:49:14) [Clang 19.1.6 ] (64-bit runtime)
Python platform: Linux-5.15.0-131-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA RTX A6000
GPU 1: NVIDIA RTX A6000
GPU 2: NVIDIA RTX A6000
GPU 3: NVIDIA RTX A6000
GPU 4: NVIDIA RTX A6000
GPU 5: NVIDIA RTX A6000
GPU 6: NVIDIA RTX A6000
GPU 7: NVIDIA RTX A6000

Nvidia driver version: 550.120
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        48 bits physical, 48 bits virtual
Byte Order:                           Little Endian
CPU(s):                               256
On-line CPU(s) list:                  0-255
Vendor ID:                            AuthenticAMD
Model name:                           AMD EPYC 7763 64-Core Processor
CPU family:                           25
Model:                                1
Thread(s) per core:                   2
Core(s) per socket:                   64
Socket(s):                            2
Stepping:                             1
Frequency boost:                      enabled
CPU max MHz:                          3529.0520
CPU min MHz:                          1500.0000
BogoMIPS:                             4890.88
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 invpcid_single hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca
Virtualization:                       AMD-V
L1d cache:                            4 MiB (128 instances)
L1i cache:                            4 MiB (128 instances)
L2 cache:                             64 MiB (128 instances)
L3 cache:                             512 MiB (16 instances)
NUMA node(s):                         2
NUMA node0 CPU(s):                    0-63,128-191
NUMA node1 CPU(s):                    64-127,192-255
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Mitigation; safe RET
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Not affected

Versions of relevant libraries:
[pip3] msgpack-numpy==0.4.8
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-ml-py==12.570.86
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] optree==0.14.0
[pip3] pytorch-triton-rocm==0.0.1
[pip3] pyzmq==26.2.1
[pip3] torch==2.5.1
[pip3] torchaudio==2.5.1
[pip3] torchvision==0.20.1
[pip3] transformers==4.48.2
[pip3] triton==3.1.0
[conda] numpy                     1.25.0                   pypi_0    pypi
[conda] nvidia-cublas-cu11        11.10.3.66               pypi_0    pypi
[conda] nvidia-cuda-cupti-cu11    11.7.101                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu11    11.7.99                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu11  11.7.99                  pypi_0    pypi
[conda] nvidia-cudnn-cu11         8.5.0.96                 pypi_0    pypi
[conda] nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
[conda] nvidia-curand-cu11        10.2.10.91               pypi_0    pypi
[conda] nvidia-cusolver-cu11      11.4.0.1                 pypi_0    pypi
[conda] nvidia-cusparse-cu11      11.7.4.91                pypi_0    pypi
[conda] nvidia-nccl-cu11          2.14.3                   pypi_0    pypi
[conda] nvidia-nvtx-cu11          11.7.91                  pypi_0    pypi
[conda] torch                     2.0.1                    pypi_0    pypi
[conda] torch-ema                 0.3                      pypi_0    pypi
[conda] triton                    2.0.0                    pypi_0    pypi
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.7.1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	GPU7	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	NODE	NODE	NODE	SYS	SYS	SYS	SYS	0-63,128-191	0		N/A
GPU1	NODE	 X 	NODE	NODE	SYS	SYS	SYS	SYS	0-63,128-191	0		N/A
GPU2	NODE	NODE	 X 	NODE	SYS	SYS	SYS	SYS	0-63,128-191	0		N/A
GPU3	NODE	NODE	NODE	 X 	SYS	SYS	SYS	SYS	0-63,128-191	0		N/A
GPU4	SYS	SYS	SYS	SYS	 X 	NODE	NODE	NODE	64-127,192-255	1		N/A
GPU5	SYS	SYS	SYS	SYS	NODE	 X 	NODE	NODE	64-127,192-255	1		N/A
GPU6	SYS	SYS	SYS	SYS	NODE	NODE	 X 	NODE	64-127,192-255	1		N/A
GPU7	SYS	SYS	SYS	SYS	NODE	NODE	NODE	 X 	64-127,192-255	1		N/A

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

LD_LIBRARY_PATH=/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/cv2/../../lib64:
NCCL_CUMEM_ENABLE=0
TORCHINDUCTOR_COMPILE_THREADS=1
CUDA_MODULE_LOADING=LAZY

🐛 Describe the bug

When using the LLM class for offline inference with LoRA in a multi-GPU setting, Triton produces an error about being unable to access some pointers.

I'm using vllm for generation in trl's GRPOTrainer class, and I get this error during the their __init__ method. I've managed to isolate the issue to the following blocks in the Python interpreter. It seems like there are some device issues. Here are two examples, one of which works, and one of which produces the error I described:

Works OK:

from peft import get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM
from vllm import LLM

model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
peft_config = LoraConfig(r=16, lora_alpha=64, lora_dropout=0.05, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
model = get_peft_model(model, peft_config)
llm = LLM(
    model.name_or_path,
    device="cuda:0", 
    gpu_memory_utilization=0.5,
    dtype="auto",
    enable_prefix_caching=True,
    max_model_len=4096,
    enable_lora=True,
    max_lora_rank=16
)

Does not work:

from peft import get_peft_model, LoraConfig
from transformers import AutoModelForCausalLM
from vllm import LLM

model = AutoModelForCausalLM.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Llama-8B")
peft_config = LoraConfig(r=16, lora_alpha=64, lora_dropout=0.05, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"])
model = get_peft_model(model, peft_config)
llm = LLM(
    model.name_or_path,
    device="cuda:1", # the only difference 
    gpu_memory_utilization=0.5,
    dtype="auto",
    enable_prefix_caching=True,
    max_model_len=4096,
    enable_lora=True,
    max_lora_rank=16
)
Console output immediately prior to error
INFO 02-08 12:17:15 config.py:526] This model supports multiple tasks: {'reward', 'generate', 'embed', 'score', 'classify'}. Defaulting to 'generate'.
INFO 02-08 12:17:15 llm_engine.py:232] Initializing a V0 LLM engine (v0.7.1) with config: model='deepseek-ai/DeepSeek-R1-Distill-Llama-8B', speculative_config=None, tokenizer='deepseek-ai/DeepSeek-R1-Distill-Llama-8B', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, over
ride_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False,
kv_cache_dtype=auto,  device_config=cuda:1, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=deepseek-ai
/DeepSeek-R1-Distill-Llama-8B, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=False, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting
_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":256}, use_cached_outputs=False,
INFO 02-08 12:17:16 cuda.py:235] Using Flash Attention backend.
INFO 02-08 12:17:17 model_runner.py:1111] Starting to load model deepseek-ai/DeepSeek-R1-Distill-Llama-8B...
INFO 02-08 12:17:17 weight_utils.py:251] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  1.23s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.35s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:02<00:00,  1.33s/it]
INFO 02-08 12:17:20 model_runner.py:1116] Loading model weights took 0.0000 GB
INFO 02-08 12:17:20 punica_selector.py:16] Using PunicaWrapperGPU.
INFO 02-08 12:17:21 model_runner_base.py:120] Writing input of failed execution to /tmp/err_execute_model_input_20250208-121721.pkl...
INFO 02-08 12:17:21 model_runner_base.py:149] Completed writing input of failed execution to /tmp/err_execute_model_input_20250208-121721.pkl.
Full stack trace:
[rank0]: Traceback (most recent call last):
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 116, in _wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1721, in execute_model
[rank0]:     hidden_or_intermediate_states = model_executable(
[rank0]:                                     ^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 539, in forward
[rank0]:     model_output = self.model(input_ids, positions, kv_caches,
[rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/compilation/decorators.py", line 170, in __call__
[rank0]:     return self.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 354, in forward
[rank0]:     hidden_states = self.get_input_embeddings(input_ids)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/model_executor/models/llama.py", line 339, in get_input_embeddings
[rank0]:     return self.embed_tokens(input_ids)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl                                                                                                                                                     [rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/lora/layers.py", line 260, in forward
[rank0]:     self.punica_wrapper.add_lora_embedding(full_output,
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/lora/punica_wrapper/punica_gpu.py", line 203, in add_lora_embedding
[rank0]:     sgmv_expand(
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/lora/ops/triton_ops/sgmv_expand.py", line 220, in _sgmv_expand
[rank0]:     _sgmv_expand_kernel[grid](
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/triton/runtime/jit.py", line 345, in <lambda>
[rank0]:     return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
[rank0]:                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/triton/runtime/jit.py", line 691, in run
[rank0]:     kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/triton/backends/nvidia/driver.py", line 365, in __call__
[rank0]:     self.launch(*args, **kwargs)
[rank0]: ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

[rank0]: The above exception was the direct cause of the following exception:

[rank0]: Traceback (most recent call last):
[rank0]:   File "<stdin>", line 1, in <module>
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/utils.py", line 1039, in inner
[rank0]:     return fn(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 240, in __init__
[rank0]:     self.llm_engine = self.engine_class.from_engine_args(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 482, in from_engine_args
[rank0]:     engine = cls(
[rank0]:              ^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 274, in __init__
[rank0]:     self._initialize_kv_caches()
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/engine/llm_engine.py", line 414, in _initialize_kv_caches
[rank0]:     self.model_executor.determine_num_available_blocks())
[rank0]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 99, in determine_num_available_blocks
[rank0]:     results = self.collective_rpc("determine_num_available_blocks")
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 49, in collective_rpc
[rank0]:     answer = run_method(self.driver_worker, method, args, kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/utils.py", line 2208, in run_method
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/worker.py", line 228, in determine_num_available_blocks
[rank0]:     self.model_runner.profile_run()
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1236, in profile_run
[rank0]:     self._dummy_run(max_num_batched_tokens, max_num_seqs)
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner.py", line 1347, in _dummy_run
[rank0]:     self.execute_model(model_input, kv_caches, intermediate_tensors)
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/data2/ctrenton/uv/llm_server/lib/python3.12/site-packages/vllm/worker/model_runner_base.py", line 152, in _wrapper
[rank0]:     raise type(err)(
[rank0]: ValueError: Error in model execution (input dumped to /tmp/err_execute_model_input_20250208-121721.pkl): Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)

Since I'm using vllm as part of a training loop in trl, I'd rather not poke around their device assignment logic if possible.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.
@tchang1997 tchang1997 added the bug Something isn't working label Feb 8, 2025
@jeejeelee
Copy link
Collaborator

See: triton-lang/triton#2925

@Zzhiter
Copy link
Contributor

Zzhiter commented Feb 9, 2025

the same!

@tchang1997
Copy link
Author

Thanks! Unfortunately, I just tried this new snippet following the issue and I have the same issue:

import torch 
from vllm import LLM

print(torch.cuda.current_device())  # output: 0
torch.cuda.set_device("cuda:1") # also tried .set_device(1)
print(torch.cuda.current_device()) # output: 1
llm = LLM(
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    device="cuda:1", 
    gpu_memory_utilization=0.5,
    dtype="auto",
    enable_prefix_caching=True,
    max_model_len=4096,
    enable_lora=True,
    max_lora_rank=16,
)
print(torch.cuda.current_device()) # output: 0

I noticed that, when I use torch.cuda.set_device, the issue persists even when I take out the enable_lora and max_lora_rank kwargs. Seems like pinning vLLM to a specific GPU might not be feasible as per #3750 ?

@Zzhiter
Copy link
Contributor

Zzhiter commented Feb 9, 2025

Thanks! Unfortunately, I just tried this new snippet following the issue and I have the same issue:

import torch 
from vllm import LLM

print(torch.cuda.current_device())  # output: 0
torch.cuda.set_device("cuda:1") # also tried .set_device(1)
print(torch.cuda.current_device()) # output: 1
llm = LLM(
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    device="cuda:1", 
    gpu_memory_utilization=0.5,
    dtype="auto",
    enable_prefix_caching=True,
    max_model_len=4096,
    enable_lora=True,
    max_lora_rank=16,
)
print(torch.cuda.current_device()) # output: 0

I noticed that, when I use torch.cuda.set_device, the issue persists even when I take out the enable_lora and max_lora_rank kwargs. Seems like pinning vLLM to a specific GPU might not be feasible as per #3750 ?

Did you use triton backend?

@jeejeelee
Copy link
Collaborator

You can try the following code, it workround for me locally.

import torch 
from vllm import LLM

print(torch.cuda.current_device())  # output: 0
torch.cuda.set_device("cuda:1") # also tried .set_device(1)
print(torch.cuda.current_device()) # output: 1
llm = LLM(
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
    device="cuda:0", # or device="cuda" 
    gpu_memory_utilization=0.5,
    dtype="auto",
    enable_prefix_caching=True,
    max_model_len=4096,
    enable_lora=True,
    max_lora_rank=16,
)
print(torch.cuda.current_device()) # output: 0

@tchang1997
Copy link
Author

This doesn't raise an error but does put the vLLM model on the wrong GPU (the original cuda:0 instead of cuda:1) for me.

I found another possible solution here; it seems doable for my own training script, but incorporating it into trl would require a major refactor to the trainer code. I'll update if I get the chance to test that.

@Zzhiter
Copy link
Contributor

Zzhiter commented Feb 11, 2025

I have found a temporary solution from [this commit](https://huggingface.co/microsoft/Phi-3-small-128k-instruct/commit/ed7de9a074b0760e6cf050fe1d103b90834933c8) and [this discussion](https://huggingface.co/microsoft/Phi-3-small-8k-instruct/discussions/23).

It only requires adding with torch.cuda.device() to all Triton operations in LoRA.

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants