Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Mar 1, 2025

Related issue: #12724

  • Rename v1 ROCmAttention to TritonAttention and allow user to use it on Nvidia GPUs through VLLM_ATTENTION_BACKEND=triton_attn_vllm_v1
  • Since v1 ROCm attn backend is implemented with Triton, it can be used on Nvidia GPUs too for triton kernel development.

Tested on T4 GPU
$ VLLM_USE_V1=1 python examples/offline_inference/basic/generate.py --model Qwen/Qwen2.5-3B-Instruct --dtype half --enforce-eager 
--max-num-seqs 1 -tp 2
INFO 03-01 03:23:53 [__init__.py:207] Automatically detected platform cuda.
WARNING 03-01 03:23:55 [arg_utils.py:1410] Setting max_num_batched_tokens to 8192 for LLM_CLASS usage context.
WARNING 03-01 03:23:56 [config.py:2552] Casting torch.bfloat16 to torch.float16.
INFO 03-01 03:24:05 [config.py:575] This model supports multiple tasks: {'reward', 'classify', 'generate', 'score', 'embed'}. Defaulting to 'generate'.
INFO 03-01 03:24:05 [config.py:1485] Defaulting to use mp for distributed inference
INFO 03-01 03:24:05 [config.py:1660] Chunked prefill is enabled with max_num_batched_tokens=8192.
WARNING 03-01 03:24:05 [cuda.py:95] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
INFO 03-01 03:24:05 [core.py:50] Initializing a V1 LLM engine (v0.1.dev4784+g18e5059) with config: model='Qwen/Qwen2.5-3B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-3B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Qwen/Qwen2.5-3B-Instruct, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":["vllm.unified_attention","vllm.unified_attention_with_output"],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}
WARNING 03-01 03:24:05 [multiproc_worker_utils.py:309] Reducing Torch parallelism from 2 threads to 1 to avoid unnecessary CPU contention. Set OMP_NUM_THREADS in the external environment to tune this value as needed.
INFO 03-01 03:24:05 [custom_cache_manager.py:19] Setting Triton cache manager to: vllm.triton_utils.custom_cache_manager:CustomCacheManager
INFO 03-01 03:24:05 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0, 1], buffer_handle=(2, 10485760, 10, 'psm_1e082990'), local_subscribe_addr='ipc:///tmp/f7491eb8-ecd9-486b-bf1a-dd08898e8aee', remote_subscribe_addr=None, remote_addr_ipv6=False)
WARNING 03-01 03:24:06 [utils.py:2298] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x78579a1cc0b0>
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:06 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_be1dd2a6'), local_subscribe_addr='ipc:///tmp/35d2b250-45a4-411a-a7bd-c25874c79c67', remote_subscribe_addr=None, remote_addr_ipv6=False)
WARNING 03-01 03:24:06 [utils.py:2298] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x78579a1d0950>
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:06 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[0], buffer_handle=(1, 10485760, 10, 'psm_7b4b9014'), local_subscribe_addr='ipc:///tmp/c981e423-630d-4330-b6a4-201ee387fccf', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [utils.py:939] Found nccl from library libnccl.so.2
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [utils.py:939] Found nccl from library libnccl.so.2
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [pynccl.py:69] vLLM is using nccl==2.21.5
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [custom_all_reduce_utils.py:244] reading GPU P2P access cache from /root/.cache/vllm/gpu_p2p_access_cache_for_0,1.json
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [shm_broadcast.py:258] vLLM message queue communication handle: Handle(local_reader_ranks=[1], buffer_handle=(1, 4194304, 6, 'psm_4f618998'), local_subscribe_addr='ipc:///tmp/49b66037-1022-4554-aee6-b094b86ce383', remote_subscribe_addr=None, remote_addr_ipv6=False)
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [parallel_state.py:948] rank 0 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 0
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [parallel_state.py:948] rank 1 in world size 2 is assigned as DP rank 0, PP rank 0, TP rank 1
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [cuda.py:202] Cannot use Flash Attention backend for Turing GPUs.
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [cuda.py:204] Using Triton backend on V1 engine.
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [cuda.py:202] Cannot use Flash Attention backend for Turing GPUs.
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [cuda.py:204] Using Triton backend on V1 engine.
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [gpu_model_runner.py:1054] Starting to load model Qwen/Qwen2.5-3B-Instruct...
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [gpu_model_runner.py:1054] Starting to load model Qwen/Qwen2.5-3B-Instruct...
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [cuda.py:202] Cannot use Flash Attention backend for Turing GPUs.
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [cuda.py:202] Cannot use Flash Attention backend for Turing GPUs.
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:07 [cuda.py:204] Using Triton backend on V1 engine.
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:07 [cuda.py:204] Using Triton backend on V1 engine.
(VllmWorker rank=1 pid=42647) WARNING 03-01 03:24:07 [topk_topp_sampler.py:46] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
(VllmWorker rank=0 pid=42605) WARNING 03-01 03:24:07 [topk_topp_sampler.py:46] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:08 [weight_utils.py:254] Using model weights format ['*.safetensors']
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:08 [weight_utils.py:254] Using model weights format ['*.safetensors']
Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  50% Completed | 1/2 [00:01<00:01,  2.00s/it]
(VllmWorker rank=1 pid=42647) INFO 03-01 03:24:12 [gpu_model_runner.py:1066] Loading model weights took 2.9348 GB and 4.151953 seconds
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.45s/it]
Loading safetensors checkpoint shards: 100% Completed | 2/2 [00:03<00:00,  1.53s/it]
(VllmWorker rank=0 pid=42605) 
(VllmWorker rank=0 pid=42605) INFO 03-01 03:24:12 [gpu_model_runner.py:1066] Loading model weights took 2.9348 GB and 4.580717 seconds
INFO 03-01 03:24:18 [kv_cache_utils.py:524] GPU KV cache size: 559,584 tokens
INFO 03-01 03:24:18 [kv_cache_utils.py:527] Maximum concurrency for 32,768 tokens per request: 17.08x
INFO 03-01 03:24:18 [kv_cache_utils.py:524] GPU KV cache size: 559,584 tokens
INFO 03-01 03:24:18 [kv_cache_utils.py:527] Maximum concurrency for 32,768 tokens per request: 17.08x
INFO 03-01 03:24:18 [core.py:116] init engine (profile, create kv cache, warmup model) took 5.81 seconds
Processed prompts: 100%|███████████████████████████████████████████████████| 4/4 [00:06<00:00,  1.67s/it, est. speed input: 3.29 toks/s, output: 9.56 toks/s]
Prompt: 'Hello, my name is', Generated text: " Josh and I'm here to personally walk you through the process of enabling swap with"
Prompt: 'The president of the United States is', Generated text: ' a very important and respected position in the country. A pop quiz came up a'
Prompt: 'The capital of France is', Generated text: ' Paris across the Seine. Paris itself is divided into 20 districts,'
Prompt: 'The future of AI is', Generated text: ' heavily linked to quantum computing, but it will first need to become better at handling'

Isotr0py added 2 commits March 1, 2025 10:50
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
@github-actions
Copy link

github-actions bot commented Mar 1, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Mar 1, 2025
Signed-off-by: Isotr0py <2037008807@qq.com>
@DarkLight1337 DarkLight1337 added the rocm Related to AMD ROCm label Mar 1, 2025
@mergify
Copy link

mergify bot commented Mar 1, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Isotr0py.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 1, 2025
Signed-off-by: Isotr0py <2037008807@qq.com>
@mergify mergify bot removed the needs-rebase label Mar 1, 2025
logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend")
if cls.has_device_capability(80):
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Mar 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should move these logs to debug

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we? I think people care about this log (although I really want to provide only one option per hardware).

Signed-off-by: Isotr0py <2037008807@qq.com>
@mergify
Copy link

mergify bot commented Mar 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Isotr0py.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 3, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Isotr0py, thanks for the PR!
Does Triton actually support T4? I think the support is discontinued:

Supported Hardware:

NVIDIA GPUs (Compute Capability 8.0+)

@Isotr0py
Copy link
Member Author

Isotr0py commented Mar 3, 2025

Hmm, at least they still kept the FMA fallback when deprecating MMAv1 on pre-Ampere GPUs. (triton-lang/triton#5066).

I also ran the prefix_prefill tests and it passed on dual T4 GPUs (with triton==3.2.0):
QQ截图20250303113343

@Isotr0py
Copy link
Member Author

Isotr0py commented Mar 3, 2025

@WoosukKwon According to triton team's response (triton-lang/triton#5066 (comment)), old platform with FMA support code path is still maintained on main branch, though the MMA support is deprecated for old platforms. So I think we can still use triton kernels for Volta and Turing GPUs, especially there are still lots of users using 32GB V100 to serve models.

I think a compromise about this deprecation is only allowing user to specify VLLM_ATTENTION_BACKEND to enable this Triton backend fallback. So that developers with old platforms can still use this backend for v1 engine development usage, if we don't want this backend to be used in production due to poor performance from MMA deprecation. WDYT?

@WoosukKwon
Copy link
Collaborator

Hi @Isotr0py, thanks for sharing the information.

I think a compromise about this deprecation is only allowing user to specify VLLM_ATTENTION_BACKEND to enable this Triton backend fallback. So that developers with old platforms can still use this backend for v1 engine development usage.

This sounds like a good compromise, but I'm still a bit worried:

  1. My primary worry is that AMD might want to further optimize the attention backend for their needs. And in doing so, they can be bothered by this code sharing.
  2. We may not have people to work on it when there's any bug in T4 & V100 support. I think we want to focus more on latest generations (Hopper & Blackwell).

That being sad, I think this needs broader discussion among the community, not just me. Will reach out to you offline!

@tdoublep
Copy link
Member

tdoublep commented Mar 4, 2025

I don't really have an opinion regarding how to handle older hardware, but it would be really nice to have the option to ask vLLM to enable the TritonAttentionBackend when running even on newer NVIDIA GPUs. This would allow folks working on Triton kernels to easily benchmark performance and drive further optimizations.

Isotr0py added 2 commits March 5, 2025 12:58
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
@mergify
Copy link

mergify bot commented Mar 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @Isotr0py.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 18, 2025
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.. I'm fine with this PR then.

logger.info_once("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend")
if cls.has_device_capability(80):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we? I think people care about this log (although I really want to provide only one option per hardware).

Signed-off-by: Isotr0py <2037008807@qq.com>
@mergify mergify bot removed the needs-rebase label Mar 19, 2025
Isotr0py and others added 3 commits March 19, 2025 16:33
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Isotr0py <2037008807@qq.com>
Signed-off-by: Isotr0py <2037008807@qq.com>
@taoluo
Copy link

taoluo commented Mar 20, 2025

Hi, I tried this PR on V100, seems it can only support --dtype float32. The original --dtype half gives following errors.
Given that V100 has limited memory, it would be nice to support half.

$ VLLM_ATTENTION_BACKEND=TRITON_ATTN_VLLM_V1  python examples/offline_inference/basic/generate.py --model JackFram/llama-68m  --dtype half --enforce-eager --max-num-seqs 1 -tp 1
INFO 03-20 12:15:02 [__init__.py:256] Automatically detected platform cuda.
INFO 03-20 12:15:03 [config.py:2595] Downcasting torch.float32 to torch.float16.
INFO 03-20 12:15:16 [config.py:583] This model supports multiple tasks: {'classify', 'embed', 'generate', 'score', 'reward'}. Defaulting to 'generate'.
INFO 03-20 12:15:16 [config.py:1693] Chunked prefill is enabled with max_num_batched_tokens=8192.
WARNING 03-20 12:15:16 [cuda.py:95] To see benefits of async output processing, enable CUDA graph. Since, enforce-eager is enabled, async output processor cannot be used
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
INFO 03-20 12:15:19 [core.py:53] Initializing a V1 LLM engine (v0.8.0rc3.dev42+gdf9f570fe) with config: model='JackFram/llama-68m', speculative_config=None, tokenizer='JackFram/llama-68m', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=True, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=None, served_model_name=JackFram/llama-68m, num_scheduler_steps=1, multi_step_stream_outputs=True, enable_prefix_caching=True, chunked_prefill_enabled=True, use_async_output_proc=False, disable_mm_preprocessor_cache=False, mm_processor_kwargs=None, pooler_config=None, compilation_config={"splitting_ops":[],"compile_sizes":[],"cudagraph_capture_sizes":[],"max_capture_size":0}
WARNING 03-20 12:15:20 [utils.py:2282] Methods determine_num_available_blocks,device_config,get_cache_block_size_bytes,initialize_cache not implemented in <vllm.v1.worker.gpu_worker.Worker object at 0x7fbb4f2b9430>
INFO 03-20 12:15:20 [parallel_state.py:967] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 03-20 12:15:20 [cuda.py:216] Using Triton backend on V1 engine.
INFO 03-20 12:15:20 [gpu_model_runner.py:1164] Starting to load model JackFram/llama-68m...
WARNING 03-20 12:15:20 [topk_topp_sampler.py:63] FlashInfer is not available. Falling back to the PyTorch-native implementation of top-p & top-k sampling. For the best performance, please install FlashInfer.
INFO 03-20 12:15:21 [weight_utils.py:257] Using model weights format ['*.bin']
Loading pt checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.55it/s]
Loading pt checkpoint shards: 100% Completed | 1/1 [00:00<00:00,  3.54it/s]

INFO 03-20 12:15:21 [loader.py:429] Loading weights took 0.29 seconds
INFO 03-20 12:15:21 [gpu_model_runner.py:1176] Model loading took 0.1270 GB and 0.613932 seconds
INFO 03-20 12:15:22 [kv_cache_utils.py:537] GPU KV cache size: 2,355,232 tokens
INFO 03-20 12:15:22 [kv_cache_utils.py:540] Maximum concurrency for 2,048 tokens per request: 1150.02x
INFO 03-20 12:15:22 [core.py:138] init engine (profile, create kv cache, warmup model) took 0.68 seconds
Processed prompts:   0%|                                                                                                                       | 0/4 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]
LLVM ERROR: Failed to compute parent layout for slice layout.

@Isotr0py
Copy link
Member Author

Isotr0py commented Mar 20, 2025

@taoluo I think that's because new version triton has removed MMAv1 support for Volta and Turing, but not sure if it's an exact bug in triton as well. You can open an issue about this llvm error in their repo.

@taoluo
Copy link

taoluo commented Mar 20, 2025

Thanks for the explanation.

@robertgshaw2-redhat robertgshaw2-redhat enabled auto-merge (squash) March 20, 2025 20:42
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 20, 2025
@robertgshaw2-redhat robertgshaw2-redhat merged commit f8a08cb into vllm-project:main Mar 21, 2025
46 checks passed
@Isotr0py Isotr0py deleted the v1-tesla branch March 21, 2025 03:15
bringlein added a commit to foundation-model-stack/vllm-triton-backend that referenced this pull request Mar 24, 2025
Updating to be able to use vllm-project/vllm#14071. 

---------

Signed-off-by: Burkhard Ringlein <ngl@zurich.ibm.com>
erictang000 pushed a commit to erictang000/vllm that referenced this pull request Mar 25, 2025
…ect#14071)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
…ect#14071)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Louis Ulmer <ulmerlouis@gmail.com>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
…ect#14071)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…ect#14071)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: Mu Huai <tianbowen.tbw@antgroup.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants