Skip to content

Conversation

@Qubitium
Copy link
Contributor

@Qubitium Qubitium commented Mar 22, 2025

With vllm main + DeekSeek R1 GPTQModel quantized model + Marlin + Triton MLA V1 + kv/prefill caching == Crash due to x (input) not contiguous.

Core cause: The kv/prefill caching is passing non-contiguous x (input) to lower kernels and some (Marlin) kernels only operate on contiguous memory layout.

Reproduce: Execute the same prompt/request twice and it will crash with below stacktrace.

I only fixed the Marlin code path. Not sure if other kernels are affected or if my fix location should be moved upwards in the logic chain.

@mgoin

(vm312) root@gpu-base:~/vllm#  CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 CUDA_DEVICE_ORDER=PCI_BUS_ID PYTROCH_CUDA_ALLOC_CONF=expandable_segments:True VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve gptqmodel-deepseek-r1 --max-model-len 8196 --gpu_memory_utilization 0.90 --trust-remote-code --tensor-parallel-size 8
INFO 03-22 02:25:31 [__init__.py:256] Automatically detected platform cuda.
INFO 03-22 02:25:33 [api_server.py:981] vLLM API server version 0.8.2.dev60+gcfbb8c93.d20250321
INFO 03-22 02:25:33 [api_server.py:982] args: Namespace(subparser='serve', model_tag='gptqmodel-deepseek-r1', config='', host=None, port=8000, uvicorn_log_level='info', disable_uvicorn_access_log=False, allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, chat_template_content_format='auto', response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, enable_ssl_refresh=False, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_request_id_headers=False, enable_auto_tool_choice=False, tool_call_parser=None, tool_parser_plugin='', model='/monster/data/xl/thyrotherapyformol_dataset_1024_maxtokens_512_bits_4_group_size_128_damp_percent_0_0025_true_sequential_True_desc_act_True_mse_2_4', task='auto', tokenizer=None, hf_config_path=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=True, allowed_local_media_path=None, download_dir=None, load_format='auto', config_format=<ConfigFormat.AUTO: 'auto'>, dtype='auto', kv_cache_dtype='auto', max_model_len=8196, guided_decoding_backend='xgrammar', logits_processor_pattern=None, model_impl='auto', distributed_executor_backend=None, pipeline_parallel_size=1, tensor_parallel_size=8, enable_expert_parallel=False, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=None, enable_prefix_caching=None, disable_sliding_window=False, use_v2_block_manager=True, num_lookahead_slots=0, seed=None, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_partial_prefills=1, max_long_partial_prefills=1, long_prefill_token_threshold=0, max_num_seqs=None, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, hf_overrides=None, enforce_eager=False, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, disable_mm_preprocessor_cache=False, enable_lora=False, enable_lora_bias=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, use_tqdm_on_load=True, multi_step_stream_outputs=True, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_disable_mqa_scorer=False, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, show_hidden_metrics_for_version=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, scheduling_policy='fcfs', scheduler_cls='vllm.core.scheduler.Scheduler', override_neuron_config=None, override_pooler_config=None, compilation_config=None, kv_transfer_config=None, worker_cls='auto', worker_extension_cls='', generation_config='auto', override_generation_config=None, enable_sleep_mode=False, calculate_kv_scales=False, additional_config=None, enable_reasoning=False, reasoning_parser=None, disable_cascade_attn=False, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, enable_prompt_tokens_details=False, enable_server_load_tracking=False, dispatch_function=<function ServeSubcommand.cmd at 0x729c358d1580>)
INFO 03-22 02:25:33 [config.py:208] Replacing legacy 'type' key with 'rope_type'
INFO 03-22 02:25:41 [config.py:585] This model supports multiple tasks: {'score', 'reward', 'generate', 'classify', 'embed'}. Defaulting to 'generate'.
INFO 03-22 02:25:42 [gptq_marlin.py:143] The model is convertible to gptq_marlin during runtime. Using gptq_marlin kernel.
INFO 03-22 02:25:42 [config.py:1509] Defaulting to use mp for distributed inference
INFO 03-22 02:25:42 [config.py:1687] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 03-22 02:25:47 [__init__.py:256] Automatically detected platform cuda.
INFO 03-22 02:25:50 [core.py:54] Initializing a V1 LLM engine (v0.8.2.dev60+gcfbb8c93.d20250321) with config: model='/monster/data/xl/thyrotherapyformol_dataset_1024_maxtokens_512_bits_4_group_size_128_damp_percent_0_0025_true_sequential_True_desc_act_True_mse_2_4', speculative_config=None, tokenizer='/monster/data/xl/thyrotherapyformol_dataset_1024_maxtokens_512_bits_4_group_size_128_damp_percent_0_0025_true_sequential_True_desc_act_True_mse_2_4', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=True, dtype=torch.bfloat16, max_seq_len=8196, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=8, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=gptq_marlin, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=/monster/data/xl/thyrotherapyformol_dataset_1024_maxtokens_512_bits_4_group_size_128_damp_percent_0_0025_true_sequential_True_desc_act_True_mse_2_4, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=True, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"level":3,"custom_ops":["none"],"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"use_inductor":true,"compile_sizes":[],"use_cudagraph":true,"cudagraph_num_of_warmups":1,"cudagraph_capture_sizes":[512,504,496,488,480,472,464,456,448,440,432,424,416,408,400,392,384,376,368,360,352,344,336,328,320,312,304,296,288,280,272,264,256,248,240,232,224,216,208,200,192,184,176,168,160,152,144,136,128,120,112,104,96,88,80,72,64,56,48,40,32,24,16,8,4,2,1],"max_capture_size":512}
WARNING 03-22 02:25:50 [multiproc_worker_utils.py:306] Reducing Torch parallelism from 32 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 03-22 02:25:50 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0, 1, 2, 3, 4, 5, 6, 7], buffer_handle=(8, 10485760, 10, 'psm_a70d36fc'), local_subscribe_addr='ipc:///tmp/fd0e6d82-8afa-4674-9912-476465be5ac0', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 03-22 02:25:55 [__init__.py:256] Automatically detected platform cuda.
WARNING 03-22 02:25:57 [utils.py:2316] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7d3918afddc0>
(VllmWorker rank=0 pid=3742482) INFO 03-22 02:25:57 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_1c82d3af'), local_subscribe_addr='ipc:///tmp/809bff3f-4320-4f14-946f-bcee88ff5616', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 03-22 02:26:02 [__init__.py:256] Automatically detected platform cuda.
WARNING 03-22 02:26:05 [utils.py:2316] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7b1b728176e0>
(VllmWorker rank=1 pid=3742520) INFO 03-22 02:26:05 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_12e12717'), local_subscribe_addr='ipc:///tmp/99d7933c-43d5-4105-ba2b-c88113a0bc05', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 03-22 02:26:10 [__init__.py:256] Automatically detected platform cuda.
WARNING 03-22 02:26:13 [utils.py:2316] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x76bb9aafa030>
(VllmWorker rank=2 pid=3742599) INFO 03-22 02:26:13 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_624dc89a'), local_subscribe_addr='ipc:///tmp/7751d295-0592-4f2c-b1e4-a411da44abaf', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 03-22 02:26:18 [__init__.py:256] Automatically detected platform cuda.
WARNING 03-22 02:26:21 [utils.py:2316] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x77439e3a0650>
(VllmWorker rank=3 pid=3742640) INFO 03-22 02:26:21 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_0c3bb63c'), local_subscribe_addr='ipc:///tmp/32e07bcc-13a8-4568-b324-0683fe92071d', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 03-22 02:26:26 [__init__.py:256] Automatically detected platform cuda.
WARNING 03-22 02:26:29 [utils.py:2316] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7723d097e000>
(VllmWorker rank=4 pid=3742678) INFO 03-22 02:26:29 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_1ee5d2f3'), local_subscribe_addr='ipc:///tmp/871cdfa5-be9c-425b-b306-4902da087dc7', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 03-22 02:26:34 [__init__.py:256] Automatically detected platform cuda.
WARNING 03-22 02:26:37 [utils.py:2316] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x746981c6e000>
(VllmWorker rank=5 pid=3742717) INFO 03-22 02:26:37 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_48117b61'), local_subscribe_addr='ipc:///tmp/99ef1818-60da-4f55-9437-1e6b51151174', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 03-22 02:26:43 [__init__.py:256] Automatically detected platform cuda.
WARNING 03-22 02:26:46 [utils.py:2316] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7aaaa80de000>
(VllmWorker rank=6 pid=3742758) INFO 03-22 02:26:46 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_07b12084'), local_subscribe_addr='ipc:///tmp/b0bc0bc0-8f0a-4171-9784-c63ad1c54755', remote_subscribe_addr=None, remote_addr_ipv6=False)
INFO 03-22 02:26:51 [__init__.py:256] Automatically detected platform cuda.
WARNING 03-22 02:26:54 [utils.py:2316] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7d926c11c980>
(VllmWorker rank=7 pid=3742791) INFO 03-22 02:26:54 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_ea0f719b'), local_subscribe_addr='ipc:///tmp/de0cf33d-63e6-4a8e-be01-f66c0b18601d', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=2 pid=3742599) INFO 03-22 02:26:55 [utils.py:931] Found nccl from library libnccl.so.2
(VllmWorker rank=1 pid=3742520) INFO 03-22 02:26:55 [utils.py:931] Found nccl from library libnccl.so.2
(VllmWorker rank=6 pid=3742758) INFO 03-22 02:26:55 [utils.py:931] Found nccl from library libnccl.so.2
(VllmWorker rank=3 pid=3742640) INFO 03-22 02:26:55 [utils.py:931] Found nccl from library libnccl.so.2
(VllmWorker rank=7 pid=3742791) INFO 03-22 02:26:55 [utils.py:931] Found nccl from library libnccl.so.2
(VllmWorker rank=2 pid=3742599) INFO 03-22 02:26:55 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=6 pid=3742758) INFO 03-22 02:26:55 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=3 pid=3742640) INFO 03-22 02:26:55 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=7 pid=3742791) INFO 03-22 02:26:55 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=1 pid=3742520) INFO 03-22 02:26:55 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=4 pid=3742678) INFO 03-22 02:26:55 [utils.py:931] Found nccl from library libnccl.so.2
(VllmWorker rank=4 pid=3742678) INFO 03-22 02:26:55 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=0 pid=3742482) INFO 03-22 02:26:55 [utils.py:931] Found nccl from library libnccl.so.2
(VllmWorker rank=0 pid=3742482) INFO 03-22 02:26:55 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=5 pid=3742717) INFO 03-22 02:26:55 [utils.py:931] Found nccl from library libnccl.so.2
(VllmWorker rank=5 pid=3742717) INFO 03-22 02:26:55 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=4 pid=3742678) WARNING 03-22 02:26:55 [custom_all_reduce.py:137] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=7 pid=3742791) WARNING 03-22 02:26:55 [custom_all_reduce.py:137] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=6 pid=3742758) WARNING 03-22 02:26:55 [custom_all_reduce.py:137] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=5 pid=3742717) WARNING 03-22 02:26:55 [custom_all_reduce.py:137] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=1 pid=3742520) WARNING 03-22 02:26:55 [custom_all_reduce.py:137] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=2 pid=3742599) WARNING 03-22 02:26:55 [custom_all_reduce.py:137] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=3 pid=3742640) WARNING 03-22 02:26:55 [custom_all_reduce.py:137] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=0 pid=3742482) WARNING 03-22 02:26:55 [custom_all_reduce.py:137] Custom allreduce is disabled because it's not supported on more than two PCIe-only GPUs. To silence this warning, specify disable_custom_all_reduce=True explicitly.
(VllmWorker rank=0 pid=3742482) INFO 03-22 02:26:55 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[1, 2, 3, 4, 5, 6, 7], buffer_handle=(7, 4194304, 6, 'psm_f860853a'), local_subscribe_addr='ipc:///tmp/5c8c7281-8809-42ef-96b4-41029dc9b654', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=4 pid=3742678) INFO 03-22 02:26:55 [parallel_state.py:967] rank 4 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 4
(VllmWorker rank=7 pid=3742791) INFO 03-22 02:26:55 [parallel_state.py:967] rank 7 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 7
(VllmWorker rank=2 pid=3742599) INFO 03-22 02:26:55 [parallel_state.py:967] rank 2 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 2
(VllmWorker rank=3 pid=3742640) INFO 03-22 02:26:55 [parallel_state.py:967] rank 3 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 3
(VllmWorker rank=6 pid=3742758) INFO 03-22 02:26:55 [parallel_state.py:967] rank 6 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 6
(VllmWorker rank=5 pid=3742717) INFO 03-22 02:26:55 [parallel_state.py:967] rank 5 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 5
(VllmWorker rank=7 pid=3742791) INFO 03-22 02:26:55 [cuda.py:187] Using Triton MLA backend on V1 engine.
(VllmWorker rank=0 pid=3742482) INFO 03-22 02:26:55 [parallel_state.py:967] rank 0 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 0
(VllmWorker rank=6 pid=3742758) INFO 03-22 02:26:55 [cuda.py:187] Using Triton MLA backend on V1 engine.
(VllmWorker rank=2 pid=3742599) INFO 03-22 02:26:55 [cuda.py:187] Using Triton MLA backend on V1 engine.
(VllmWorker rank=4 pid=3742678) INFO 03-22 02:26:55 [cuda.py:187] Using Triton MLA backend on V1 engine.
(VllmWorker rank=5 pid=3742717) INFO 03-22 02:26:55 [cuda.py:187] Using Triton MLA backend on V1 engine.
(VllmWorker rank=3 pid=3742640) INFO 03-22 02:26:55 [cuda.py:187] Using Triton MLA backend on V1 engine.
(VllmWorker rank=1 pid=3742520) INFO 03-22 02:26:55 [parallel_state.py:967] rank 1 in world size 8 is assigned as DP rank 0, PP rank 0, TP rank 1
(VllmWorker rank=1 pid=3742520) INFO 03-22 02:26:55 [cuda.py:187] Using Triton MLA backend on V1 engine.
(VllmWorker rank=0 pid=3742482) INFO 03-22 02:26:55 [cuda.py:187] Using Triton MLA backend on V1 engine.
(VllmWorker rank=7 pid=3742791) WARNING 03-22 02:26:55 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorker rank=6 pid=3742758) WARNING 03-22 02:26:55 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorker rank=5 pid=3742717) WARNING 03-22 02:26:55 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorker rank=2 pid=3742599) WARNING 03-22 02:26:55 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorker rank=4 pid=3742678) WARNING 03-22 02:26:55 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorker rank=1 pid=3742520) WARNING 03-22 02:26:55 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorker rank=3 pid=3742640) WARNING 03-22 02:26:55 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
(VllmWorker rank=0 pid=3742482) WARNING 03-22 02:26:55 [triton_decode_attention.py:44] The following error message 'operation scheduled before its operands' can be ignored.
ERROR 03-22 02:12:54 [core.py:343] EngineCore hit an exception: Traceback (most recent call last):
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/engine/core.py", line 336, in run_engine_core
ERROR 03-22 02:12:54 [core.py:343]     engine_core.run_busy_loop()
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/engine/core.py", line 370, in run_busy_loop
ERROR 03-22 02:12:54 [core.py:343]     outputs = step_fn()
ERROR 03-22 02:12:54 [core.py:343]               ^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/engine/core.py", line 195, in step
ERROR 03-22 02:12:54 [core.py:343]     output = self.model_executor.execute_model(scheduler_output)
ERROR 03-22 02:12:54 [core.py:343]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/executor/abstract.py", line 77, in execute_model
ERROR 03-22 02:12:54 [core.py:343]     output = self.collective_rpc("execute_model",
ERROR 03-22 02:12:54 [core.py:343]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/executor/multiproc_executor.py", line 134, in collective_rpc
ERROR 03-22 02:12:54 [core.py:343]     raise e
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/executor/multiproc_executor.py", line 123, in collective_rpc
ERROR 03-22 02:12:54 [core.py:343]     raise result
ERROR 03-22 02:12:54 [core.py:343] RuntimeError: A is not contiguous
ERROR 03-22 02:12:54 [core.py:343] Traceback (most recent call last):
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/executor/multiproc_executor.py", line 372, in worker_busy_loop
ERROR 03-22 02:12:54 [core.py:343]     output = func(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]              ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 03-22 02:12:54 [core.py:343]     return func(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/worker/gpu_worker.py", line 242, in execute_model
ERROR 03-22 02:12:54 [core.py:343]     output = self.model_runner.execute_model(scheduler_output)
ERROR 03-22 02:12:54 [core.py:343]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
ERROR 03-22 02:12:54 [core.py:343]     return func(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/worker/gpu_model_runner.py", line 1036, in execute_model
ERROR 03-22 02:12:54 [core.py:343]     hidden_states = self.model(
ERROR 03-22 02:12:54 [core.py:343]                     ^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 03-22 02:12:54 [core.py:343]     return self._call_impl(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 03-22 02:12:54 [core.py:343]     return forward_call(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/model_executor/models/deepseek_v2.py", line 689, in forward
ERROR 03-22 02:12:54 [core.py:343]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
ERROR 03-22 02:12:54 [core.py:343]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/compilation/decorators.py", line 245, in __call__
ERROR 03-22 02:12:54 [core.py:343]     model_output = self.forward(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/model_executor/models/deepseek_v2.py", line 627, in forward
ERROR 03-22 02:12:54 [core.py:343]     def forward(
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 03-22 02:12:54 [core.py:343]     return self._call_impl(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 03-22 02:12:54 [core.py:343]     return forward_call(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/_dynamo/eval_frame.py", line 745, in _fn
ERROR 03-22 02:12:54 [core.py:343]     return fn(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/fx/graph_module.py", line 822, in call_wrapped
ERROR 03-22 02:12:54 [core.py:343]     return self._wrapped_call(self, *args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/fx/graph_module.py", line 400, in __call__
ERROR 03-22 02:12:54 [core.py:343]     raise e
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/fx/graph_module.py", line 387, in __call__
ERROR 03-22 02:12:54 [core.py:343]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 03-22 02:12:54 [core.py:343]     return self._call_impl(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 03-22 02:12:54 [core.py:343]     return forward_call(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "<eval_with_key>.124", line 2070, in forward
ERROR 03-22 02:12:54 [core.py:343]     submod_1 = self.submod_1(getitem, s0, getitem_1, getitem_2, getitem_3);  getitem = getitem_1 = getitem_2 = submod_1 = None
ERROR 03-22 02:12:54 [core.py:343]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/fx/graph_module.py", line 822, in call_wrapped
ERROR 03-22 02:12:54 [core.py:343]     return self._wrapped_call(self, *args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/fx/graph_module.py", line 400, in __call__
ERROR 03-22 02:12:54 [core.py:343]     raise e
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/fx/graph_module.py", line 387, in __call__
ERROR 03-22 02:12:54 [core.py:343]     return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 03-22 02:12:54 [core.py:343]     return self._call_impl(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 03-22 02:12:54 [core.py:343]     return forward_call(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "<eval_with_key>.2", line 5, in forward
ERROR 03-22 02:12:54 [core.py:343]     unified_attention_with_output = torch.ops.vllm.unified_attention_with_output(x_7, x_11, k_pe, output_5, 'model.layers.0.self_attn.attn');  x_7 = x_11 = k_pe = output_5 = unified_attention_with_output = None
ERROR 03-22 02:12:54 [core.py:343]                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
ERROR 03-22 02:12:54 [core.py:343]     return self._op(*args, **(kwargs or {}))
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/attention/layer.py", line 374, in unified_attention_with_output
ERROR 03-22 02:12:54 [core.py:343]     self.impl.forward(self,
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/attention/backends/mla/common.py", line 929, in forward
ERROR 03-22 02:12:54 [core.py:343]     output[num_decode_tokens:] = self._forward_prefill(
ERROR 03-22 02:12:54 [core.py:343]                                  ^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/attention/backends/mla/common.py", line 826, in _forward_prefill
ERROR 03-22 02:12:54 [core.py:343]     context_output, context_lse = self._compute_prefill_context( \
ERROR 03-22 02:12:54 [core.py:343]                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/attention/backends/mla/common.py", line 742, in _compute_prefill_context
ERROR 03-22 02:12:54 [core.py:343]     kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
ERROR 03-22 02:12:54 [core.py:343]               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
ERROR 03-22 02:12:54 [core.py:343]     return self._call_impl(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
ERROR 03-22 02:12:54 [core.py:343]     return forward_call(*args, **kwargs)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/model_executor/layers/linear.py", line 474, in forward
ERROR 03-22 02:12:54 [core.py:343]     output_parallel = self.quant_method.apply(self, input_, bias)
ERROR 03-22 02:12:54 [core.py:343]                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/model_executor/layers/quantization/gptq_marlin.py", line 341, in apply
ERROR 03-22 02:12:54 [core.py:343]     return self.kernel.apply_weights(layer, x, bias)
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py", line 123, in apply_weights
ERROR 03-22 02:12:54 [core.py:343]     return apply_gptq_marlin_linear(
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/model_executor/layers/quantization/utils/marlin_utils.py", line 334, in apply_gptq_marlin_linear
ERROR 03-22 02:12:54 [core.py:343]     output = ops.gptq_marlin_gemm(reshaped_x,
ERROR 03-22 02:12:54 [core.py:343]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/_custom_ops.py", line 741, in gptq_marlin_gemm
ERROR 03-22 02:12:54 [core.py:343]     return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vm312/lib/python3.12/site-packages/torch/_ops.py", line 1123, in __call__
ERROR 03-22 02:12:54 [core.py:343]     return self._op(*args, **(kwargs or {}))
ERROR 03-22 02:12:54 [core.py:343]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343] RuntimeError: A is not contiguous

Signed-off-by: Qubitium <qubitium@modelcloud.ai>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Signed-off-by: Qubitium <qubitium@modelcloud.ai>
@DefTruth
Copy link
Contributor

I have make this same fix in a previous PR: #14946

@Qubitium
Copy link
Contributor Author

I have make this same fix in a previous PR: #14946

@DefTruth Yes! You fixed the bug at the source! I only fixed down-stream. This PR is no longer needed imho.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I prefer this to #14946, so that we dont subject other code paths without this requirement (namely block fp8) to a potentially slow copy

@DefTruth
Copy link
Contributor

LGTM, I prefer this to #14946, so that we dont subject other code paths without this requirement (namely block fp8) to a potentially slow copy

I think you are right, this PR seems a more general solution.

@Qubitium Qubitium requested a review from youkaichao March 22, 2025 07:02
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it prefix caching?

Copy link
Contributor Author

@Qubitium Qubitium Mar 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Is it prefix or prefill? I am confused myself or both? Let me know which is the correct terminology to use in comment and I will update. Prefix is the cache, prefill is the action that extracts the prefix cache. Is this correct?

The stacktrace shows it going through the prefill code path:

ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/attention/backends/mla/common.py", line 929, in forward
ERROR 03-22 02:12:54 [core.py:343]     output[num_decode_tokens:] = self._forward_prefill(
ERROR 03-22 02:12:54 [core.py:343]                                  ^^^^^^^^^^^^^^^^^^^^^^
ERROR 03-22 02:12:54 [core.py:343]   File "/root/vllm/vllm/v1/attention/backends/mla/common.py", line 826, in _forward_prefill
ERROR 03-22 02:12:54 [core.py:343]     context_output, context_lse = self._compute_prefill_context( \
ERROR 03-22 02:12:54 [core.py:343]                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh sorry, i just saw the code and i thought "prefill caching" is typo. i'm not familiar with the mla code path though. cc @LucasWilkinson to confirm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao I have updated comment. I believe you are correct.

@Qubitium Qubitium requested a review from youkaichao March 22, 2025 07:54
@mgoin
Copy link
Member

mgoin commented Mar 22, 2025

I worry about this non-contiguous issue for other kernels, but for performance reasons you are right to keep it local to the kernel. Will just have to audit other kernels. It would be good to make a test specifically for this that quant kernels can accept non-contiguous input

Signed-off-by: Qubitium <qubitium@modelcloud.ai>
@Qubitium Qubitium force-pushed the fix-v1-marlin-prefill branch from 89a2cba to d3ea7fe Compare March 22, 2025 13:34
@LucasWilkinson LucasWilkinson enabled auto-merge (squash) March 22, 2025 19:10
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 22, 2025
@LucasWilkinson
Copy link
Collaborator

NOTE: Another (more performant) fix in flight: #14658 (#14658 (comment))

@LucasWilkinson LucasWilkinson merged commit d20e261 into vllm-project:main Mar 24, 2025
44 checks passed
@Qubitium Qubitium deleted the fix-v1-marlin-prefill branch March 24, 2025 03:10
tlrmchlsmth added a commit that referenced this pull request Mar 24, 2025
erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
wrmedford pushed a commit to wrmedford/vllm that referenced this pull request Mar 26, 2025
Signed-off-by: Wes Medford <wryanmedford@gmail.com>
wrmedford pushed a commit to wrmedford/vllm that referenced this pull request Mar 26, 2025
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants