From 3c7d19898f4c88c0840b2794483042851a251138 Mon Sep 17 00:00:00 2001 From: amit Date: Tue, 3 Jun 2025 09:33:37 +0300 Subject: [PATCH 01/33] V1 support of priority shedualing Signed-off-by: amit --- docs/usage/v1_guide.md | 2 +- tests/v1/core/test_scheduler.py | 594 ++++++++++++++++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 220 +++++++++--- vllm/v1/engine/__init__.py | 1 + vllm/v1/engine/processor.py | 3 +- vllm/v1/request.py | 9 +- 6 files changed, 771 insertions(+), 58 deletions(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index 3d5d7ce45cce..516b24443d4c 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -69,7 +69,7 @@ This living user guide outlines a few known **important changes and limitations* way by using a simple dictionary (e.g., {request_id: num_tokens}) to dynamically allocate a fixed token budget per request, enabling features like chunked prefills, prefix caching, and speculative decoding without a strict separation between prefill -and decode phases. +and decode phases. The V1 scheduler supports multiple scheduling policies, including First-Come, First-Served (FCFS) and priority-based scheduling (where requests are processed based on assigned priority, with FCFS as a tie-breaker), configurable via the `--scheduling-policy` argument. ### Semantic Changes and Deprecated Features diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index f38454b1b288..079179730d5d 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -1246,3 +1246,597 @@ def test_memory_leak(): # Confirm no memory leak. assert_scheduler_empty(scheduler) + + +def create_scheduler_with_priority( + model: str = "facebook/opt-125m", + max_num_seqs: int = 16, + max_num_batched_tokens: int = 8192, + enable_prefix_caching: Optional[bool] = None, + long_prefill_token_threshold: int = 0, + disable_chunked_mm_input: bool = False, + use_kv_connector: bool = False, + num_blocks: int = 10000, + block_size: int = 16, + max_model_len: Optional[int] = None, + num_speculative_tokens: Optional[int] = None, +) -> Scheduler: + '''Create scheduler with priority policy enabled. + + Args: + model: model under test + max_num_seqs: max sequences to schedule + max_num_batch_tokens: max num tokens to batch + enable_prefix_caching: optionally force APC config + (True/False) or use default + (None) + + Returns: + {class}`Scheduler` instance with priority scheduling + ''' + if max_model_len is None: + max_model_len = max_num_batched_tokens + scheduler_config = SchedulerConfig( + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + max_model_len=max_model_len, + long_prefill_token_threshold=long_prefill_token_threshold, + disable_chunked_mm_input=disable_chunked_mm_input, + enable_chunked_prefill=True, + policy="priority", # Enable priority scheduling + ) + model_config = ModelConfig( + model=model, + task="auto", + tokenizer=model, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="float16", + seed=42, + ) + # Cache config, optionally force APC + kwargs_cache = ({} if enable_prefix_caching is None else { + 'enable_prefix_caching': enable_prefix_caching + }) + cache_config = CacheConfig( + block_size=block_size, + gpu_memory_utilization=0.9, + swap_space=0, + cache_dtype="auto", + **kwargs_cache, + ) + kv_transfer_config = KVTransferConfig( + kv_connector="SharedStorageConnector", + kv_role="kv_both", + kv_connector_extra_config={"shared_storage_path": "local_storage"}, + ) if use_kv_connector else None + + speculative_config: Optional[SpeculativeConfig] = None + if num_speculative_tokens is not None: + speculative_config = SpeculativeConfig( + model="ngram", num_speculative_tokens=num_speculative_tokens) + + vllm_config = VllmConfig( + scheduler_config=scheduler_config, + model_config=model_config, + cache_config=cache_config, + kv_transfer_config=kv_transfer_config, + speculative_config=speculative_config, + ) + kv_cache_config = KVCacheConfig( + num_blocks=num_blocks, # A large number of blocks to hold all requests + tensors={}, + kv_cache_groups=[ + KVCacheGroupSpec(['layer'], + FullAttentionSpec(block_size, 1, 1, torch.float32, + False)) + ], + ) + cache_config.num_gpu_blocks = num_blocks + return Scheduler( + vllm_config=vllm_config, + kv_cache_config=kv_cache_config, + log_stats=True, + structured_output_manager=StructuredOutputManager(vllm_config), + ) + + +def create_requests_with_priority( + num_requests: int, + priorities: list[int], + arrival_times: Optional[list[float]] = None, + num_tokens: int = 10, + mm_positions: Optional[list[PlaceholderRange]] = None, + max_tokens: int = 16, + stop_token_ids: Optional[list[int]] = None, + prompt_logprobs: Optional[int] = None): + """Create requests with specified priorities and arrival times.""" + assert len(priorities) == num_requests + if arrival_times is not None: + assert len(arrival_times) == num_requests + else: + arrival_times = [float(i) for i in range(num_requests)] + + sampling_params = SamplingParams(ignore_eos=False, + max_tokens=max_tokens, + stop_token_ids=stop_token_ids, + prompt_logprobs=prompt_logprobs) + requests = [] + for i in range(num_requests): + if mm_positions is not None: + mm_position = mm_positions[i] + mm_inputs = [MultiModalKwargs({})] * len(mm_position) + else: + mm_position = None + mm_inputs = None + request = Request( + request_id=f"{i}", + prompt_token_ids=[i] * num_tokens, + sampling_params=sampling_params, + multi_modal_inputs=mm_inputs, + multi_modal_placeholders=mm_position, + multi_modal_hashes=None, + eos_token_id=EOS_TOKEN_ID, + arrival_time=arrival_times[i], + priority=priorities[i], + ) + requests.append(request) + return requests + + +def test_priority_scheduling_basic_ordering(): + """Test that requests are scheduled in priority order + (lower value = higher priority).""" + scheduler = create_scheduler_with_priority() + + # Create requests with different priorities + # Priority 0 (highest), 1, 2 (lowest) + priorities = [2, 0, 1] # Add in non-priority order + arrival_times = [1.0, 2.0, 3.0] # All different arrival times + requests = create_requests_with_priority(num_requests=3, + priorities=priorities, + arrival_times=arrival_times) + + # Add requests in non-priority order + for request in requests: + scheduler.add_request(request) + + # Schedule and verify priority order + output = scheduler.schedule() + + # Should schedule all requests since they fit in budget + assert len(output.scheduled_new_reqs) == 3 + + # Verify they are scheduled in priority order: + # req_1 (priority 0), req_2 (priority 1), req_0 (priority 2) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["1", "2", "0"] + + +def test_priority_scheduling_arrival_time_tiebreaker(): + """Test that arrival time is used + as tiebreaker when priorities are equal.""" + scheduler = create_scheduler_with_priority() + + # Create requests with same priority but different arrival times + priorities = [1, 1, 1] # All same priority + arrival_times = [3.0, 1.0, 2.0] # Different arrival times + requests = create_requests_with_priority(num_requests=3, + priorities=priorities, + arrival_times=arrival_times) + + # Add requests in non-arrival order + for request in requests: + scheduler.add_request(request) + + # Schedule and verify arrival time order + output = scheduler.schedule() + + # Should schedule all requests since they fit in budget + assert len(output.scheduled_new_reqs) == 3 + + # Verify they are scheduled in arrival time order: + # req_1 (1.0), req_2 (2.0), req_0 (3.0) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["1", "2", "0"] + + +def test_priority_scheduling_mixed_priority_and_arrival(): + """Test priority scheduling with mixed priorities and arrival times.""" + scheduler = create_scheduler_with_priority() + + # Create requests with mixed priorities and arrival times + priorities = [2, 1, 1, 0] # Mixed priorities + arrival_times = [1.0, 3.0, 2.0, 4.0] # Mixed arrival times + requests = create_requests_with_priority(num_requests=4, + priorities=priorities, + arrival_times=arrival_times) + + # Add requests + for request in requests: + scheduler.add_request(request) + + # Schedule and verify order + output = scheduler.schedule() + + # Should schedule all requests since they fit in budget + assert len(output.scheduled_new_reqs) == 4 + + # Expected order: + # 1. req_3 (priority 0, arrival 4.0) + # 2. req_2 (priority 1, arrival 2.0) - earlier arrival than req_1 + # 3. req_1 (priority 1, arrival 3.0) + # 4. req_0 (priority 2, arrival 1.0) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert scheduled_req_ids == ["3", "2", "1", "0"] + + +def test_priority_scheduling_preemption(): + """Test that priority scheduling preempts + lower priority requests when memory is constrained.""" + # Create scheduler with very limited memory to force preemption + scheduler = create_scheduler_with_priority( + max_num_seqs=3, # Allow multiple requests + max_num_batched_tokens=200, + num_blocks=6, # Very limited blocks to force memory pressure + block_size=16, # Standard block size + ) + + # Create initial low-priority requests that will consume most memory + low_priority_requests = create_requests_with_priority( + num_requests=2, + priorities=[5, 5], # Low priority + arrival_times=[1.0, 2.0], + num_tokens=30 # Large enough to consume significant memory + ) + + # Add and schedule low priority requests + for request in low_priority_requests: + scheduler.add_request(request) + + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 2 + + # Simulate model execution to move requests to running state + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in low_priority_requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(low_priority_requests) + }, + sampled_token_ids=[[100] for _ in low_priority_requests], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(output, model_output) + + # Verify both requests are running + assert len(scheduler.running) == 2 + + # Now add a high-priority request that requires memory allocation + # This should trigger preemption due to memory constraints + high_priority_request = create_requests_with_priority( + num_requests=1, + priorities=[0], # High priority + arrival_times=[3.0], + num_tokens=30 # Large enough to require significant memory + )[0] + + scheduler.add_request(high_priority_request) + + # Schedule again - this should trigger + # preemption when trying to allocate memory + output = scheduler.schedule() + + # Due to the scheduler's design, if preemption happens + # during running request scheduling, + # waiting requests won't be scheduled in the same step + # Let's check if preemption occurred by looking at the waiting queue + + # If preemption happened, we should see requests in the + # waiting queue + if len(scheduler.waiting) > 1: # high priority + preempted request + # Preemption occurred - verify the high priority request + # gets scheduled next + output2 = scheduler.schedule() + assert len(output2.scheduled_new_reqs) == 1 + # High priority request + assert output2.scheduled_new_reqs[0].req_id == "0" + else: + # No preemption needed - all requests fit + # This is also valid behavior if memory allows + assert len(output.scheduled_new_reqs) == 1 + # High priority request + assert output.scheduled_new_reqs[0].req_id == "0" + + +def test_priority_scheduling_no_preemption_when_space_available(): + """Test that preemption doesn't happen + when there's space for new requests.""" + scheduler = create_scheduler_with_priority( + max_num_seqs=3, # Allow 3 concurrent requests + max_num_batched_tokens=200, # Sufficient token budget + ) + + # Add two low-priority running requests + low_priority_requests = create_requests_with_priority( + num_requests=2, + priorities=[5, 5], + arrival_times=[1.0, 2.0], + num_tokens=30) + + for request in low_priority_requests: + scheduler.add_request(request) + + output = scheduler.schedule() + model_output = ModelRunnerOutput( + req_ids=[req.request_id for req in low_priority_requests], + req_id_to_index={ + req.request_id: i + for i, req in enumerate(low_priority_requests) + }, + sampled_token_ids=[[100] for _ in low_priority_requests], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(output, model_output) + + # Add high-priority request + high_priority_request = create_requests_with_priority(num_requests=1, + priorities=[0], + arrival_times=[3.0], + num_tokens=30)[0] + + scheduler.add_request(high_priority_request) + + # Schedule - should not preempt since there's space + output = scheduler.schedule() + + # Should schedule the new request without preemption + assert len(output.scheduled_new_reqs) == 1 + assert len(scheduler.running) == 3 # All three requests running + assert len(scheduler.waiting) == 0 # No requests waiting + + +def test_priority_scheduling_preemption_victim_selection(): + """Test that the correct victim is selected for + preemption based on priority and arrival time.""" + # This test verifies the priority-based victim selection logic + # by checking the waiting queue order after adding requests with different + # priorities + scheduler = create_scheduler_with_priority( + max_num_seqs=1, # Force sequential processing to test priority order + ) + + # Create requests with different priorities + requests = create_requests_with_priority( + num_requests=3, + priorities=[3, 2, 0], # Different priorities: low, medium, high + arrival_times=[1.0, 2.0, 3.0], + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule - should only schedule the highest priority request + # (req_2, priority 0) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_new_reqs[0].req_id == "2" # Highest priority + + # Verify the waiting queue has the remaining requests in priority order + assert len(scheduler.waiting) == 2 + + # Extract waiting requests and verify priority order + temp_waiting = list(scheduler.waiting) + temp_waiting.sort() # Sort by (priority, arrival_time, request) + + waiting_priorities = [priority for priority, _, _ in temp_waiting] + waiting_req_ids = [req.request_id for _, _, req in temp_waiting] + + # Should be req_1 (priority 2) then req_0 (priority 3) + assert waiting_priorities == [2, 3] + assert waiting_req_ids == ["1", "0"] + + +def test_priority_scheduling_equal_priority_preemption(): + """Test arrival time tiebreaker when requests have equal priority.""" + # This test verifies that arrival time is used as a tiebreaker for equal + # priorities + scheduler = create_scheduler_with_priority( + max_num_seqs=1, # Force sequential processing + ) + + # Create requests with same priority but different arrival times + requests = create_requests_with_priority( + num_requests=3, + priorities=[2, 2, 2], # Same priority + arrival_times=[3.0, 1.0, 2.0], # Different arrival times + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule - should schedule the request with earliest arrival time + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_new_reqs[0].req_id == "1" # Earliest arrival (1.0) + + # Verify the waiting queue has remaining requests in arrival time order + assert len(scheduler.waiting) == 2 + + # Extract waiting requests and verify arrival time order + temp_waiting = list(scheduler.waiting) + temp_waiting.sort() # Sort by (priority, arrival_time, request) + + waiting_arrival_times = [ + arrival_time for _, arrival_time, _ in temp_waiting + ] + waiting_req_ids = [req.request_id for _, _, req in temp_waiting] + + # Should be req_2 (arrival 2.0) then req_0 (arrival 3.0) + assert waiting_arrival_times == [2.0, 3.0] + assert waiting_req_ids == ["2", "0"] + + +def test_priority_scheduling_waiting_queue_order(): + """Test that the waiting queue maintains priority order.""" + scheduler = create_scheduler_with_priority( + max_num_seqs=1, # Only one request can run at a time + ) + + # Create multiple requests with different priorities + requests = create_requests_with_priority( + num_requests=4, + priorities=[3, 1, 2, 0], # Mixed priorities + arrival_times=[1.0, 2.0, 3.0, 4.0], + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule - should only schedule the highest priority request + # (req_3, priority 0) + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 1 + assert output.scheduled_new_reqs[0].req_id == "3" + + # Verify waiting queue has remaining requests in priority order + assert len(scheduler.waiting) == 3 + + # Extract requests from waiting queue + # (it's a heap, so we need to pop to see order) + waiting_priorities = [] + waiting_req_ids = [] + temp_waiting = list(scheduler.waiting) # Copy the heap + temp_waiting.sort() # Sort by (priority, arrival_time, request) + + for priority, arrival_time, request in temp_waiting: + waiting_priorities.append(priority) + waiting_req_ids.append(request.request_id) + + # Should be ordered by priority: req_1 (1), req_2 (2), req_0 (3) + assert waiting_req_ids == ["1", "2", "0"] + assert waiting_priorities == [1, 2, 3] + + +def test_priority_scheduling_fcfs_fallback(): + """Test that FCFS behavior is maintained when all + requests have same priority.""" + scheduler = create_scheduler_with_priority() + + # Create requests with same priority but different arrival times + priorities = [1, 1, 1, 1] # All same priority + arrival_times = [4.0, 1.0, 3.0, 2.0] # Different arrival times + requests = create_requests_with_priority(num_requests=4, + priorities=priorities, + arrival_times=arrival_times) + + # Add requests + for request in requests: + scheduler.add_request(request) + + # Schedule + output = scheduler.schedule() + + # Should schedule all requests in arrival time order + assert len(output.scheduled_new_reqs) == 4 + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + + # Expected order by arrival time: + # req_1 (1.0), req_3 (2.0), req_2 (3.0), req_0 (4.0) + assert scheduled_req_ids == ["1", "3", "2", "0"] + + +def test_priority_scheduling_with_limited_slots(): + """Test priority scheduling when max_num_seqs limits concurrent requests.""" + scheduler = create_scheduler_with_priority( + max_num_seqs=2, # Only allow 2 concurrent requests + max_num_batched_tokens=1000, # Plenty of token budget + ) + + # Create requests with different priorities + requests = create_requests_with_priority( + num_requests=4, + priorities=[3, 1, 2, 0], # Mixed priorities + arrival_times=[1.0, 2.0, 3.0, 4.0], + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule - should only schedule the 2 highest priority requests + output = scheduler.schedule() + assert len(output.scheduled_new_reqs) == 2 + + # Should schedule req_3 (priority 0) and req_1 (priority 1) + scheduled_req_ids = [req.req_id for req in output.scheduled_new_reqs] + assert "3" in scheduled_req_ids # Priority 0 + assert "1" in scheduled_req_ids # Priority 1 + + # Remaining requests should be in waiting queue in priority order + assert len(scheduler.waiting) == 2 + + # Extract waiting requests and verify order + temp_waiting = list(scheduler.waiting) + temp_waiting.sort() + waiting_priorities = [priority for priority, _, _ in temp_waiting] + waiting_req_ids = [req.request_id for _, _, req in temp_waiting] + + # Should be req_2 (priority 2) then req_0 (priority 3) + assert waiting_priorities == [2, 3] + assert waiting_req_ids == ["2", "0"] + + +def test_priority_scheduling_heap_property(): + """Test that the waiting queue maintains heap + property for priority scheduling.""" + scheduler = create_scheduler_with_priority( + max_num_seqs=1, # Only one request can run at a time + ) + + # Add requests in random priority order + priorities = [5, 1, 8, 3, 2, 7, 4, 6] + arrival_times = [float(i) for i in range(len(priorities))] + requests = create_requests_with_priority(num_requests=len(priorities), + priorities=priorities, + arrival_times=arrival_times, + num_tokens=10) + + # Add all requests + for request in requests: + scheduler.add_request(request) + + # Schedule one request at a time and verify priority order + scheduled_priorities = [] + + while scheduler.waiting: + output = scheduler.schedule() + if output.scheduled_new_reqs: + req = output.scheduled_new_reqs[0] + scheduled_priorities.append(requests[int(req.req_id)].priority) + + # Simulate completion to make room for next request + model_output = ModelRunnerOutput( + req_ids=[req.req_id], + req_id_to_index={req.req_id: 0}, + sampled_token_ids=[[100]], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + ) + scheduler.update_from_output(output, model_output) + + # Finish the request to make room for the next one + scheduler.finish_requests(req.req_id, + RequestStatus.FINISHED_STOPPED) + + # Verify requests were scheduled in priority order (lowest value first) + expected_priorities = sorted(priorities) + assert scheduled_priorities == expected_priorities diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index ce16a1ed5a09..89e4016a7b97 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -2,10 +2,11 @@ from __future__ import annotations +import heapq import time from collections import defaultdict, deque from collections.abc import Iterable -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch @@ -63,8 +64,8 @@ def __init__( # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs - self.max_num_scheduled_tokens = \ - self.scheduler_config.max_num_batched_tokens + self.max_num_scheduled_tokens = ( + self.scheduler_config.max_num_batched_tokens) self.max_model_len = self.scheduler_config.max_model_len self.enable_kv_cache_events = ( self.kv_events_config is not None @@ -88,8 +89,12 @@ def __init__( # req_id -> Request self.requests: dict[str, Request] = {} + # Scheduling policy + self.policy = self.scheduler_config.policy # Priority queues for requests. - self.waiting: deque[Request] = deque() + self.waiting: Union[list[tuple[int, float, Request]], + deque[Request]] = ([] if self.policy == "priority" + else deque()) self.running: list[Request] = [] # The request IDs that are finished in between the previous and the @@ -104,8 +109,8 @@ def __init__( # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> deque of CachedRequestData - self._cached_reqs_data: dict[ - str, deque[CachedRequestData]] = defaultdict(deque) + self._cached_reqs_data: dict[str, deque[CachedRequestData]] = ( + defaultdict(deque)) # Encoder-related. # Calculate encoder cache size if applicable @@ -209,10 +214,16 @@ def schedule(self) -> SchedulerOutput: encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget) = self._try_schedule_encoder_inputs( - request, request.num_computed_tokens, num_new_tokens, - encoder_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_budget, + ) = self._try_schedule_encoder_inputs( + request, + request.num_computed_tokens, + num_new_tokens, + encoder_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled because one of the following @@ -230,18 +241,33 @@ def schedule(self) -> SchedulerOutput: num_draft_tokens = max( num_new_tokens + request.num_computed_tokens - - request.num_tokens, 0) + request.num_tokens, + 0, + ) while True: new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens, num_draft_tokens=num_draft_tokens, - num_lookahead_tokens=self.num_lookahead_tokens) + num_lookahead_tokens=self.num_lookahead_tokens, + ) if new_blocks is None: # The request cannot be scheduled. # Preempt the lowest-priority request. - preempted_req = self.running.pop() + if not self.running: + # No request to preempt. + can_schedule = False + break + if self.policy == "priority": + preempted_req = min( + self.running, + key=lambda r: (r.priority, r.arrival_time), + ) + self.running.remove(preempted_req) + else: + preempted_req = self.running.pop() + self.kv_cache_manager.free(preempted_req) preempted_req.status = RequestStatus.PREEMPTED preempted_req.num_computed_tokens = 0 @@ -249,7 +275,19 @@ def schedule(self) -> SchedulerOutput: preempted_req.record_event( EngineCoreEventType.PREEMPTED, scheduled_timestamp) - self.waiting.appendleft(preempted_req) + if self.policy == "priority": + heapq.heappush( + cast(list[tuple[int, float, Request]], + self.waiting), + ( + preempted_req.priority, + preempted_req.arrival_time, + preempted_req, + ), + ) + else: + cast(deque[Request], + self.waiting).appendleft(preempted_req) preempted_reqs.append(preempted_req) if preempted_req == request: # No more request to preempt. @@ -307,7 +345,10 @@ def schedule(self) -> SchedulerOutput: # Use a temporary deque to collect requests that need to be skipped # and put back at the head of the waiting queue later - skipped_waiting_requests: deque[Request] = deque() + skipped_waiting_requests: Union[list[tuple[int, float, Request]], + deque[Request]] = ([] if self.policy + == "priority" else + deque()) # Next, schedule the WAITING requests. if not preempted_reqs: @@ -315,7 +356,14 @@ def schedule(self) -> SchedulerOutput: if len(self.running) == self.max_num_running_reqs: break - request = self.waiting[0] + if self.policy == "priority": + if (not self.waiting + ): # Should not happen due to outer loop condition + break + priority_val, arrival_time_val, request = heapq.heappop( + cast(list[tuple[int, float, Request]], self.waiting)) + else: + request = cast(deque[Request], self.waiting).popleft() # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: @@ -325,9 +373,18 @@ def schedule(self) -> SchedulerOutput: else: logger.debug( "%s is still in WAITING_FOR_REMOTE_KVS state.", - request.request_id) - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) + request.request_id, + ) + if self.policy == "priority": + waiting_queue = cast( + list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + cast(deque[Request], + skipped_waiting_requests).appendleft(request) continue # Skip request if the structured output request is still waiting @@ -337,19 +394,33 @@ def schedule(self) -> SchedulerOutput: if structured_output_req and structured_output_req.grammar: request.status = RequestStatus.WAITING else: - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) + if self.policy == "priority": + waiting_queue = cast( + list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + cast(deque[Request], + skipped_waiting_requests).appendleft(request) continue # Check that adding the request still respects the max_loras # constraint. - if self.lora_config and request.lora_request and ( - len(scheduled_loras) == self.lora_config.max_loras - and request.lora_request.lora_int_id - not in scheduled_loras): + if (self.lora_config and request.lora_request and + (len(scheduled_loras) == self.lora_config.max_loras and + request.lora_request.lora_int_id not in scheduled_loras)): # Scheduling would exceed max_loras, skip. - self.waiting.popleft() - skipped_waiting_requests.appendleft(request) + if self.policy == "priority": + waiting_queue = cast(list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + cast(deque[Request], + skipped_waiting_requests).appendleft(request) continue num_external_computed_tokens = 0 @@ -358,9 +429,8 @@ def schedule(self) -> SchedulerOutput: # Get already-cached tokens. if request.num_computed_tokens == 0: # Get locally-cached tokens. - new_computed_blocks, num_new_local_computed_tokens = \ - self.kv_cache_manager.get_computed_blocks( - request) + new_computed_blocks, num_new_local_computed_tokens = ( + self.kv_cache_manager.get_computed_blocks(request)) # Get externally-cached tokens if using a KVConnector. if self.connector is not None: @@ -400,11 +470,16 @@ def schedule(self) -> SchedulerOutput: # Schedule encoder inputs. if request.has_encoder_inputs: - (encoder_inputs_to_schedule, num_new_tokens, - new_encoder_budget - ) = self._try_schedule_encoder_inputs( - request, num_computed_tokens, num_new_tokens, - encoder_budget) + ( + encoder_inputs_to_schedule, + num_new_tokens, + new_encoder_budget, + ) = self._try_schedule_encoder_inputs( + request, + num_computed_tokens, + num_new_tokens, + encoder_budget, + ) if num_new_tokens == 0: # The request cannot be scheduled. break @@ -419,6 +494,15 @@ def schedule(self) -> SchedulerOutput: ) if new_blocks is None: # The request cannot be scheduled. + if self.policy == "priority": + waiting_queue = cast(list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + # For FCFS, push back to the front of the deque. + cast(deque[Request], self.waiting).appendleft(request) break # KVConnector: update internal state after allocation. @@ -432,17 +516,27 @@ def schedule(self) -> SchedulerOutput: num_external_computed_tokens, ) - self.waiting.popleft() + # Request was already popped from self.waiting + # (either via heapq.heappop or self.waiting.popleft()) + # unless it was re-added above due to new_blocks being None. if load_kv_async: # If loading async, allocate memory and put request # into the WAITING_FOR_REMOTE_KV state. - skipped_waiting_requests.appendleft(request) + if self.policy == "priority": + waiting_queue = cast(list[tuple[int, float, Request]], + skipped_waiting_requests) + heapq.heappush( + waiting_queue, + (priority_val, arrival_time_val, request)) + else: + cast(deque[Request], + skipped_waiting_requests).appendleft(request) request.status = RequestStatus.WAITING_FOR_REMOTE_KVS continue if request.use_structured_output: - structured_output_request_ids[ - request.request_id] = req_index + structured_output_request_ids[request.request_id] = ( + req_index) req_index += 1 self.running.append(request) if self.log_stats: @@ -478,7 +572,17 @@ def schedule(self) -> SchedulerOutput: # Put back any skipped requests at the head of the waiting queue if skipped_waiting_requests: - self.waiting.extendleft(skipped_waiting_requests) + if self.policy == "priority": + waiting_queue = cast(list[tuple[int, float, Request]], + self.waiting) + skipped_queue = cast(list[tuple[int, float, Request]], + skipped_waiting_requests) + for item in skipped_queue: + heapq.heappush(waiting_queue, item) + else: # FCFS + waiting_deque = cast(deque[Request], self.waiting) + skipped_deque = cast(deque[Request], skipped_waiting_requests) + waiting_deque.extendleft(skipped_deque) # Check if the scheduling constraints are satisfied. total_num_scheduled_tokens = sum(num_scheduled_tokens.values()) @@ -488,8 +592,8 @@ def schedule(self) -> SchedulerOutput: # Since some requests in the RUNNING queue may not be scheduled in # this step, the total number of scheduled requests can be smaller than # len(self.running). - assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + - len(scheduled_running_reqs) <= len(self.running)) + assert len(scheduled_new_reqs) + len(scheduled_resumed_reqs) + len( + scheduled_running_reqs) <= len(self.running) # Get the longest common prefix among all requests in the running queue. # This can be potentially used for cascade attention. @@ -734,7 +838,8 @@ def update_from_output( spec_decoding_stats = self.make_spec_decoding_stats( spec_decoding_stats, num_draft_tokens=len(scheduled_spec_token_ids), - num_accepted_tokens=len(generated_token_ids) - 1) + num_accepted_tokens=len(generated_token_ids) - 1, + ) cached_encoder_input_ids = ( self.encoder_cache_manager.get_cached_input_ids(request)) @@ -862,7 +967,13 @@ def get_request_counts(self) -> tuple[int, int]: return len(self.running), len(self.waiting) def add_request(self, request: Request) -> None: - self.waiting.append(request) + if self.policy == "priority": + heapq.heappush( + cast(list[tuple[int, float, Request]], self.waiting), + (request.priority, request.arrival_time, request), + ) + else: + cast(deque[Request], self.waiting).append(request) self.requests[request.request_id] = request if self.log_stats: request.record_event(EngineCoreEventType.QUEUED) @@ -892,7 +1003,7 @@ def finish_requests( if request.status == RequestStatus.RUNNING: self.running.remove(request) else: - self.waiting.remove(request) + cast(deque[Request], self.waiting).remove(request) request.status = finished_status self._free_request(request) @@ -921,7 +1032,7 @@ def _free_blocks(self, request: Request): del self.requests[request.request_id] def get_num_unfinished_requests(self) -> int: - return len(self.waiting) + len(self.running) + return len(cast(deque[Request], self.waiting)) + len(self.running) def has_finished_requests(self) -> bool: return len(self.finished_req_ids) > 0 @@ -939,7 +1050,7 @@ def make_stats( assert prefix_cache_stats is not None return SchedulerStats( num_running_reqs=len(self.running), - num_waiting_reqs=len(self.waiting), + num_waiting_reqs=len(cast(deque[Request], self.waiting)), gpu_cache_usage=self.kv_cache_manager.usage, prefix_cache_stats=prefix_cache_stats, spec_decoding_stats=spec_decoding_stats, @@ -957,7 +1068,8 @@ def make_spec_decoding_stats( spec_decoding_stats = SpecDecodingStats.new(self.num_spec_tokens) spec_decoding_stats.observe_draft( num_draft_tokens=num_draft_tokens, - num_accepted_tokens=num_accepted_tokens) + num_accepted_tokens=num_accepted_tokens, + ) return spec_decoding_stats def shutdown(self) -> None: @@ -981,8 +1093,8 @@ def _connector_finished( """ if self.connector is None: return False, None - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" + assert (len(self.kv_cache_config.kv_cache_groups) == 1 + ), "KV connector only supports one KV cache group now" block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] return self.connector.request_finished(request, block_ids) @@ -1000,8 +1112,8 @@ def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ if request.request_id not in self.finished_recving_kv_req_ids: return False - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" + assert (len(self.kv_cache_config.kv_cache_groups) == 1 + ), "KV connector only supports one KV cache group now" # Now that the blocks are ready, actually cache them. block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] num_computed_tokens = len(block_ids) * self.block_size @@ -1032,9 +1144,9 @@ def _update_from_kv_xfer_finished(self, scheduler the request during the next step. """ # P/D: update recv and send status from last step. - for req_id in (model_runner_output.finished_recving or ()): + for req_id in model_runner_output.finished_recving or (): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id) - for req_id in (model_runner_output.finished_sending or ()): + for req_id in model_runner_output.finished_sending or (): logger.debug("Finished sending KV transfer for request %s", req_id) self._free_blocks(self.requests[req_id]) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 0c9f61a76427..46710e98db97 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -63,6 +63,7 @@ class EngineCoreRequest( # belong to, to cover a race condition where the request is sent before # a wave finished notification is received. current_wave: int = 0 + priority: int = 0 class EngineCoreEventType(enum.IntEnum): diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 64a756148780..15d9f5e917dd 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -217,8 +217,6 @@ def process_inputs( # TODO(woosuk): Support encoder-decoder models. self._validate_lora(lora_request) self._validate_params(params, lora_request) - if priority != 0: - raise ValueError("V1 does not support priority yet.") if trace_headers is not None: raise ValueError("V1 does not support tracing yet.") if prompt_adapter_request is not None: @@ -327,6 +325,7 @@ def process_inputs( arrival_time=arrival_time, lora_request=lora_request, cache_salt=decoder_inputs.get("cache_salt"), + priority=priority, ) def _validate_model_inputs(self, diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 42c75ef96401..350826386c0a 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 import enum +import time from typing import TYPE_CHECKING, Any, Optional, Union from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange @@ -27,17 +28,22 @@ def __init__( sampling_params: SamplingParams, eos_token_id: Optional[int], client_index: int = 0, + arrival_time: Optional[float] = None, lora_request: Optional["LoRARequest"] = None, structured_output_request: Optional["StructuredOutputRequest"] = None, cache_salt: Optional[str] = None, + priority: int = 0, ) -> None: self.request_id = request_id self.client_index = client_index + self.priority = priority self.sampling_params = sampling_params # Because of LoRA, the eos token id can be different for each request. self.eos_token_id = eos_token_id self.lora_request = lora_request self.structured_output_request = structured_output_request + self.arrival_time = arrival_time if arrival_time is not None else \ + time.time() self.status = (RequestStatus.WAITING_FOR_FSM if sampling_params.guided_decoding is not None else @@ -91,17 +97,18 @@ def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": return cls( request_id=request.request_id, - client_index=request.client_index, prompt_token_ids=request.prompt_token_ids, multi_modal_inputs=request.mm_inputs, multi_modal_hashes=request.mm_hashes, multi_modal_placeholders=request.mm_placeholders, sampling_params=request.sampling_params, eos_token_id=request.eos_token_id, + arrival_time=request.arrival_time, lora_request=request.lora_request, structured_output_request=StructuredOutputRequest( sampling_params=request.sampling_params), cache_salt=request.cache_salt, + priority=request.priority, ) def append_output_token_ids( From 4b2e513da7e23f1030adec07e78f7a1e599a4f10 Mon Sep 17 00:00:00 2001 From: amit Date: Tue, 3 Jun 2025 18:35:24 +0300 Subject: [PATCH 02/33] style(docs): split long line and wrap paragraph in note admonition Signed-off-by: amit --- docs/usage/v1_guide.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index b5a35c04a287..af0ef74026e2 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -72,7 +72,13 @@ This living user guide outlines a few known **important changes and limitations* way by using a simple dictionary (e.g., `{request_id: num_tokens}`) to dynamically allocate a fixed token budget per request, enabling features like chunked prefills, prefix caching, and speculative decoding without a strict separation between prefill -and decode phases. The V1 scheduler supports multiple scheduling policies, including First-Come, First-Served (FCFS) and priority-based scheduling (where requests are processed based on assigned priority, with FCFS as a tie-breaker), configurable via the `--scheduling-policy` argument. +and decode phases. + +!!! note + The V1 scheduler supports multiple scheduling policies, including First-Come, + First-Served (FCFS) and priority-based scheduling (where requests are processed + based on assigned priority, with FCFS as a tie-breaker), configurable via the + `--scheduling-policy` argument. ### Semantic Changes and Deprecated Features From 419de95f761068c34277f6cf54917e151326ee76 Mon Sep 17 00:00:00 2001 From: amit Date: Tue, 3 Jun 2025 18:54:04 +0300 Subject: [PATCH 03/33] style(docs): fix pymarkdown error Signed-off-by: amit --- docs/usage/v1_guide.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docs/usage/v1_guide.md b/docs/usage/v1_guide.md index af0ef74026e2..9992884a6ab1 100644 --- a/docs/usage/v1_guide.md +++ b/docs/usage/v1_guide.md @@ -68,17 +68,17 @@ This living user guide outlines a few known **important changes and limitations* - **🟠 Delayed**: Temporarily dropped in V1 but planned to be re-introduced later. - **🔴 Deprecated**: Not planned for V1 unless there is strong demand. -**Note**: vLLM V1’s unified scheduler treats both prompt and output tokens the same -way by using a simple dictionary (e.g., `{request_id: num_tokens}`) to dynamically -allocate a fixed token budget per request, enabling features like chunked prefills, -prefix caching, and speculative decoding without a strict separation between prefill -and decode phases. - !!! note - The V1 scheduler supports multiple scheduling policies, including First-Come, - First-Served (FCFS) and priority-based scheduling (where requests are processed - based on assigned priority, with FCFS as a tie-breaker), configurable via the - `--scheduling-policy` argument. + vLLM V1’s unified scheduler treats both prompt and output tokens the same + way by using a simple dictionary (e.g., `{request_id: num_tokens}`) to dynamically + allocate a fixed token budget per request, enabling features like chunked prefills, + prefix caching, and speculative decoding without a strict separation between prefill + and decode phases. + +The V1 scheduler supports multiple scheduling policies, including First-Come, +First-Served (FCFS) and priority-based scheduling (where requests are processed +based on assigned priority, with FCFS as a tie-breaker), configurable via the +`--scheduling-policy` argument. ### Semantic Changes and Deprecated Features From de76b7e861f1c865e5ab1620c416f66938235675 Mon Sep 17 00:00:00 2001 From: amit Date: Sun, 8 Jun 2025 10:00:10 +0300 Subject: [PATCH 04/33] Merge remote-tracking branch 'upstream/main' into v1-priority-schedular Signed-off-by: amit --- .buildkite/check-wheel-size.py | 1 + .buildkite/generate_index.py | 1 + .buildkite/lm-eval-harness/conftest.py | 1 + .../test_lm_eval_correctness.py | 1 + .../convert-results-json-to-markdown.py | 1 + .../scripts/download-tokenizer.py | 1 + .../scripts/generate-nightly-markdown.py | 1 + .../scripts/get-lmdeploy-modelname.py | 1 + .../scripts/summary-nightly-results.py | 1 + .buildkite/release-pipeline.yaml | 19 +- .buildkite/scripts/annotate-release.sh | 31 + .../hardware_ci/run-cpu-test-ppc64le.sh | 3 +- .../scripts/hardware_ci/run-cpu-test.sh | 13 +- .../scripts/hardware_ci/run-tpu-v1-test.sh | 2 +- .buildkite/scripts/tpu/cleanup_docker.sh | 24 + .buildkite/scripts/tpu/config_v6e_1.env | 14 + .buildkite/scripts/tpu/docker_run_bm.sh | 102 ++++ .buildkite/scripts/tpu/run_bm.sh | 94 +++ .buildkite/test-pipeline.yaml | 6 + .github/CODEOWNERS | 16 +- .github/PULL_REQUEST_TEMPLATE.md | 13 +- .pre-commit-config.yaml | 2 + CMakeLists.txt | 5 +- README.md | 10 +- benchmarks/auto_tune.sh | 164 +++-- benchmarks/backend_request_func.py | 1 + benchmarks/benchmark_dataset.py | 11 +- benchmarks/benchmark_latency.py | 1 + .../benchmark_long_document_qa_throughput.py | 1 + benchmarks/benchmark_prefix_caching.py | 1 + benchmarks/benchmark_prioritization.py | 1 + benchmarks/benchmark_serving.py | 1 + .../benchmark_serving_structured_output.py | 2 +- benchmarks/benchmark_throughput.py | 1 + benchmarks/benchmark_utils.py | 8 +- .../cutlass_benchmarks/sparse_benchmarks.py | 1 + benchmarks/cutlass_benchmarks/utils.py | 1 + .../cutlass_benchmarks/w8a8_benchmarks.py | 1 + .../cutlass_benchmarks/weight_shapes.py | 1 + .../disagg_prefill_proxy_server.py | 1 + .../disagg_benchmarks/round_robin_proxy.py | 1 + .../visualize_benchmark_results.py | 1 + .../fused_kernels/layernorm_rms_benchmarks.py | 1 + benchmarks/kernels/bench_fp8_gemm.py | 3 +- benchmarks/kernels/benchmark_aqlm.py | 1 + benchmarks/kernels/benchmark_bitblas.py | 1 + .../kernels/benchmark_cutlass_fp4_moe.py | 3 +- .../kernels/benchmark_grouped_gemm_cutlass.py | 62 +- benchmarks/kernels/benchmark_layernorm.py | 1 + benchmarks/kernels/benchmark_lora.py | 1 + benchmarks/kernels/benchmark_machete.py | 1 + benchmarks/kernels/benchmark_marlin.py | 1 + benchmarks/kernels/benchmark_moe.py | 1 + .../benchmark_moe_permute_unpermute.py | 1 + .../kernels/benchmark_paged_attention.py | 1 + benchmarks/kernels/benchmark_quant.py | 1 + benchmarks/kernels/benchmark_rmsnorm.py | 1 + benchmarks/kernels/benchmark_rope.py | 1 + benchmarks/kernels/benchmark_shapes.py | 1 + .../kernels/benchmark_w8a8_block_fp8.py | 1 + .../benchmark_fp8_block_dense_gemm.py | 1 + benchmarks/kernels/graph_machete_bench.py | 1 + benchmarks/kernels/utils.py | 1 + benchmarks/kernels/weight_shapes.py | 1 + benchmarks/overheads/benchmark_hashing.py | 1 + cmake/hipify.py | 1 + csrc/attention/mla/cutlass_mla_kernels.cu | 2 +- .../vllm_cutlass_library_extension.py | 1 + csrc/moe/marlin_moe_wna16/generate_kernels.py | 1 + csrc/moe/moe_ops.h | 6 +- csrc/moe/moe_permute_unpermute_op.cu | 56 ++ csrc/moe/permute_unpermute_kernels/dispatch.h | 18 +- csrc/moe/topk_softmax_kernels.cu | 16 +- csrc/moe/torch_bindings.cpp | 6 + csrc/ops.h | 19 +- .../c3x/scaled_mm_blockwise_sm100_fp8.cu | 4 - ...scaled_mm_blockwise_sm100_fp8_dispatch.cuh | 206 +++++-- .../c3x/scaled_mm_sm100_fp8_dispatch.cuh | 53 +- .../cutlass_w8a8/moe/grouped_mm_c3x.cu | 29 +- .../cutlass_w8a8/moe/grouped_mm_c3x.cuh | 6 +- .../quantization/cutlass_w8a8/moe/moe_data.cu | 78 ++- .../cutlass_w8a8/scaled_mm_entry.cu | 48 +- csrc/quantization/fp8/common.cu | 35 +- csrc/quantization/fp8/common.cuh | 68 ++- .../fused_kernels/layernorm_utils.cuh | 99 +-- .../gptq_marlin/generate_kernels.py | 1 + csrc/quantization/machete/generate.py | 1 + csrc/quantization/vectorization.cuh | 23 +- csrc/sampler.cu | 86 +++ csrc/torch_bindings.cpp | 28 +- docker/Dockerfile.nightly_torch | 3 + docker/Dockerfile.ppc64le | 96 ++- docker/Dockerfile.rocm | 5 +- docs/cli/README.md | 2 + docs/contributing/ci-failures.md | 120 ++++ docs/deployment/docker.md | 4 +- docs/features/tool_calling.md | 6 +- docs/mkdocs/hooks/generate_examples.py | 1 + docs/mkdocs/hooks/remove_announcement.py | 1 + docs/mkdocs/hooks/url_schemes.py | 1 + docs/models/extensions/tensorizer.md | 2 +- docs/models/supported_models.md | 1 + docs/usage/v1_guide.md | 8 +- examples/offline_inference/audio_language.py | 1 + .../automatic_prefix_caching.py | 1 + examples/offline_inference/basic/basic.py | 1 + examples/offline_inference/basic/chat.py | 1 + examples/offline_inference/basic/classify.py | 1 + examples/offline_inference/basic/embed.py | 1 + examples/offline_inference/basic/generate.py | 1 + examples/offline_inference/basic/score.py | 1 + .../offline_inference/batch_llm_inference.py | 1 + examples/offline_inference/chat_with_tools.py | 1 + .../offline_inference/context_extension.py | 82 ++- examples/offline_inference/data_parallel.py | 1 + .../decode_example.py | 1 + .../prefill_example.py | 1 + .../disaggregated_prefill.py | 1 + examples/offline_inference/eagle.py | 1 + .../embed_jina_embeddings_v3.py | 1 + .../offline_inference/embed_matryoshka_fy.py | 1 + examples/offline_inference/encoder_decoder.py | 1 + .../encoder_decoder_multimodal.py | 1 + .../offline_inference/llm_engine_example.py | 1 + .../offline_inference/load_sharded_state.py | 1 + .../lora_with_quantization_inference.py | 1 + examples/offline_inference/metrics.py | 1 + examples/offline_inference/mistral-small.py | 1 + examples/offline_inference/mlpspeculator.py | 1 + .../offline_inference/multilora_inference.py | 1 + examples/offline_inference/neuron.py | 1 + examples/offline_inference/neuron_eagle.py | 1 + .../neuron_int8_quantization.py | 1 + .../offline_inference/neuron_multimodal.py | 1 + .../offline_inference/neuron_speculation.py | 1 + examples/offline_inference/prefix_caching.py | 1 + .../prithvi_geospatial_mae.py | 1 + examples/offline_inference/profiling.py | 1 + .../profiling_tpu/profiling.py | 1 + .../prompt_embed_inference.py | 1 + .../qwen2_5_omni/only_thinker.py | 1 + examples/offline_inference/qwen_1m.py | 1 + examples/offline_inference/reproducibility.py | 1 + examples/offline_inference/rlhf.py | 1 + examples/offline_inference/rlhf_colocate.py | 1 + examples/offline_inference/rlhf_utils.py | 1 + .../offline_inference/save_sharded_state.py | 1 + .../offline_inference/simple_profiling.py | 1 + .../offline_inference/structured_outputs.py | 1 + .../offline_inference/torchrun_example.py | 1 + examples/offline_inference/tpu.py | 1 + examples/offline_inference/vision_language.py | 1 + .../vision_language_embedding.py | 1 + .../vision_language_multi_image.py | 37 +- examples/online_serving/api_client.py | 1 + .../online_serving/cohere_rerank_client.py | 1 + .../disagg_proxy_demo.py | 1 + .../gradio_openai_chatbot_webserver.py | 1 + examples/online_serving/gradio_webserver.py | 1 + .../online_serving/jinaai_rerank_client.py | 1 + .../online_serving/kv_events_subscriber.py | 1 + .../multi_instance_data_parallel.py | 58 ++ .../openai_chat_completion_client.py | 1 + ...i_chat_completion_client_for_multimodal.py | 1 + ...penai_chat_completion_client_with_tools.py | 1 + ...t_completion_client_with_tools_required.py | 1 + ...enai_chat_completion_structured_outputs.py | 2 +- ...etion_structured_outputs_structural_tag.py | 1 + ...etion_structured_outputs_with_reasoning.py | 1 + ...at_completion_tool_calls_with_reasoning.py | 1 + .../openai_chat_completion_with_reasoning.py | 1 + ...hat_completion_with_reasoning_streaming.py | 1 + ...ai_chat_embedding_client_for_multimodal.py | 1 + .../openai_classification_client.py | 1 + .../openai_completion_client.py | 1 + .../openai_cross_encoder_score.py | 1 + .../online_serving/openai_embedding_client.py | 1 + .../openai_embedding_matryoshka_fy.py | 1 + .../online_serving/openai_pooling_client.py | 1 + .../openai_transcription_client.py | 1 + .../opentelemetry/dummy_client.py | 1 + ...ompt_embed_inference_with_openai_client.py | 1 + examples/online_serving/ray_serve_deepseek.py | 1 + ...val_augmented_generation_with_langchain.py | 1 + ...al_augmented_generation_with_llamaindex.py | 1 + .../streamlit_openai_chatbot_webserver.py | 1 + examples/online_serving/utils.py | 1 + .../others/lmcache/cpu_offload_lmcache.py | 1 + .../lmcache/disagg_prefill_lmcache_v0.py | 1 + .../disagg_example_nixl.sh | 6 +- .../disagg_proxy_server.py | 1 + .../lmcache/kv_cache_sharing_lmcache_v1.py | 1 + examples/others/tensorize_vllm_model.py | 1 + examples/tool_chat_template_deepseekr1.jinja | 92 +++ find_cuda_init.py | 1 + requirements/common.txt | 2 +- requirements/cpu.txt | 3 + requirements/nightly_torch_test.txt | 9 +- requirements/rocm.txt | 3 +- requirements/test.in | 2 +- requirements/test.txt | 2 +- requirements/tpu.txt | 10 +- setup.py | 2 + tests/async_engine/api_server_async_engine.py | 1 + tests/async_engine/conftest.py | 1 + tests/async_engine/test_api_server.py | 1 + tests/async_engine/test_async_llm_engine.py | 23 + tests/async_engine/test_request_tracker.py | 1 + .../test_basic_correctness.py | 29 +- .../basic_correctness/test_chunked_prefill.py | 1 + tests/basic_correctness/test_cpu_offload.py | 1 + tests/basic_correctness/test_cumem.py | 1 + tests/basic_correctness/test_preemption.py | 1 + tests/benchmarks/test_latency_cli.py | 1 + tests/benchmarks/test_serve_cli.py | 1 + tests/benchmarks/test_throughput_cli.py | 1 + tests/build_cython.py | 1 + tests/compile/backend.py | 1 + tests/compile/conftest.py | 1 + .../compile/piecewise/test_full_cudagraph.py | 8 +- tests/compile/piecewise/test_simple.py | 3 +- tests/compile/piecewise/test_toy_llama.py | 7 +- tests/compile/test_async_tp.py | 1 + tests/compile/test_basic_correctness.py | 1 + tests/compile/test_config.py | 43 ++ tests/compile/test_full_graph.py | 1 + tests/compile/test_functionalization.py | 1 + tests/compile/test_fusion.py | 1 + tests/compile/test_pass_manager.py | 1 + tests/compile/test_sequence_parallelism.py | 1 + tests/compile/test_silu_mul_quant_fusion.py | 1 + tests/compile/test_wrapper.py | 1 + tests/conftest.py | 1 + tests/core/block/conftest.py | 1 + tests/core/block/e2e/conftest.py | 1 + tests/core/block/e2e/test_correctness.py | 1 + .../e2e/test_correctness_sliding_window.py | 1 + tests/core/block/test_block_manager.py | 1 + tests/core/block/test_block_table.py | 1 + tests/core/block/test_common.py | 1 + .../block/test_cpu_gpu_block_allocator.py | 1 + tests/core/block/test_naive_block.py | 1 + tests/core/block/test_prefix_caching_block.py | 1 + tests/core/conftest.py | 1 + tests/core/test_chunked_prefill_scheduler.py | 1 + tests/core/test_num_computed_tokens_update.py | 1 + tests/core/test_scheduler.py | 1 + tests/core/test_scheduler_encoder_decoder.py | 1 + tests/core/test_serialization.py | 1 + tests/core/utils.py | 1 + tests/detokenizer/conftest.py | 1 + .../test_disable_detokenization.py | 1 + tests/detokenizer/test_stop_checker.py | 1 + tests/detokenizer/test_stop_reason.py | 1 + tests/detokenizer/test_stop_strings.py | 1 + tests/distributed/conftest.py | 108 ++-- tests/distributed/test_ca_buffer_sharing.py | 1 + tests/distributed/test_comm_ops.py | 1 + tests/distributed/test_custom_all_reduce.py | 1 + tests/distributed/test_distributed_oot.py | 1 + tests/distributed/test_events.py | 70 ++- tests/distributed/test_expert_parallel.py | 1 + .../distributed/test_multi_node_assignment.py | 1 + tests/distributed/test_pipeline_parallel.py | 1 + tests/distributed/test_pipeline_partition.py | 1 + tests/distributed/test_pp_cudagraph.py | 1 + tests/distributed/test_pynccl.py | 1 + tests/distributed/test_same_node.py | 1 + tests/distributed/test_sequence_parallel.py | 1 + tests/distributed/test_shm_broadcast.py | 1 + tests/distributed/test_torchrun_example.py | 1 + tests/distributed/test_utils.py | 1 + tests/encoder_decoder/test_e2e_correctness.py | 1 + tests/engine/conftest.py | 1 + tests/engine/test_arg_utils.py | 19 +- tests/engine/test_computed_prefix_blocks.py | 1 + tests/engine/test_executor.py | 1 + .../test_multi_step_output_processor.py | 1 + tests/engine/test_multiproc_workers.py | 1 + tests/engine/test_options.py | 1 + tests/engine/test_short_mm_context.py | 1 + tests/entrypoints/conftest.py | 1 + tests/entrypoints/llm/test_accuracy.py | 1 + tests/entrypoints/llm/test_chat.py | 1 + tests/entrypoints/llm/test_collective_rpc.py | 1 + tests/entrypoints/llm/test_encode.py | 1 + tests/entrypoints/llm/test_generate.py | 1 + .../llm/test_generate_multiple_loras.py | 1 + tests/entrypoints/llm/test_gpu_utilization.py | 1 + tests/entrypoints/llm/test_guided_generate.py | 1 + tests/entrypoints/llm/test_lazy_outlines.py | 1 + .../entrypoints/llm/test_prompt_validation.py | 1 + .../offline_mode/test_offline_mode.py | 1 + .../openai/correctness/test_lmeval.py | 1 + .../openai/correctness/test_mteb.py | 1 + .../test_transcription_api_correctness.py | 1 + .../openai/test_async_tokenization.py | 1 + tests/entrypoints/openai/test_audio.py | 1 + tests/entrypoints/openai/test_basic.py | 1 + tests/entrypoints/openai/test_chat.py | 1 + tests/entrypoints/openai/test_chat_echo.py | 1 + .../openai/test_chat_logit_bias_validation.py | 1 + .../entrypoints/openai/test_chat_template.py | 1 + .../openai/test_chat_with_tool_reasoning.py | 1 + .../entrypoints/openai/test_chunked_prompt.py | 1 + .../entrypoints/openai/test_classification.py | 1 + tests/entrypoints/openai/test_cli_args.py | 1 + tests/entrypoints/openai/test_completion.py | 1 + .../test_completion_with_function_calling.py | 76 ++- .../test_completion_with_prompt_embeds.py | 1 + tests/entrypoints/openai/test_embedding.py | 1 + .../openai/test_embedding_dimensions.py | 1 + .../openai/test_encoder_decoder.py | 1 + .../entrypoints/openai/test_lora_adapters.py | 1 + .../entrypoints/openai/test_lora_resolvers.py | 1 + tests/entrypoints/openai/test_metrics.py | 1 + tests/entrypoints/openai/test_models.py | 1 + .../openai/test_oot_registration.py | 1 + .../entrypoints/openai/test_openai_schema.py | 1 + tests/entrypoints/openai/test_pooling.py | 1 + .../openai/test_prompt_validation.py | 1 + tests/entrypoints/openai/test_rerank.py | 1 + .../openai/test_return_tokens_as_ids.py | 1 + tests/entrypoints/openai/test_root_path.py | 1 + tests/entrypoints/openai/test_run_batch.py | 1 + tests/entrypoints/openai/test_score.py | 1 + tests/entrypoints/openai/test_serving_chat.py | 1 + .../entrypoints/openai/test_serving_models.py | 1 + tests/entrypoints/openai/test_shutdown.py | 1 + tests/entrypoints/openai/test_sleep.py | 1 + .../openai/test_tensorizer_entrypoint.py | 1 + tests/entrypoints/openai/test_tokenization.py | 1 + .../openai/test_transcription_validation.py | 1 + tests/entrypoints/openai/test_truncation.py | 1 + tests/entrypoints/openai/test_video.py | 1 + tests/entrypoints/openai/test_vision.py | 1 + .../openai/test_vision_embedding.py | 1 + .../test_llama4_pythonic_tool_parser.py | 1 + .../tool_parsers/test_pythonic_tool_parser.py | 1 + .../entrypoints/openai/tool_parsers/utils.py | 1 + .../test_api_server_process_manager.py | 1 + tests/entrypoints/test_chat_utils.py | 1 + tests/entrypoints/test_ssl_cert_refresher.py | 1 + .../test_fastsafetensors_loader.py | 1 + .../test_weight_utils.py | 1 + tests/kernels/allclose_default.py | 1 + tests/kernels/attention/conftest.py | 1 + tests/kernels/attention/test_attention.py | 1 + .../attention/test_attention_selector.py | 6 +- .../attention/test_blocksparse_attention.py | 1 + tests/kernels/attention/test_cache.py | 1 + .../attention/test_cascade_flash_attn.py | 1 + .../attention/test_encoder_decoder_attn.py | 1 + tests/kernels/attention/test_flash_attn.py | 1 + tests/kernels/attention/test_flashinfer.py | 1 + tests/kernels/attention/test_flashmla.py | 1 + .../kernels/attention/test_lightning_attn.py | 1 + .../attention/test_merge_attn_states.py | 1 + tests/kernels/attention/test_mha_attn.py | 1 + .../kernels/attention/test_mla_decode_cpu.py | 1 + .../kernels/attention/test_prefix_prefill.py | 1 + .../attention/test_rocm_attention_selector.py | 1 + .../attention/test_triton_decode_attention.py | 1 + .../test_triton_unified_attention.py | 1 + tests/kernels/core/test_activation.py | 1 + .../core/test_fused_quant_layernorm.py | 1 + tests/kernels/core/test_layernorm.py | 1 + tests/kernels/core/test_opcheck.py | 1 + tests/kernels/core/test_permute_cols.py | 1 + tests/kernels/core/test_pos_encoding.py | 1 + tests/kernels/core/test_rotary_embedding.py | 1 + tests/kernels/core/test_uva.py | 1 + tests/kernels/mamba/test_causal_conv1d.py | 1 + tests/kernels/mamba/test_mamba_mixer2.py | 1 + tests/kernels/mamba/test_mamba_ssm.py | 1 + tests/kernels/mamba/test_mamba_ssm_ssd.py | 1 + tests/kernels/moe/__init__.py | 0 tests/kernels/moe/deepep_utils.py | 191 ++++++ tests/kernels/moe/test_batched_moe.py | 1 + tests/kernels/moe/test_cutlass_moe.py | 9 +- tests/kernels/moe/test_deepep_deepgemm_moe.py | 513 ++++++++++++++++ tests/kernels/moe/test_deepep_moe.py | 459 ++++++++++++++ tests/kernels/moe/test_moe.py | 1 + .../kernels/moe/test_moe_permute_unpermute.py | 1 + tests/kernels/moe/test_nvfp4_moe.py | 6 +- tests/kernels/moe/test_pplx_cutlass_moe.py | 287 +++++++++ tests/kernels/moe/test_pplx_moe.py | 132 +--- tests/kernels/moe/test_rocm_aiter_topk.py | 1 + tests/kernels/moe/test_triton_moe_ptpc_fp8.py | 1 + tests/kernels/quant_utils.py | 1 + tests/kernels/quantization/nvfp4_utils.py | 1 + .../quantization/test_allspark_gemm.py | 1 + tests/kernels/quantization/test_aqlm.py | 1 + tests/kernels/quantization/test_awq.py | 1 + tests/kernels/quantization/test_awq_triton.py | 1 + tests/kernels/quantization/test_block_fp8.py | 1 + tests/kernels/quantization/test_block_int8.py | 1 + .../quantization/test_cutlass_2of4_sparse.py | 1 + .../quantization/test_cutlass_scaled_mm.py | 4 +- tests/kernels/quantization/test_fp8_quant.py | 1 + tests/kernels/quantization/test_ggml.py | 1 + tests/kernels/quantization/test_gguf.py | 1 + tests/kernels/quantization/test_gptq.py | 1 + .../kernels/quantization/test_int8_kernel.py | 1 + tests/kernels/quantization/test_int8_quant.py | 1 + tests/kernels/quantization/test_machete_mm.py | 1 + .../kernels/quantization/test_marlin_gemm.py | 1 + .../kernels/quantization/test_nvfp4_quant.py | 1 + .../quantization/test_nvfp4_scaled_mm.py | 1 + .../quantization/test_rocm_skinny_gemms.py | 1 + .../quantization/test_triton_scaled_mm.py | 1 + .../test_apply_repetition_penalties.py | 76 +++ tests/kernels/test_cutlass_mla_decode.py | 5 +- tests/kernels/test_flex_attention.py | 93 +++ tests/kernels/test_fused_quant_activation.py | 1 + tests/kernels/test_triton_flash_attention.py | 1 + tests/kernels/utils.py | 1 + tests/kv_transfer/test_disagg.py | 1 + tests/kv_transfer/test_lookup_buffer.py | 1 + tests/kv_transfer/test_module.py | 1 + tests/kv_transfer/test_send_recv.py | 1 + tests/lora/conftest.py | 1 + tests/lora/test_add_lora.py | 1 + tests/lora/test_baichuan.py | 1 + tests/lora/test_chatglm3_tp.py | 1 + tests/lora/test_layers.py | 1 + tests/lora/test_llama_tp.py | 1 + tests/lora/test_lora_allowed_token_ids.py | 1 + tests/lora/test_lora_checkpoints.py | 1 + tests/lora/test_lora_functions.py | 1 + tests/lora/test_lora_huggingface.py | 1 + tests/lora/test_lora_manager.py | 1 + tests/lora/test_minicpmv_tp.py | 1 + tests/lora/test_mixtral.py | 1 + tests/lora/test_peft_helper.py | 1 + tests/lora/test_phi.py | 1 + tests/lora/test_punica_ops.py | 1 + tests/lora/test_quant_model.py | 1 + tests/lora/test_qwen2vl.py | 1 + tests/lora/test_resolver.py | 1 + tests/lora/test_tokenizer_group.py | 1 + tests/lora/test_transfomers_model.py | 1 + tests/lora/test_utils.py | 1 + tests/lora/test_worker.py | 1 + tests/lora/utils.py | 1 + tests/metrics/test_metrics.py | 1 + tests/mistral_tool_use/conftest.py | 1 + .../test_mistral_tool_calls.py | 1 + tests/mistral_tool_use/utils.py | 1 + tests/model_executor/conftest.py | 1 + .../model_executor/test_enabled_custom_ops.py | 1 + .../model_executor/test_guided_processors.py | 1 + tests/model_executor/test_logits_processor.py | 1 + .../test_model_load_with_params.py | 1 + tests/model_executor/test_weight_utils.py | 1 + tests/models/language/generation/test_bart.py | 1 + .../models/language/generation/test_common.py | 2 +- .../language/generation/test_granite.py | 1 + .../generation/test_granitemoehybrid.py | 1 + .../models/language/generation/test_hybrid.py | 1 + .../language/generation/test_mistral.py | 1 + .../models/language/generation/test_phimoe.py | 1 + tests/models/language/pooling/embed_utils.py | 7 +- tests/models/language/pooling/mteb_utils.py | 13 +- tests/models/language/pooling/test_baai.py | 1 + .../language/pooling/test_classification.py | 1 + .../models/language/pooling/test_embedding.py | 1 + tests/models/language/pooling/test_gritlm.py | 1 + tests/models/language/pooling/test_gte.py | 8 +- .../models/language/pooling/test_intfloat.py | 46 ++ tests/models/language/pooling/test_jina.py | 4 +- tests/models/language/pooling/test_nomic.py | 4 +- .../pooling/test_nomic_max_model_len.py | 1 + tests/models/language/pooling/test_scoring.py | 1 + .../pooling/test_snowflake_arctic_embed.py | 1 + .../pooling/test_truncation_control.py | 1 + .../multimodal/generation/test_common.py | 10 +- .../multimodal/generation/test_florence2.py | 3 + .../generation/test_granite_speech.py | 3 +- .../multimodal/generation/test_interleaved.py | 1 + .../multimodal/generation/test_mllama.py | 1 + .../multimodal/generation/test_phi4mm.py | 5 + .../multimodal/generation/test_pixtral.py | 1 + .../multimodal/generation/test_qwen2_vl.py | 1 + .../multimodal/generation/test_ultravox.py | 1 + .../multimodal/generation/test_whisper.py | 1 + .../generation/vlm_utils/builders.py | 1 + .../generation/vlm_utils/case_filtering.py | 1 + .../multimodal/generation/vlm_utils/core.py | 1 + .../generation/vlm_utils/custom_inputs.py | 1 + .../generation/vlm_utils/model_utils.py | 19 +- .../generation/vlm_utils/runners.py | 1 + .../multimodal/generation/vlm_utils/types.py | 1 + .../multimodal/pooling/test_dse_qwen2_vl.py | 1 + .../multimodal/pooling/test_intern_vit.py | 1 + .../multimodal/pooling/test_llava_next.py | 1 + tests/models/multimodal/pooling/test_phi3v.py | 1 + .../multimodal/processing/test_common.py | 3 +- .../multimodal/processing/test_h2ovl.py | 1 + .../multimodal/processing/test_idefics3.py | 1 + .../multimodal/processing/test_internvl.py | 1 + .../multimodal/processing/test_llama4.py | 1 + .../multimodal/processing/test_llava_next.py | 1 + .../processing/test_llava_onevision.py | 1 + .../processing/test_minimax_vl_01.py | 1 + .../multimodal/processing/test_mllama.py | 1 + .../multimodal/processing/test_phi3v.py | 1 + .../multimodal/processing/test_phi4mm.py | 1 + .../multimodal/processing/test_qwen2_vl.py | 1 + .../multimodal/processing/test_smolvlm.py | 1 + tests/models/quantization/test_aqlm.py | 1 + tests/models/quantization/test_awq.py | 1 + tests/models/quantization/test_bitblas.py | 1 + tests/models/quantization/test_fp8.py | 1 + tests/models/quantization/test_gguf.py | 3 +- .../models/quantization/test_gptq_bitblas.py | 1 + tests/models/quantization/test_gptq_marlin.py | 1 + .../quantization/test_gptq_marlin_24.py | 1 + tests/models/quantization/test_modelopt.py | 1 + tests/models/quantization/test_mxfp4.py | 1 + tests/models/quantization/test_nvfp4.py | 1 + tests/models/registry.py | 50 +- tests/models/test_initialization.py | 14 + tests/models/test_oot_registration.py | 1 + tests/models/test_registry.py | 1 + tests/models/test_transformers.py | 1 + tests/models/test_utils.py | 1 + tests/models/test_vision.py | 1 + tests/models/utils.py | 1 + tests/mq_llm_engine/conftest.py | 1 + tests/mq_llm_engine/test_abort.py | 1 + tests/mq_llm_engine/test_error_handling.py | 1 + tests/mq_llm_engine/test_load.py | 1 + tests/mq_llm_engine/utils.py | 1 + .../multi_step/test_correctness_async_llm.py | 1 + tests/multi_step/test_correctness_llm.py | 1 + tests/multimodal/test_hasher.py | 1 + tests/multimodal/test_image.py | 1 + tests/multimodal/test_inputs.py | 1 + tests/multimodal/test_processing.py | 1 + tests/multimodal/test_utils.py | 112 +++- tests/multimodal/test_video.py | 1 + tests/multimodal/utils.py | 1 + tests/neuron/1_core/test_activation.py | 1 + tests/neuron/1_core/test_block_table.py | 1 + tests/neuron/1_core/test_cache.py | 1 + tests/neuron/1_core/test_layernorm.py | 1 + tests/neuron/1_core/test_logits_processor.py | 1 + .../neuron/1_core/test_neuron_model_runner.py | 1 + tests/neuron/1_core/test_neuron_quant.py | 1 + tests/neuron/1_core/test_prefix_prefill.py | 1 + tests/neuron/1_core/test_rotary_embedding.py | 1 + tests/neuron/2_core/test_comm_ops.py | 1 + tests/neuron/2_core/test_eagle.py | 1 + tests/neuron/2_core/test_mistral.py | 1 + tests/neuron/2_core/test_multi_lora.py | 1 + .../test_filesystem_resolver.py | 1 + tests/plugins/vllm_add_dummy_model/setup.py | 1 + .../vllm_add_dummy_model/__init__.py | 1 + .../my_gemma_embedding.py | 1 + .../vllm_add_dummy_model/my_llava.py | 1 + .../vllm_add_dummy_model/my_opt.py | 1 + .../plugins/vllm_add_dummy_platform/setup.py | 1 + .../vllm_add_dummy_platform/__init__.py | 1 + .../dummy_attention_backend.py | 1 + .../vllm_add_dummy_platform/dummy_platform.py | 1 + tests/plugins_tests/conftest.py | 1 + tests/plugins_tests/test_platform_plugins.py | 1 + tests/plugins_tests/test_scheduler_plugins.py | 1 + tests/pplx_utils.py | 123 ++++ .../test_disable_sliding_window.py | 1 + tests/prefix_caching/test_prefix_caching.py | 1 + tests/prompt_adapter/test_bloom.py | 1 + .../test_multi_adapter_inference.py | 1 + tests/prompt_adapter/test_pa_lora.py | 1 + tests/quantization/test_auto_round.py | 1 + tests/quantization/test_bitsandbytes.py | 1 + tests/quantization/test_compressed_tensors.py | 3 +- tests/quantization/test_configs.py | 1 + tests/quantization/test_cpu_offload.py | 3 +- tests/quantization/test_experts_int8.py | 1 + tests/quantization/test_fp8.py | 1 + tests/quantization/test_gptq_dynamic.py | 1 + tests/quantization/test_ipex_quant.py | 1 + tests/quantization/test_lm_head.py | 1 + tests/quantization/test_ptpc_fp8.py | 1 + tests/quantization/test_quark.py | 1 + .../test_register_quantization_config.py | 1 + tests/quantization/test_torchao.py | 7 +- tests/quantization/utils.py | 1 + .../test_deepseekr1_reasoning_parser.py | 1 + .../test_granite_reasoning_parser.py | 1 + .../reasoning/test_qwen3_reasoning_parser.py | 1 + tests/reasoning/utils.py | 1 + .../test_runai_model_streamer_loader.py | 1 + .../test_weight_utils.py | 1 + tests/samplers/test_beam_search.py | 1 + tests/samplers/test_ignore_eos.py | 1 + tests/samplers/test_logits_processor.py | 1 + tests/samplers/test_logprobs.py | 1 + tests/samplers/test_no_bad_words.py | 1 + tests/samplers/test_ranks.py | 1 + tests/samplers/test_rejection_sampler.py | 1 + tests/samplers/test_sampler.py | 1 + tests/samplers/test_seeded_generate.py | 1 + .../test_typical_acceptance_sampler.py | 1 + tests/spec_decode/conftest.py | 1 + tests/spec_decode/e2e/conftest.py | 1 + tests/spec_decode/e2e/test_compatibility.py | 1 + .../spec_decode/e2e/test_eagle_correctness.py | 1 + tests/spec_decode/e2e/test_integration.py | 1 + .../e2e/test_integration_dist_tp2.py | 1 + .../e2e/test_integration_dist_tp4.py | 1 + tests/spec_decode/e2e/test_logprobs.py | 1 + .../e2e/test_medusa_correctness.py | 1 + tests/spec_decode/e2e/test_mlp_correctness.py | 1 + tests/spec_decode/e2e/test_mtp_correctness.py | 1 + .../e2e/test_multistep_correctness.py | 1 + .../spec_decode/e2e/test_ngram_correctness.py | 1 + tests/spec_decode/e2e/test_seed.py | 1 + tests/spec_decode/test_batch_expansion.py | 1 + tests/spec_decode/test_dynamic_spec_decode.py | 1 + tests/spec_decode/test_memory_usage.py | 1 + tests/spec_decode/test_metrics.py | 1 + tests/spec_decode/test_multi_step_worker.py | 1 + tests/spec_decode/test_ngram_worker.py | 1 + tests/spec_decode/test_scorer.py | 1 + tests/spec_decode/test_spec_decode_worker.py | 1 + tests/spec_decode/test_utils.py | 1 + tests/spec_decode/utils.py | 1 + tests/standalone_tests/lazy_imports.py | 1 + tests/tensorizer_loader/conftest.py | 1 + tests/tensorizer_loader/test_tensorizer.py | 1 + tests/test_cache_block_hashing.py | 1 + tests/test_config.py | 14 + tests/test_embedded_commit.py | 1 + tests/test_inputs.py | 1 + tests/test_logger.py | 1 + tests/test_outputs.py | 1 + tests/test_regression.py | 1 + tests/test_sampling_params.py | 1 + tests/test_scalartype.py | 1 + tests/test_seed_behavior.py | 3 +- tests/test_sequence.py | 1 + tests/test_sharded_state_loader.py | 1 + tests/test_triton_utils.py | 1 + tests/test_utils.py | 1 + tests/test_version.py | 1 + tests/test_vllm_port.py | 1 + tests/tokenization/test_cached_tokenizer.py | 1 + tests/tokenization/test_detokenize.py | 4 +- tests/tokenization/test_get_eos.py | 1 + tests/tokenization/test_mistral_tokenizer.py | 1 + tests/tokenization/test_tokenizer.py | 1 + tests/tokenization/test_tokenizer_group.py | 1 + tests/tokenization/test_tokenizer_registry.py | 1 + tests/tool_use/conftest.py | 1 + ...est_chat_completion_request_validations.py | 1 + tests/tool_use/test_chat_completions.py | 1 + tests/tool_use/test_jamba_tool_parser.py | 1 + tests/tool_use/test_parallel_tool_calls.py | 1 + tests/tool_use/test_tool_calls.py | 1 + tests/tool_use/test_tool_choice_required.py | 1 + tests/tool_use/utils.py | 1 + tests/tpu/lora/test_lora.py | 1 + tests/tpu/test_compilation.py | 6 +- tests/tpu/test_custom_dispatcher.py | 1 + tests/tpu/test_moe_pallas.py | 3 +- tests/tpu/test_quantization_accuracy.py | 1 + tests/tracing/test_tracing.py | 1 + tests/utils.py | 1 + tests/v1/core/test_kv_cache_utils.py | 326 ++++++++-- tests/v1/core/test_prefix_caching.py | 451 +++++++++++--- tests/v1/core/test_scheduler.py | 17 +- tests/v1/core/test_scheduler_e2e.py | 1 + tests/v1/core/test_specialized_manager.py | 51 +- tests/v1/e2e/test_cascade_attention.py | 1 + .../v1/e2e/test_correctness_sliding_window.py | 5 +- tests/v1/e2e/test_spec_decode.py | 1 + tests/v1/engine/conftest.py | 1 + tests/v1/engine/test_async_llm.py | 199 ++++-- tests/v1/engine/test_engine_args.py | 1 + tests/v1/engine/test_engine_core.py | 2 + tests/v1/engine/test_engine_core_client.py | 191 +++++- tests/v1/engine/test_llm_engine.py | 1 + tests/v1/engine/test_output_processor.py | 6 + tests/v1/engine/utils.py | 1 + tests/v1/entrypoints/conftest.py | 1 + .../llm/test_struct_output_generate.py | 1 + .../openai/test_chat_completion.py | 1 + .../v1/entrypoints/openai/test_completion.py | 1 + .../openai/test_multi_api_servers.py | 1 + .../nixl_integration/run_accuracy_test.sh | 11 +- .../nixl_integration/test_accuracy.py | 2 + .../nixl_integration/test_edge_cases.py | 1 + .../nixl_integration/toy_proxy_server.py | 1 + .../kv_connector/unit/test_multi_connector.py | 78 ++- .../kv_connector/unit/test_nixl_connector.py | 5 +- .../unit/test_remote_decode_lifecycle.py | 5 +- .../unit/test_remote_prefill_lifecycle.py | 25 +- tests/v1/kv_connector/unit/utils.py | 11 +- tests/v1/metrics/test_ray_metrics.py | 6 +- tests/v1/sample/test_logprobs.py | 14 +- tests/v1/sample/test_logprobs_e2e.py | 1 + tests/v1/sample/test_rejection_sampler.py | 1 + tests/v1/sample/test_sampler.py | 1 + tests/v1/sample/test_sampling_params_e2e.py | 1 + tests/v1/sample/test_topk_topp_sampler.py | 1 + tests/v1/sample/utils.py | 1 + tests/v1/shutdown/test_delete.py | 1 + tests/v1/shutdown/test_forward_error.py | 1 + tests/v1/shutdown/test_processor_error.py | 1 + tests/v1/shutdown/test_startup_error.py | 1 + tests/v1/shutdown/utils.py | 1 + tests/v1/spec_decode/test_eagle.py | 65 +- tests/v1/spec_decode/test_max_len.py | 1 + tests/v1/spec_decode/test_ngram.py | 1 + tests/v1/structured_output/test_utils.py | 1 + tests/v1/test_async_llm_dp.py | 26 +- tests/v1/test_metrics_reader.py | 1 + tests/v1/test_oracle.py | 1 + tests/v1/test_serial_utils.py | 1 + tests/v1/test_utils.py | 1 + tests/v1/tpu/test_basic.py | 1 + tests/v1/tpu/test_mha_attn.py | 1 + tests/v1/tpu/test_multimodal.py | 1 + tests/v1/tpu/test_pallas.py | 1 + tests/v1/tpu/test_perf.py | 1 + tests/v1/tpu/test_sampler.py | 1 + .../v1/tpu/test_spmd_model_weight_loading.py | 13 +- tests/v1/tpu/test_topk_topp_sampler.py | 1 + tests/v1/tpu/worker/test_tpu_model_runner.py | 235 ++++++- tests/v1/worker/test_gpu_input_batch.py | 30 +- tests/v1/worker/test_gpu_model_runner.py | 258 +++++++- tests/vllm_test_utils/setup.py | 1 + .../vllm_test_utils/__init__.py | 1 + .../vllm_test_utils/vllm_test_utils/blame.py | 1 + .../vllm_test_utils/monitor.py | 1 + tests/weight_loading/test_weight_loading.py | 1 + tests/worker/conftest.py | 1 + .../test_encoder_decoder_model_runner.py | 1 + tests/worker/test_model_input.py | 1 + tests/worker/test_model_runner.py | 1 + tests/worker/test_profile.py | 1 + tests/worker/test_swap.py | 1 + tools/check_spdx_header.py | 5 +- tools/check_triton_import.py | 1 + tools/enforce_regex_import.py | 1 + tools/profiler/print_layerwise_table.py | 1 + tools/profiler/visualize_layerwise_profile.py | 1 + tools/report_build_time_ninja.py | 1 + use_existing_torch.py | 1 + vllm/__init__.py | 1 + vllm/_custom_ops.py | 114 +++- vllm/_ipex_ops.py | 1 + vllm/adapter_commons/layers.py | 1 + vllm/adapter_commons/models.py | 1 + vllm/adapter_commons/request.py | 1 + vllm/adapter_commons/utils.py | 1 + vllm/adapter_commons/worker_manager.py | 1 + vllm/assets/audio.py | 1 + vllm/assets/base.py | 1 + vllm/assets/image.py | 1 + vllm/assets/video.py | 1 + vllm/attention/__init__.py | 1 + vllm/attention/backends/abstract.py | 2 + vllm/attention/backends/blocksparse_attn.py | 4 + vllm/attention/backends/cpu_mla.py | 10 +- .../backends/dual_chunk_flash_attn.py | 4 + vllm/attention/backends/flash_attn.py | 4 + vllm/attention/backends/flashinfer.py | 4 + vllm/attention/backends/flashmla.py | 4 +- vllm/attention/backends/hpu_attn.py | 4 + vllm/attention/backends/ipex_attn.py | 4 + vllm/attention/backends/mla/common.py | 4 + vllm/attention/backends/pallas.py | 4 + vllm/attention/backends/placeholder_attn.py | 1 + vllm/attention/backends/rocm_aiter_mla.py | 4 +- vllm/attention/backends/rocm_flash_attn.py | 4 + vllm/attention/backends/torch_sdpa.py | 20 +- vllm/attention/backends/triton_mla.py | 4 +- vllm/attention/backends/utils.py | 1 + vllm/attention/backends/xformers.py | 4 + vllm/attention/layer.py | 18 +- .../blocksparse_attention_kernel.py | 1 + .../ops/blocksparse_attention/interface.py | 1 + .../ops/blocksparse_attention/utils.py | 1 + .../ops/chunked_prefill_paged_decode.py | 1 + vllm/attention/ops/flashmla.py | 1 + vllm/attention/ops/hpu_paged_attn.py | 1 + vllm/attention/ops/ipex_attn.py | 1 + vllm/attention/ops/merge_attn_states.py | 1 + vllm/attention/ops/nki_flash_attn.py | 1 + vllm/attention/ops/paged_attn.py | 1 + vllm/attention/ops/prefix_prefill.py | 1 + vllm/attention/ops/rocm_aiter_mla.py | 1 + vllm/attention/ops/rocm_aiter_paged_attn.py | 1 + vllm/attention/ops/triton_decode_attention.py | 1 + vllm/attention/ops/triton_flash_attention.py | 1 + .../attention/ops/triton_merge_attn_states.py | 1 + .../attention/ops/triton_unified_attention.py | 1 + vllm/attention/selector.py | 1 + vllm/attention/utils/fa_utils.py | 1 + vllm/beam_search.py | 1 + vllm/benchmarks/datasets.py | 54 +- vllm/benchmarks/endpoint_request_func.py | 1 + vllm/benchmarks/latency.py | 1 + vllm/benchmarks/serve.py | 1 + vllm/benchmarks/throughput.py | 1 + vllm/benchmarks/utils.py | 1 + vllm/collect_env.py | 6 +- vllm/compilation/activation_quant_fusion.py | 1 + vllm/compilation/backends.py | 1 + vllm/compilation/base_piecewise_backend.py | 1 + vllm/compilation/collective_fusion.py | 1 + vllm/compilation/compiler_interface.py | 1 + vllm/compilation/counter.py | 3 +- vllm/compilation/cuda_piecewise_backend.py | 3 +- vllm/compilation/decorators.py | 1 + vllm/compilation/fix_functionalization.py | 1 + vllm/compilation/fusion.py | 1 + vllm/compilation/fx_utils.py | 1 + vllm/compilation/inductor_pass.py | 1 + vllm/compilation/monitor.py | 1 + vllm/compilation/multi_output_match.py | 1 + vllm/compilation/noop_elimination.py | 1 + vllm/compilation/pass_manager.py | 1 + vllm/compilation/sequence_parallelism.py | 1 + vllm/compilation/torch25_custom_graph_pass.py | 1 + vllm/compilation/vllm_inductor_pass.py | 1 + vllm/compilation/wrapper.py | 8 +- vllm/config.py | 66 +- vllm/connections.py | 1 + vllm/core/block/block_table.py | 1 + vllm/core/block/common.py | 1 + vllm/core/block/cpu_gpu_block_allocator.py | 1 + vllm/core/block/interfaces.py | 1 + vllm/core/block/naive_block.py | 1 + vllm/core/block/prefix_caching_block.py | 1 + vllm/core/block/utils.py | 1 + vllm/core/block_manager.py | 1 + vllm/core/evictor.py | 1 + vllm/core/interfaces.py | 1 + vllm/core/placeholder_block_space_manager.py | 1 + vllm/core/scheduler.py | 1 + vllm/device_allocator/cumem.py | 1 + vllm/distributed/__init__.py | 1 + vllm/distributed/communication_op.py | 1 + .../device_communicators/all2all.py | 149 ++++- .../base_device_communicator.py | 4 +- .../device_communicators/cpu_communicator.py | 1 + .../device_communicators/cuda_communicator.py | 9 + .../device_communicators/cuda_wrapper.py | 1 + .../device_communicators/custom_all_reduce.py | 1 + .../custom_all_reduce_utils.py | 1 + .../device_communicators/hpu_communicator.py | 1 + .../neuron_communicator.py | 1 + .../device_communicators/pynccl.py | 1 + .../device_communicators/pynccl_wrapper.py | 1 + .../device_communicators/shm_broadcast.py | 48 +- .../device_communicators/tpu_communicator.py | 1 + .../device_communicators/xpu_communicator.py | 1 + vllm/distributed/kv_events.py | 78 ++- vllm/distributed/kv_transfer/__init__.py | 1 + .../kv_transfer/kv_connector/base.py | 1 + .../kv_transfer/kv_connector/factory.py | 1 + .../kv_connector/lmcache_connector.py | 1 + .../kv_connector/mooncake_store_connector.py | 1 + .../kv_connector/simple_connector.py | 1 + .../kv_transfer/kv_connector/utils.py | 19 +- .../kv_transfer/kv_connector/v1/__init__.py | 1 + .../kv_transfer/kv_connector/v1/base.py | 30 +- .../kv_connector/v1/lmcache_connector.py | 1 + .../kv_connector/v1/multi_connector.py | 42 +- .../kv_connector/v1/nixl_connector.py | 404 ++++++++---- .../v1/shared_storage_connector.py | 1 + .../kv_transfer/kv_connector_agent.py | 1 + .../kv_transfer/kv_lookup_buffer/base.py | 1 + .../kv_lookup_buffer/mooncake_store.py | 1 + .../kv_lookup_buffer/simple_buffer.py | 1 + vllm/distributed/kv_transfer/kv_pipe/base.py | 1 + .../kv_transfer/kv_pipe/mooncake_pipe.py | 1 + .../kv_transfer/kv_pipe/pynccl_pipe.py | 1 + .../kv_transfer/kv_transfer_state.py | 1 + vllm/distributed/parallel_state.py | 1 + vllm/distributed/utils.py | 1 + vllm/engine/arg_utils.py | 44 +- vllm/engine/async_llm_engine.py | 17 +- vllm/engine/async_timeout.py | 1 + vllm/engine/llm_engine.py | 1 + vllm/engine/metrics.py | 1 + vllm/engine/metrics_types.py | 1 + vllm/engine/multiprocessing/__init__.py | 1 + vllm/engine/multiprocessing/client.py | 1 + vllm/engine/multiprocessing/engine.py | 1 + vllm/engine/output_processor/interfaces.py | 1 + vllm/engine/output_processor/multi_step.py | 1 + vllm/engine/output_processor/single_step.py | 1 + vllm/engine/output_processor/stop_checker.py | 1 + vllm/engine/output_processor/util.py | 1 + vllm/engine/protocol.py | 1 + vllm/entrypoints/api_server.py | 1 + vllm/entrypoints/chat_utils.py | 1 + vllm/entrypoints/cli/benchmark/base.py | 1 + vllm/entrypoints/cli/benchmark/latency.py | 1 + vllm/entrypoints/cli/benchmark/main.py | 1 + vllm/entrypoints/cli/benchmark/serve.py | 1 + vllm/entrypoints/cli/benchmark/throughput.py | 1 + vllm/entrypoints/cli/collect_env.py | 1 + vllm/entrypoints/cli/main.py | 5 +- vllm/entrypoints/cli/openai.py | 1 + vllm/entrypoints/cli/run_batch.py | 9 +- vllm/entrypoints/cli/serve.py | 7 +- vllm/entrypoints/cli/types.py | 1 + vllm/entrypoints/launcher.py | 1 + vllm/entrypoints/llm.py | 1 + vllm/entrypoints/logger.py | 1 + vllm/entrypoints/openai/api_server.py | 1 + vllm/entrypoints/openai/cli_args.py | 1 + vllm/entrypoints/openai/logits_processors.py | 1 + vllm/entrypoints/openai/protocol.py | 3 + vllm/entrypoints/openai/run_batch.py | 1 + vllm/entrypoints/openai/serving_chat.py | 35 +- .../openai/serving_classification.py | 1 + vllm/entrypoints/openai/serving_completion.py | 1 + vllm/entrypoints/openai/serving_embedding.py | 1 + vllm/entrypoints/openai/serving_engine.py | 1 + vllm/entrypoints/openai/serving_models.py | 1 + vllm/entrypoints/openai/serving_pooling.py | 1 + vllm/entrypoints/openai/serving_score.py | 1 + .../openai/serving_tokenization.py | 1 + .../openai/serving_transcription.py | 1 + .../openai/tool_parsers/__init__.py | 1 + .../tool_parsers/abstract_tool_parser.py | 1 + .../tool_parsers/deepseekv3_tool_parser.py | 1 + .../granite_20b_fc_tool_parser.py | 1 + .../tool_parsers/granite_tool_parser.py | 1 + .../openai/tool_parsers/hermes_tool_parser.py | 1 + .../tool_parsers/internlm2_tool_parser.py | 1 + .../openai/tool_parsers/jamba_tool_parser.py | 1 + .../llama4_pythonic_tool_parser.py | 1 + .../openai/tool_parsers/llama_tool_parser.py | 1 + .../tool_parsers/mistral_tool_parser.py | 35 +- .../tool_parsers/phi4mini_tool_parser.py | 1 + .../tool_parsers/pythonic_tool_parser.py | 1 + vllm/entrypoints/openai/tool_parsers/utils.py | 1 + vllm/entrypoints/score_utils.py | 1 + vllm/entrypoints/ssl.py | 1 + vllm/entrypoints/utils.py | 15 +- vllm/env_override.py | 19 +- vllm/envs.py | 9 + vllm/executor/executor_base.py | 1 + vllm/executor/mp_distributed_executor.py | 1 + vllm/executor/msgspec_utils.py | 1 + vllm/executor/multiproc_worker_utils.py | 1 + vllm/executor/ray_distributed_executor.py | 1 + vllm/executor/ray_utils.py | 1 + vllm/executor/uniproc_executor.py | 1 + vllm/forward_context.py | 1 + vllm/inputs/__init__.py | 1 + vllm/inputs/data.py | 1 + vllm/inputs/parse.py | 1 + vllm/inputs/preprocess.py | 1 + vllm/inputs/registry.py | 27 +- vllm/jsontree.py | 1 + vllm/logger.py | 1 + vllm/logging_utils/__init__.py | 1 + vllm/logging_utils/dump_input.py | 1 + vllm/logging_utils/formatter.py | 1 + vllm/logits_process.py | 1 + vllm/lora/fully_sharded_layers.py | 1 + vllm/lora/layers.py | 1 + vllm/lora/lora.py | 1 + vllm/lora/models.py | 1 + vllm/lora/ops/torch_ops/__init__.py | 1 + vllm/lora/ops/torch_ops/lora_ops.py | 1 + vllm/lora/ops/triton_ops/__init__.py | 1 + vllm/lora/ops/triton_ops/kernel_utils.py | 1 + vllm/lora/ops/triton_ops/lora_expand_op.py | 1 + .../ops/triton_ops/lora_kernel_metadata.py | 1 + vllm/lora/ops/triton_ops/lora_shrink_op.py | 1 + vllm/lora/ops/triton_ops/utils.py | 1 + vllm/lora/ops/xla_ops/__init__.py | 1 + vllm/lora/ops/xla_ops/lora_ops.py | 1 + vllm/lora/peft_helper.py | 1 + vllm/lora/punica_wrapper/__init__.py | 1 + vllm/lora/punica_wrapper/punica_base.py | 1 + vllm/lora/punica_wrapper/punica_cpu.py | 1 + vllm/lora/punica_wrapper/punica_gpu.py | 1 + vllm/lora/punica_wrapper/punica_hpu.py | 1 + vllm/lora/punica_wrapper/punica_selector.py | 1 + vllm/lora/punica_wrapper/punica_tpu.py | 1 + vllm/lora/punica_wrapper/utils.py | 1 + vllm/lora/request.py | 1 + vllm/lora/resolver.py | 1 + vllm/lora/utils.py | 1 + vllm/lora/worker_manager.py | 1 + vllm/model_executor/__init__.py | 1 + vllm/model_executor/custom_op.py | 1 + .../guided_decoding/__init__.py | 1 + .../guided_decoding/guidance_decoding.py | 1 + .../guidance_logits_processors.py | 1 + .../guided_decoding/guided_fields.py | 1 + .../lm_format_enforcer_decoding.py | 1 + .../guided_decoding/outlines_decoding.py | 1 + .../outlines_logits_processors.py | 1 + vllm/model_executor/guided_decoding/utils.py | 1 + .../guided_decoding/xgrammar_decoding.py | 1 + vllm/model_executor/layers/activation.py | 1 + .../layers/fused_moe/__init__.py | 1 + .../layers/fused_moe/batched_deep_gemm_moe.py | 125 ++++ .../batched_triton_or_deep_gemm_moe.py | 117 ++++ ...,dtype=fp8_w8a8,block_shape=[128,128].json | 146 +++++ .../layers/fused_moe/cutlass_moe.py | 379 +++++++----- .../layers/fused_moe/deep_gemm_moe.py | 34 +- .../fused_moe/deepep_ht_prepare_finalize.py | 236 ++++++++ .../fused_moe/deepep_ll_prepare_finalize.py | 186 ++++++ .../layers/fused_moe/fused_batched_moe.py | 60 +- .../layers/fused_moe/fused_marlin_moe.py | 1 + .../layers/fused_moe/fused_moe.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 203 +++++-- .../layers/fused_moe/modular_kernel.py | 164 +++-- .../layers/fused_moe/moe_align_block_size.py | 1 + .../layers/fused_moe/moe_pallas.py | 1 + .../layers/fused_moe/moe_permute_unpermute.py | 6 +- .../layers/fused_moe/moe_torch_iterative.py | 1 + .../layers/fused_moe/pplx_prepare_finalize.py | 33 +- .../layers/fused_moe/prepare_finalize.py | 13 +- .../layers/fused_moe/rocm_aiter_fused_moe.py | 1 + .../layers/fused_moe/triton_deep_gemm_moe.py | 13 +- vllm/model_executor/layers/fused_moe/utils.py | 5 +- vllm/model_executor/layers/layernorm.py | 1 + vllm/model_executor/layers/lightning_attn.py | 1 + vllm/model_executor/layers/linear.py | 1 + .../model_executor/layers/logits_processor.py | 1 + .../layers/mamba/mamba2_metadata.py | 1 + .../layers/mamba/mamba_mixer.py | 1 + .../layers/mamba/mamba_mixer2.py | 1 + .../layers/mamba/ops/causal_conv1d.py | 1 + .../layers/mamba/ops/mamba_ssm.py | 1 + .../layers/mamba/ops/ssd_bmm.py | 1 + .../layers/mamba/ops/ssd_chunk_scan.py | 1 + .../layers/mamba/ops/ssd_chunk_state.py | 1 + .../layers/mamba/ops/ssd_combined.py | 1 + .../layers/mamba/ops/ssd_state_passing.py | 1 + vllm/model_executor/layers/pooler.py | 1 + .../layers/quantization/__init__.py | 1 + .../layers/quantization/aqlm.py | 1 + .../layers/quantization/auto_round.py | 1 + .../model_executor/layers/quantization/awq.py | 1 + .../layers/quantization/awq_marlin.py | 1 + .../layers/quantization/awq_triton.py | 1 + .../layers/quantization/base_config.py | 1 + .../layers/quantization/bitblas.py | 1 + .../layers/quantization/bitsandbytes.py | 1 + .../compressed_tensors/compressed_tensors.py | 7 +- .../compressed_tensors_moe.py | 88 +-- .../compressed_tensors/schemes/__init__.py | 1 + .../schemes/compressed_tensors_24.py | 1 + .../schemes/compressed_tensors_scheme.py | 1 + .../schemes/compressed_tensors_w4a16_24.py | 1 + .../schemes/compressed_tensors_w4a16_nvfp4.py | 1 + .../schemes/compressed_tensors_w8a16_fp8.py | 1 + .../schemes/compressed_tensors_w8a8_fp8.py | 1 + .../schemes/compressed_tensors_w8a8_int8.py | 1 + .../schemes/compressed_tensors_wNa16.py | 1 + .../compressed_tensors/triton_scaled_mm.py | 1 + .../quantization/compressed_tensors/utils.py | 1 + .../layers/quantization/deepspeedfp.py | 1 + .../layers/quantization/experts_int8.py | 1 + .../layers/quantization/fbgemm_fp8.py | 1 + .../model_executor/layers/quantization/fp8.py | 46 +- .../layers/quantization/gguf.py | 1 + .../layers/quantization/gptq.py | 1 + .../layers/quantization/gptq_bitblas.py | 1 + .../layers/quantization/gptq_marlin.py | 1 + .../layers/quantization/gptq_marlin_24.py | 1 + .../layers/quantization/hqq_marlin.py | 1 + .../layers/quantization/ipex_quant.py | 1 + .../kernels/mixed_precision/MPLinearKernel.py | 1 + .../kernels/mixed_precision/__init__.py | 1 + .../kernels/mixed_precision/allspark.py | 1 + .../kernels/mixed_precision/bitblas.py | 1 + .../kernels/mixed_precision/exllama.py | 1 + .../kernels/mixed_precision/machete.py | 1 + .../kernels/mixed_precision/marlin.py | 1 + .../kernels/scaled_mm/ScaledMMLinearKernel.py | 1 + .../kernels/scaled_mm/__init__.py | 1 + .../quantization/kernels/scaled_mm/aiter.py | 1 + .../quantization/kernels/scaled_mm/cutlass.py | 1 + .../quantization/kernels/scaled_mm/triton.py | 1 + .../quantization/kernels/scaled_mm/xla.py | 1 + .../layers/quantization/kv_cache.py | 1 + .../layers/quantization/marlin.py | 1 + .../layers/quantization/modelopt.py | 1 + .../layers/quantization/moe_wna16.py | 1 + .../layers/quantization/neuron_quant.py | 1 + .../layers/quantization/ptpc_fp8.py | 1 + .../model_executor/layers/quantization/qqq.py | 1 + .../layers/quantization/quark/quark.py | 1 + .../layers/quantization/quark/quark_moe.py | 1 + .../quantization/quark/schemes/__init__.py | 1 + .../quark/schemes/quark_scheme.py | 1 + .../quark/schemes/quark_w4a4_mxfp4.py | 1 + .../quark/schemes/quark_w8a8_fp8.py | 1 + .../quark/schemes/quark_w8a8_int8.py | 1 + .../layers/quantization/quark/utils.py | 1 + .../layers/quantization/schema.py | 1 + .../layers/quantization/torchao.py | 22 +- .../layers/quantization/tpu_int8.py | 1 + .../layers/quantization/utils/__init__.py | 1 + .../quantization/utils/allspark_utils.py | 1 + .../quantization/utils/bitblas_utils.py | 1 + .../layers/quantization/utils/fp8_utils.py | 15 +- .../layers/quantization/utils/gptq_utils.py | 1 + .../layers/quantization/utils/int8_utils.py | 1 + .../layers/quantization/utils/layer_utils.py | 1 + .../quantization/utils/machete_utils.py | 1 + .../layers/quantization/utils/marlin_utils.py | 1 + .../quantization/utils/marlin_utils_fp4.py | 1 + .../quantization/utils/marlin_utils_fp8.py | 1 + .../quantization/utils/marlin_utils_test.py | 1 + .../utils/marlin_utils_test_24.py | 1 + .../utils/marlin_utils_test_qqq.py | 1 + .../layers/quantization/utils/mxfp4_utils.py | 1 + .../utils/nvfp4_emulation_utils.py | 1 + .../layers/quantization/utils/quant_utils.py | 1 + .../layers/quantization/utils/w8a8_utils.py | 1 + .../layers/rejection_sampler.py | 1 + vllm/model_executor/layers/resampler.py | 1 + .../model_executor/layers/rotary_embedding.py | 1 + vllm/model_executor/layers/sampler.py | 1 + .../layers/spec_decode_base_sampler.py | 1 + .../layers/typical_acceptance_sampler.py | 1 + vllm/model_executor/layers/utils.py | 14 +- .../layers/vocab_parallel_embedding.py | 1 + vllm/model_executor/model_loader/__init__.py | 1 + .../model_loader/base_loader.py | 1 + .../model_loader/bitsandbytes_loader.py | 1 + .../model_loader/default_loader.py | 1 + .../model_loader/dummy_loader.py | 1 + .../model_loader/gguf_loader.py | 1 + vllm/model_executor/model_loader/neuron.py | 1 + .../model_loader/neuronx_distributed.py | 1 + .../model_loader/runai_streamer_loader.py | 1 + .../model_loader/sharded_state_loader.py | 1 + .../model_executor/model_loader/tensorizer.py | 1 + .../model_loader/tensorizer_loader.py | 1 + vllm/model_executor/model_loader/utils.py | 1 + .../model_loader/weight_utils.py | 1 + vllm/model_executor/models/__init__.py | 1 + vllm/model_executor/models/adapters.py | 1 + vllm/model_executor/models/aimv2.py | 1 + vllm/model_executor/models/arctic.py | 1 + vllm/model_executor/models/aria.py | 1 + vllm/model_executor/models/aya_vision.py | 15 +- vllm/model_executor/models/baichuan.py | 1 + vllm/model_executor/models/bamba.py | 1 + vllm/model_executor/models/bart.py | 1 + vllm/model_executor/models/bert.py | 14 +- vllm/model_executor/models/bert_with_rope.py | 8 +- vllm/model_executor/models/blip.py | 1 + vllm/model_executor/models/blip2.py | 1 + vllm/model_executor/models/bloom.py | 1 + vllm/model_executor/models/chameleon.py | 1 + vllm/model_executor/models/chatglm.py | 1 + vllm/model_executor/models/clip.py | 1 + vllm/model_executor/models/commandr.py | 1 + .../models/constant_size_cache.py | 1 + vllm/model_executor/models/dbrx.py | 1 + vllm/model_executor/models/deepseek.py | 1 + vllm/model_executor/models/deepseek_mtp.py | 1 + vllm/model_executor/models/deepseek_v2.py | 1 + vllm/model_executor/models/deepseek_vl2.py | 1 + vllm/model_executor/models/eagle.py | 1 + vllm/model_executor/models/exaone.py | 1 + vllm/model_executor/models/fairseq2_llama.py | 1 + vllm/model_executor/models/falcon.py | 1 + vllm/model_executor/models/falcon_h1.py | 1 + vllm/model_executor/models/florence2.py | 1 + vllm/model_executor/models/fuyu.py | 1 + vllm/model_executor/models/gemma.py | 1 + vllm/model_executor/models/gemma2.py | 1 + vllm/model_executor/models/gemma3.py | 1 + vllm/model_executor/models/gemma3_mm.py | 1 + vllm/model_executor/models/glm.py | 1 + vllm/model_executor/models/glm4.py | 1 + vllm/model_executor/models/glm4v.py | 1 + vllm/model_executor/models/gpt2.py | 1 + vllm/model_executor/models/gpt_bigcode.py | 1 + vllm/model_executor/models/gpt_j.py | 1 + vllm/model_executor/models/gpt_neox.py | 1 + vllm/model_executor/models/granite.py | 1 + vllm/model_executor/models/granite_speech.py | 1 + vllm/model_executor/models/granitemoe.py | 1 + .../model_executor/models/granitemoehybrid.py | 1 + .../model_executor/models/granitemoeshared.py | 1 + vllm/model_executor/models/gritlm.py | 1 + vllm/model_executor/models/grok1.py | 1 + vllm/model_executor/models/h2ovl.py | 1 + .../models/idefics2_vision_model.py | 1 + vllm/model_executor/models/idefics3.py | 17 +- vllm/model_executor/models/interfaces.py | 1 + vllm/model_executor/models/interfaces_base.py | 1 + vllm/model_executor/models/intern_vit.py | 1 + vllm/model_executor/models/internlm2.py | 1 + vllm/model_executor/models/internlm2_ve.py | 1 + vllm/model_executor/models/internvl.py | 1 + vllm/model_executor/models/jais.py | 1 + vllm/model_executor/models/jamba.py | 1 + vllm/model_executor/models/kimi_vl.py | 1 + vllm/model_executor/models/llama.py | 1 + vllm/model_executor/models/llama4.py | 1 + vllm/model_executor/models/llama_eagle.py | 15 +- vllm/model_executor/models/llama_eagle3.py | 25 +- vllm/model_executor/models/llava.py | 1 + vllm/model_executor/models/llava_next.py | 1 + .../model_executor/models/llava_next_video.py | 1 + vllm/model_executor/models/llava_onevision.py | 1 + vllm/model_executor/models/mamba.py | 1 + vllm/model_executor/models/mamba2.py | 1 + vllm/model_executor/models/mamba_cache.py | 1 + vllm/model_executor/models/medusa.py | 1 + vllm/model_executor/models/mimo.py | 1 + vllm/model_executor/models/mimo_mtp.py | 1 + vllm/model_executor/models/minicpm.py | 1 + vllm/model_executor/models/minicpm3.py | 1 + vllm/model_executor/models/minicpm_eagle.py | 1 + vllm/model_executor/models/minicpmo.py | 1 + vllm/model_executor/models/minicpmv.py | 1 + vllm/model_executor/models/minimax_cache.py | 1 + vllm/model_executor/models/minimax_text_01.py | 1 + vllm/model_executor/models/minimax_vl_01.py | 1 + vllm/model_executor/models/mistral3.py | 1 + vllm/model_executor/models/mixtral.py | 1 + vllm/model_executor/models/mixtral_quant.py | 1 + vllm/model_executor/models/mllama.py | 1 + vllm/model_executor/models/mllama4.py | 1 + vllm/model_executor/models/mlp_speculator.py | 1 + vllm/model_executor/models/modernbert.py | 1 + vllm/model_executor/models/module_mapping.py | 1 + vllm/model_executor/models/molmo.py | 1 + vllm/model_executor/models/moonvit.py | 1 + vllm/model_executor/models/mpt.py | 1 + vllm/model_executor/models/nemotron.py | 1 + vllm/model_executor/models/nemotron_h.py | 573 ++++++++++++++++++ vllm/model_executor/models/nemotron_nas.py | 1 + vllm/model_executor/models/nvlm_d.py | 1 + vllm/model_executor/models/olmo.py | 1 + vllm/model_executor/models/olmo2.py | 1 + vllm/model_executor/models/olmoe.py | 1 + vllm/model_executor/models/opt.py | 1 + vllm/model_executor/models/orion.py | 1 + vllm/model_executor/models/ovis.py | 1 + vllm/model_executor/models/paligemma.py | 1 + vllm/model_executor/models/persimmon.py | 1 + vllm/model_executor/models/phi.py | 1 + vllm/model_executor/models/phi3.py | 1 + vllm/model_executor/models/phi3_small.py | 1 + vllm/model_executor/models/phi3v.py | 1 + vllm/model_executor/models/phi4mm.py | 1 + vllm/model_executor/models/phi4mm_audio.py | 1 + vllm/model_executor/models/phi4mm_utils.py | 1 + vllm/model_executor/models/phimoe.py | 1 + vllm/model_executor/models/pixtral.py | 1 + vllm/model_executor/models/plamo2.py | 1 + .../models/prithvi_geospatial_mae.py | 1 + vllm/model_executor/models/qwen.py | 1 + vllm/model_executor/models/qwen2.py | 1 + .../models/qwen2_5_omni_thinker.py | 1 + vllm/model_executor/models/qwen2_5_vl.py | 1 + vllm/model_executor/models/qwen2_audio.py | 1 + vllm/model_executor/models/qwen2_moe.py | 1 + vllm/model_executor/models/qwen2_rm.py | 1 + vllm/model_executor/models/qwen2_vl.py | 1 + vllm/model_executor/models/qwen3.py | 1 + vllm/model_executor/models/qwen3_moe.py | 1 + vllm/model_executor/models/qwen_vl.py | 1 + vllm/model_executor/models/registry.py | 2 + vllm/model_executor/models/roberta.py | 1 + vllm/model_executor/models/siglip.py | 1 + vllm/model_executor/models/skyworkr1v.py | 1 + vllm/model_executor/models/smolvlm.py | 1 + vllm/model_executor/models/solar.py | 1 + vllm/model_executor/models/stablelm.py | 1 + vllm/model_executor/models/starcoder2.py | 1 + vllm/model_executor/models/telechat2.py | 1 + vllm/model_executor/models/teleflm.py | 1 + vllm/model_executor/models/transformers.py | 1 + vllm/model_executor/models/ultravox.py | 1 + vllm/model_executor/models/utils.py | 1 + vllm/model_executor/models/vision.py | 1 + vllm/model_executor/models/whisper.py | 1 + vllm/model_executor/models/zamba2.py | 1 + vllm/model_executor/parameter.py | 1 + vllm/model_executor/pooling_metadata.py | 1 + vllm/model_executor/sampling_metadata.py | 1 + vllm/model_executor/utils.py | 1 + vllm/multimodal/__init__.py | 1 + vllm/multimodal/audio.py | 1 + vllm/multimodal/base.py | 1 + vllm/multimodal/hasher.py | 1 + vllm/multimodal/image.py | 1 + vllm/multimodal/inputs.py | 9 +- vllm/multimodal/parse.py | 1 + vllm/multimodal/processing.py | 1 + vllm/multimodal/profiling.py | 1 + vllm/multimodal/registry.py | 1 + vllm/multimodal/utils.py | 31 +- vllm/multimodal/video.py | 1 + vllm/outputs.py | 1 + vllm/platforms/__init__.py | 1 + vllm/platforms/cpu.py | 78 ++- vllm/platforms/cuda.py | 43 ++ vllm/platforms/hpu.py | 1 + vllm/platforms/interface.py | 30 + vllm/platforms/neuron.py | 1 + vllm/platforms/rocm.py | 1 + vllm/platforms/tpu.py | 1 + vllm/platforms/xpu.py | 1 + vllm/plugins/__init__.py | 1 + .../lora_resolvers/filesystem_resolver.py | 1 + vllm/pooling_params.py | 1 + vllm/profiler/layerwise_profile.py | 1 + vllm/profiler/utils.py | 1 + vllm/prompt_adapter/layers.py | 1 + vllm/prompt_adapter/models.py | 1 + vllm/prompt_adapter/request.py | 1 + vllm/prompt_adapter/utils.py | 1 + vllm/prompt_adapter/worker_manager.py | 1 + vllm/reasoning/__init__.py | 1 + vllm/reasoning/abs_reasoning_parsers.py | 1 + .../reasoning/deepseek_r1_reasoning_parser.py | 1 + vllm/reasoning/granite_reasoning_parser.py | 1 + vllm/reasoning/qwen3_reasoning_parser.py | 1 + vllm/sampling_params.py | 12 + vllm/scalar_type.py | 1 + vllm/scripts.py | 1 + vllm/sequence.py | 1 + vllm/spec_decode/batch_expansion.py | 1 + vllm/spec_decode/draft_model_runner.py | 2 +- vllm/spec_decode/interfaces.py | 1 + vllm/spec_decode/medusa_worker.py | 1 + vllm/spec_decode/metrics.py | 1 + vllm/spec_decode/mlp_speculator_worker.py | 1 + vllm/spec_decode/mqa_scorer.py | 1 + vllm/spec_decode/multi_step_worker.py | 1 + vllm/spec_decode/ngram_worker.py | 1 + vllm/spec_decode/proposer_worker_base.py | 1 + .../spec_decode/smaller_tp_proposer_worker.py | 1 + vllm/spec_decode/spec_decode_worker.py | 1 + vllm/spec_decode/target_model_runner.py | 1 + vllm/spec_decode/top1_proposer.py | 1 + vllm/spec_decode/util.py | 1 + vllm/test_utils.py | 1 + vllm/third_party/pynvml.py | 1 + vllm/tracing.py | 1 + vllm/transformers_utils/__init__.py | 1 + .../chat_templates/__init__.py | 1 + .../chat_templates/registry.py | 1 + vllm/transformers_utils/config.py | 1 + vllm/transformers_utils/configs/__init__.py | 3 + vllm/transformers_utils/configs/arctic.py | 1 + vllm/transformers_utils/configs/chatglm.py | 1 + vllm/transformers_utils/configs/cohere2.py | 1 + vllm/transformers_utils/configs/dbrx.py | 1 + .../configs/deepseek_vl2.py | 1 + vllm/transformers_utils/configs/eagle.py | 1 + vllm/transformers_utils/configs/exaone.py | 1 + vllm/transformers_utils/configs/falcon.py | 1 + vllm/transformers_utils/configs/h2ovl.py | 1 + vllm/transformers_utils/configs/internvl.py | 1 + vllm/transformers_utils/configs/jais.py | 1 + vllm/transformers_utils/configs/kimi_vl.py | 1 + vllm/transformers_utils/configs/medusa.py | 1 + .../configs/minimax_text_01.py | 1 + .../configs/minimax_vl_01.py | 1 + vllm/transformers_utils/configs/mllama.py | 1 + .../configs/mlp_speculator.py | 1 + vllm/transformers_utils/configs/moonvit.py | 1 + vllm/transformers_utils/configs/mpt.py | 1 + vllm/transformers_utils/configs/nemotron.py | 1 + vllm/transformers_utils/configs/nemotron_h.py | 258 ++++++++ vllm/transformers_utils/configs/nvlm_d.py | 1 + vllm/transformers_utils/configs/ovis.py | 1 + vllm/transformers_utils/configs/skyworkr1v.py | 1 + vllm/transformers_utils/configs/solar.py | 1 + vllm/transformers_utils/configs/telechat2.py | 1 + vllm/transformers_utils/configs/ultravox.py | 1 + vllm/transformers_utils/detokenizer.py | 1 + vllm/transformers_utils/detokenizer_utils.py | 1 + vllm/transformers_utils/processor.py | 1 + .../transformers_utils/processors/__init__.py | 1 + .../processors/deepseek_vl2.py | 1 + vllm/transformers_utils/processors/ovis.py | 1 + vllm/transformers_utils/s3_utils.py | 1 + vllm/transformers_utils/tokenizer.py | 1 + vllm/transformers_utils/tokenizer_base.py | 1 + vllm/transformers_utils/tokenizer_group.py | 1 + .../transformers_utils/tokenizers/__init__.py | 1 + vllm/transformers_utils/tokenizers/mistral.py | 3 + vllm/transformers_utils/utils.py | 1 + vllm/triton_utils/__init__.py | 1 + vllm/triton_utils/importing.py | 1 + vllm/usage/usage_lib.py | 1 + vllm/utils.py | 4 +- vllm/v1/attention/backends/cpu_attn.py | 163 +++++ vllm/v1/attention/backends/flash_attn.py | 82 ++- vllm/v1/attention/backends/flashinfer.py | 37 +- vllm/v1/attention/backends/flex_attention.py | 477 +++++++++++++++ vllm/v1/attention/backends/mla/common.py | 7 +- vllm/v1/attention/backends/mla/cutlass_mla.py | 96 +++ vllm/v1/attention/backends/mla/flashmla.py | 4 +- .../attention/backends/mla/rocm_aiter_mla.py | 4 +- vllm/v1/attention/backends/mla/triton_mla.py | 4 +- vllm/v1/attention/backends/pallas.py | 7 +- vllm/v1/attention/backends/triton_attn.py | 57 +- vllm/v1/attention/backends/utils.py | 34 ++ vllm/v1/core/block_pool.py | 73 ++- vllm/v1/core/encoder_cache_manager.py | 1 + vllm/v1/core/kv_cache_coordinator.py | 358 +++++++++++ vllm/v1/core/kv_cache_manager.py | 140 +++-- vllm/v1/core/kv_cache_utils.py | 379 ++++++++++-- vllm/v1/core/sched/interface.py | 1 + vllm/v1/core/sched/output.py | 1 + vllm/v1/core/sched/scheduler.py | 49 +- vllm/v1/core/sched/utils.py | 1 + vllm/v1/core/single_type_kv_cache_manager.py | 157 +++-- vllm/v1/engine/__init__.py | 2 + vllm/v1/engine/async_llm.py | 25 +- vllm/v1/engine/coordinator.py | 1 + vllm/v1/engine/core.py | 1 + vllm/v1/engine/core_client.py | 37 +- vllm/v1/engine/detokenizer.py | 1 + vllm/v1/engine/exceptions.py | 1 + vllm/v1/engine/llm_engine.py | 1 + vllm/v1/engine/logprobs.py | 1 + vllm/v1/engine/mm_input_cache.py | 1 + vllm/v1/engine/output_processor.py | 1 + vllm/v1/engine/parallel_sampling.py | 1 + vllm/v1/engine/processor.py | 9 + vllm/v1/executor/abstract.py | 1 + vllm/v1/executor/multiproc_executor.py | 1 + vllm/v1/executor/ray_distributed_executor.py | 1 + vllm/v1/kv_cache_interface.py | 34 +- vllm/v1/metrics/loggers.py | 1 + vllm/v1/metrics/prometheus.py | 1 + vllm/v1/metrics/ray_wrappers.py | 11 + vllm/v1/metrics/reader.py | 1 + vllm/v1/metrics/stats.py | 1 + vllm/v1/outputs.py | 1 + vllm/v1/request.py | 1 + vllm/v1/sample/metadata.py | 1 + vllm/v1/sample/ops/bad_words.py | 1 + vllm/v1/sample/ops/penalties.py | 1 + vllm/v1/sample/ops/topk_topp_sampler.py | 1 + vllm/v1/sample/rejection_sampler.py | 1 + vllm/v1/sample/sampler.py | 18 +- vllm/v1/sample/tpu/metadata.py | 1 + vllm/v1/sample/tpu/sampler.py | 1 + vllm/v1/serial_utils.py | 1 + vllm/v1/spec_decode/eagle.py | 12 +- vllm/v1/spec_decode/medusa.py | 1 + vllm/v1/spec_decode/metadata.py | 1 + vllm/v1/spec_decode/metrics.py | 1 + vllm/v1/spec_decode/ngram_proposer.py | 1 + vllm/v1/spec_decode/utils.py | 1 + vllm/v1/structured_output/__init__.py | 1 + vllm/v1/structured_output/backend_guidance.py | 1 + vllm/v1/structured_output/backend_types.py | 1 + vllm/v1/structured_output/backend_xgrammar.py | 1 + vllm/v1/structured_output/request.py | 1 + vllm/v1/structured_output/utils.py | 1 + vllm/v1/utils.py | 3 +- vllm/v1/worker/block_table.py | 4 +- vllm/v1/worker/cpu_model_runner.py | 86 +++ vllm/v1/worker/cpu_worker.py | 101 +++ vllm/v1/worker/gpu_input_batch.py | 19 +- vllm/v1/worker/gpu_model_runner.py | 242 ++++++-- vllm/v1/worker/gpu_worker.py | 58 +- vllm/v1/worker/lora_model_runner_mixin.py | 1 + vllm/v1/worker/tpu_model_runner.py | 72 ++- vllm/v1/worker/tpu_worker.py | 4 +- vllm/v1/worker/utils.py | 37 ++ vllm/v1/worker/worker_base.py | 1 + vllm/version.py | 1 + vllm/worker/cache_engine.py | 1 + vllm/worker/cpu_enc_dec_model_runner.py | 2 +- vllm/worker/cpu_model_runner.py | 2 +- vllm/worker/cpu_pooling_model_runner.py | 2 +- vllm/worker/cpu_worker.py | 1 + vllm/worker/enc_dec_model_runner.py | 2 +- vllm/worker/hpu_model_runner.py | 1 + vllm/worker/hpu_worker.py | 1 + vllm/worker/model_runner.py | 2 +- vllm/worker/model_runner_base.py | 1 + vllm/worker/multi_step_hpu_worker.py | 1 + vllm/worker/multi_step_model_runner.py | 1 + vllm/worker/multi_step_neuron_model_runner.py | 2 +- ...i_step_neuronx_distributed_model_runner.py | 2 +- vllm/worker/multi_step_tpu_worker.py | 1 + vllm/worker/multi_step_worker.py | 1 + vllm/worker/neuron_model_runner.py | 3 +- vllm/worker/neuron_worker.py | 1 + .../neuronx_distributed_model_runner.py | 1 + vllm/worker/pooling_model_runner.py | 2 +- vllm/worker/tpu_model_runner.py | 1 + vllm/worker/tpu_worker.py | 1 + vllm/worker/utils.py | 1 + vllm/worker/worker.py | 1 + vllm/worker/worker_base.py | 4 +- vllm/worker/xpu_model_runner.py | 2 +- vllm/worker/xpu_worker.py | 1 + 1513 files changed, 13047 insertions(+), 2268 deletions(-) create mode 100755 .buildkite/scripts/annotate-release.sh create mode 100755 .buildkite/scripts/tpu/cleanup_docker.sh create mode 100644 .buildkite/scripts/tpu/config_v6e_1.env create mode 100755 .buildkite/scripts/tpu/docker_run_bm.sh create mode 100755 .buildkite/scripts/tpu/run_bm.sh create mode 100644 csrc/sampler.cu create mode 100644 docs/contributing/ci-failures.md create mode 100644 examples/online_serving/multi_instance_data_parallel.py create mode 100644 examples/tool_chat_template_deepseekr1.jinja create mode 100644 tests/compile/test_config.py create mode 100644 tests/kernels/moe/__init__.py create mode 100644 tests/kernels/moe/deepep_utils.py create mode 100644 tests/kernels/moe/test_deepep_deepgemm_moe.py create mode 100644 tests/kernels/moe/test_deepep_moe.py create mode 100644 tests/kernels/moe/test_pplx_cutlass_moe.py create mode 100644 tests/kernels/test_apply_repetition_penalties.py create mode 100644 tests/kernels/test_flex_attention.py create mode 100644 tests/models/language/pooling/test_intfloat.py create mode 100644 tests/pplx_utils.py create mode 100644 vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json create mode 100644 vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py create mode 100644 vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py create mode 100644 vllm/model_executor/models/nemotron_h.py create mode 100644 vllm/transformers_utils/configs/nemotron_h.py create mode 100644 vllm/v1/attention/backends/cpu_attn.py create mode 100644 vllm/v1/attention/backends/flex_attention.py create mode 100644 vllm/v1/attention/backends/mla/cutlass_mla.py create mode 100644 vllm/v1/core/kv_cache_coordinator.py create mode 100644 vllm/v1/worker/cpu_model_runner.py create mode 100644 vllm/v1/worker/cpu_worker.py diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index e29881fcbac0..68aff793ae6a 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import sys diff --git a/.buildkite/generate_index.py b/.buildkite/generate_index.py index 270663c415c7..7045d8810493 100644 --- a/.buildkite/generate_index.py +++ b/.buildkite/generate_index.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import os diff --git a/.buildkite/lm-eval-harness/conftest.py b/.buildkite/lm-eval-harness/conftest.py index 769d2efda4ad..c0d60dd5328f 100644 --- a/.buildkite/lm-eval-harness/conftest.py +++ b/.buildkite/lm-eval-harness/conftest.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from pathlib import Path import pytest diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index 409a6ca82008..930adfaf3e19 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ LM eval harness on model to compare vs HF baseline computed offline. Configs are found in configs/$MODEL.yaml diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 7f2a2d8dc296..a4f1638c1adb 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json import os diff --git a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py index 778a3a8d87f6..8532ff7ef798 100644 --- a/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py +++ b/.buildkite/nightly-benchmarks/scripts/download-tokenizer.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse diff --git a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py index 10a7a2f5a467..053fd52c35ae 100644 --- a/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/generate-nightly-markdown.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import json diff --git a/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py b/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py index e5f179a0f5b6..ddea1d2b1b1e 100644 --- a/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py +++ b/.buildkite/nightly-benchmarks/scripts/get-lmdeploy-modelname.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from lmdeploy.serve.openai.api_client import APIClient diff --git a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py index 2a7b37991f31..fb3b9d5e34e0 100644 --- a/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py +++ b/.buildkite/nightly-benchmarks/scripts/summary-nightly-results.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import datetime import json diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index b3c27e2c99c2..16b5ad0297fe 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -1,5 +1,6 @@ steps: - label: "Build wheel - CUDA 12.8" + id: build-wheel-cuda-12-8 agents: queue: cpu_queue_postmerge commands: @@ -11,6 +12,7 @@ steps: DOCKER_BUILDKIT: "1" - label: "Build wheel - CUDA 12.6" + id: build-wheel-cuda-12-6 agents: queue: cpu_queue_postmerge commands: @@ -28,6 +30,7 @@ steps: - label: "Build wheel - CUDA 11.8" # depends_on: block-build-cu118-wheel + id: build-wheel-cuda-11-8 agents: queue: cpu_queue_postmerge commands: @@ -44,6 +47,7 @@ steps: - label: "Build release image" depends_on: block-release-image-build + id: build-release-image agents: queue: cpu_queue_postmerge commands: @@ -51,6 +55,18 @@ steps: - "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ." - "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT" + - label: "Annotate release workflow" + depends_on: + - build-release-image + - build-wheel-cuda-12-8 + - build-wheel-cuda-12-6 + - build-wheel-cuda-11-8 + id: annotate-release-workflow + agents: + queue: cpu_queue_postmerge + commands: + - "bash .buildkite/scripts/annotate-release.sh" + - label: "Build and publish TPU release image" depends_on: ~ if: build.env("NIGHTLY") == "1" @@ -70,9 +86,10 @@ steps: DOCKER_BUILDKIT: "1" - input: "Provide Release version here" + id: input-release-version fields: - text: "What is the release version?" - key: "release-version" + key: release-version - block: "Build CPU release image" key: block-cpu-release-image-build diff --git a/.buildkite/scripts/annotate-release.sh b/.buildkite/scripts/annotate-release.sh new file mode 100755 index 000000000000..94e0ac2398f3 --- /dev/null +++ b/.buildkite/scripts/annotate-release.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +set -ex + +# Get release version and strip leading 'v' if present +RELEASE_VERSION=$(buildkite-agent meta-data get release-version | sed 's/^v//') + +if [ -z "$RELEASE_VERSION" ]; then + echo "Error: RELEASE_VERSION is empty. 'release-version' metadata might not be set or is invalid." + exit 1 +fi + +buildkite-agent annotate --style 'info' --context 'release-workflow' << EOF +To download the wheel: +\`\`\` +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}/vllm-${RELEASE_VERSION}-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu126/vllm-${RELEASE_VERSION}+cu126-cp38-abi3-manylinux1_x86_64.whl . +aws s3 cp s3://vllm-wheels/${RELEASE_VERSION}+cu118/vllm-${RELEASE_VERSION}+cu118-cp38-abi3-manylinux1_x86_64.whl . +\`\`\` + +To download and upload the image: + +\`\`\` +docker pull public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} +docker tag public.ecr.aws/q9t5s3a7/vllm-release-repo:${BUILDKITE_COMMIT} vllm/vllm-openai +docker tag vllm/vllm-openai vllm/vllm-openai:latest +docker tag vllm/vllm-openai vllm/vllm-openai:v${RELEASE_VERSION} +docker push vllm/vllm-openai:latest +docker push vllm/vllm-openai:v${RELEASE_VERSION} +\`\`\` +EOF \ No newline at end of file diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh index 077bd9914907..36bcb015d308 100755 --- a/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test-ppc64le.sh @@ -7,6 +7,7 @@ set -ex # Setup cleanup remove_docker_container() { if [[ -n "$container_id" ]]; then + podman stop --all -t0 podman rm -f "$container_id" || true fi podman system prune -f @@ -37,7 +38,7 @@ function cpu_tests() { pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-facebook/opt-125m] pytest -v -s tests/models/language/generation/test_common.py::test_models[False-5-32-google/gemma-1.1-2b-it] pytest -v -s tests/models/language/pooling/test_classification.py::test_models[float-jason9693/Qwen2.5-1.5B-apeach] - pytest -v -s tests/models/language/pooling/test_embedding.py::test_models[half-BAAI/bge-base-en-v1.5]" + pytest -v -s tests/models/language/pooling/test_embedding.py -m cpu_model" } # All of CPU tests are expected to be finished less than 40 mins. diff --git a/.buildkite/scripts/hardware_ci/run-cpu-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-test.sh index 0a11935607e2..61aa7df13b4d 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-test.sh @@ -6,6 +6,7 @@ set -ex # allow to bind to different cores CORE_RANGE=${CORE_RANGE:-48-95} +OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95} NUMA_NODE=${NUMA_NODE:-1} export CMAKE_BUILD_PARALLEL_LEVEL=32 @@ -23,10 +24,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu . # Run the image, setting --shm-size=4g for tensor parallel. -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus="$CORE_RANGE" \ - --cpuset-mems="$NUMA_NODE" --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE" +docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2 function cpu_tests() { set -e @@ -56,7 +55,7 @@ function cpu_tests() { # Run AWQ test docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - pytest -s -v \ + VLLM_USE_V1=0 pytest -s -v \ tests/quantization/test_ipex_quant.py" # Run chunked-prefill and prefix-cache test @@ -68,8 +67,6 @@ function cpu_tests() { # online serving docker exec cpu-test-"$NUMA_NODE" bash -c " set -e - export VLLM_CPU_KVCACHE_SPACE=10 - export VLLM_CPU_OMP_THREADS_BIND=$1 python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half & timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 python3 benchmarks/benchmark_serving.py \ @@ -89,4 +86,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 40 mins. export -f cpu_tests -timeout 40m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" +timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh index 3212b660ec35..a2a5c2a02cbb 100755 --- a/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh +++ b/.buildkite/scripts/hardware_ci/run-tpu-v1-test.sh @@ -150,7 +150,7 @@ run_and_track_test 9 "test_multimodal.py" \ run_and_track_test 10 "test_pallas.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py" run_and_track_test 11 "test_struct_output_generate.py" \ - "python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py" + "python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\"" run_and_track_test 12 "test_moe_pallas.py" \ "python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py" run_and_track_test 13 "test_lora.py" \ diff --git a/.buildkite/scripts/tpu/cleanup_docker.sh b/.buildkite/scripts/tpu/cleanup_docker.sh new file mode 100755 index 000000000000..209d9c4341cd --- /dev/null +++ b/.buildkite/scripts/tpu/cleanup_docker.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +set -euo pipefail + +docker_root=$(docker info -f '{{.DockerRootDir}}') +if [ -z "$docker_root" ]; then + echo "Failed to determine Docker root directory." + exit 1 +fi +echo "Docker root directory: $docker_root" +# Check disk usage of the filesystem where Docker's root directory is located +disk_usage=$(df "$docker_root" | tail -1 | awk '{print $5}' | sed 's/%//') +# Define the threshold +threshold=70 +if [ "$disk_usage" -gt "$threshold" ]; then + echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." + # Remove dangling images (those that are not tagged and not used by any container) + docker image prune -f + # Remove unused volumes / force the system prune for old images as well. + docker volume prune -f && docker system prune --force --filter "until=72h" --all + echo "Docker images and volumes cleanup completed." +else + echo "Disk usage is below $threshold%. No cleanup needed." +fi diff --git a/.buildkite/scripts/tpu/config_v6e_1.env b/.buildkite/scripts/tpu/config_v6e_1.env new file mode 100644 index 000000000000..441758647347 --- /dev/null +++ b/.buildkite/scripts/tpu/config_v6e_1.env @@ -0,0 +1,14 @@ +# Environment config +TEST_NAME=llama8b +CONTAINER_NAME=vllm-tpu + +# vllm config +MODEL=meta-llama/Llama-3.1-8B-Instruct +MAX_NUM_SEQS=512 +MAX_NUM_BATCHED_TOKENS=512 +TENSOR_PARALLEL_SIZE=1 +MAX_MODEL_LEN=2048 +DOWNLOAD_DIR=/mnt/disks/persist +EXPECTED_THROUGHPUT=8.0 +INPUT_LEN=1800 +OUTPUT_LEN=128 diff --git a/.buildkite/scripts/tpu/docker_run_bm.sh b/.buildkite/scripts/tpu/docker_run_bm.sh new file mode 100755 index 000000000000..6705da03e3d7 --- /dev/null +++ b/.buildkite/scripts/tpu/docker_run_bm.sh @@ -0,0 +1,102 @@ +#!/bin/bash + +if [ ! -f "$1" ]; then + echo "Error: The env file '$1' does not exist." + exit 1 # Exit the script with a non-zero status to indicate an error +fi + +ENV_FILE=$1 + +# For testing on local vm, use `set -a` to export all variables +source /etc/environment +source $ENV_FILE + +remove_docker_container() { + docker rm -f tpu-test || true; + docker rm -f vllm-tpu || true; + docker rm -f $CONTAINER_NAME || true; +} + +trap remove_docker_container EXIT + +# Remove the container that might not be cleaned up in the previous run. +remove_docker_container + +# Build docker image. +# TODO: build the image outside the script and share the image with other +# tpu test if building time is too long. +DOCKER_BUILDKIT=1 docker build \ + --build-arg max_jobs=16 \ + --build-arg USE_SCCACHE=1 \ + --build-arg GIT_REPO_CHECK=0 \ + --tag vllm/vllm-tpu-bm \ + --progress plain -f docker/Dockerfile.tpu . + +LOG_ROOT=$(mktemp -d) +# If mktemp fails, set -e will cause the script to exit. +echo "Results will be stored in: $LOG_ROOT" + +if [ -z "$HF_TOKEN" ]; then + echo "Error: HF_TOKEN is not set or is empty." + exit 1 +fi + +# Make sure mounted disk or dir exists +if [ ! -d "$DOWNLOAD_DIR" ]; then + echo "Error: Folder $DOWNLOAD_DIR does not exist. This is useually a mounted drive. If no mounted drive, just create a folder." + exit 1 +fi + +echo "Run model $MODEL" +echo + +echo "starting docker...$CONTAINER_NAME" +echo +docker run \ + -v $DOWNLOAD_DIR:$DOWNLOAD_DIR \ + --env-file $ENV_FILE \ + -e HF_TOKEN="$HF_TOKEN" \ + -e TARGET_COMMIT=$BUILDKITE_COMMIT \ + -e MODEL=$MODEL \ + -e WORKSPACE=/workspace \ + --name $CONTAINER_NAME \ + -d \ + --privileged \ + --network host \ + -v /dev/shm:/dev/shm \ + vllm/vllm-tpu-bm tail -f /dev/null + +echo "run script..." +echo +docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/hardware_ci/run_bm.sh" + +echo "copy result back..." +VLLM_LOG="$LOG_ROOT/$TEST_NAME"_vllm_log.txt +BM_LOG="$LOG_ROOT/$TEST_NAME"_bm_log.txt +docker cp "$CONTAINER_NAME:/workspace/vllm_log.txt" "$VLLM_LOG" +docker cp "$CONTAINER_NAME:/workspace/bm_log.txt" "$BM_LOG" + +throughput=$(grep "Request throughput (req/s):" "$BM_LOG" | sed 's/[^0-9.]//g') +echo "throughput for $TEST_NAME at $BUILDKITE_COMMIT: $throughput" + +if [ "$BUILDKITE" = "true" ]; then + echo "Running inside Buildkite" + buildkite-agent artifact upload "$VLLM_LOG" + buildkite-agent artifact upload "$BM_LOG" +else + echo "Not running inside Buildkite" +fi + +# +# compare the throughput with EXPECTED_THROUGHPUT +# and assert meeting the expectation +# +if [[ -z "$throughput" || ! "$throughput" =~ ^[0-9]+([.][0-9]+)?$ ]]; then + echo "Failed to get the throughput" + exit 1 +fi + +if (( $(echo "$throughput < $EXPECTED_THROUGHPUT" | bc -l) )); then + echo "Error: throughput($throughput) is less than expected($EXPECTED_THROUGHPUT)" + exit 1 +fi diff --git a/.buildkite/scripts/tpu/run_bm.sh b/.buildkite/scripts/tpu/run_bm.sh new file mode 100755 index 000000000000..877669cd956a --- /dev/null +++ b/.buildkite/scripts/tpu/run_bm.sh @@ -0,0 +1,94 @@ +#!/bin/bash + +set -euo pipefail + +VLLM_LOG="$WORKSPACE/vllm_log.txt" +BM_LOG="$WORKSPACE/bm_log.txt" + +if [ -n "$TARGET_COMMIT" ]; then + head_hash=$(git rev-parse HEAD) + if [ "$TARGET_COMMIT" != "$head_hash" ]; then + echo "Error: target commit $TARGET_COMMIT does not match HEAD: $head_hash" + exit 1 + fi +fi + +echo "model: $MODEL" +echo + +# +# create a log folder +# +mkdir "$WORKSPACE/log" + +# TODO: Move to image building. +pip install pandas +pip install datasets + +# +# create sonnet_4x +# +echo "Create sonnet_4x.txt" +echo "" > benchmarks/sonnet_4x.txt +for _ in {1..4} + do + cat benchmarks/sonnet.txt >> benchmarks/sonnet_4x.txt +done + +# +# start vllm service in backend +# +echo "lanching vllm..." +echo "logging to $VLLM_LOG" +echo + +VLLM_USE_V1=1 vllm serve $MODEL \ + --seed 42 \ + --disable-log-requests \ + --max-num-seqs $MAX_NUM_SEQS \ + --max-num-batched-tokens $MAX_NUM_BATCHED_TOKENS \ + --tensor-parallel-size $TENSOR_PARALLEL_SIZE \ + --no-enable-prefix-caching \ + --download_dir $DOWNLOAD_DIR \ + --max-model-len $MAX_MODEL_LEN > "$VLLM_LOG" 2>&1 & + + +echo "wait for 20 minutes.." +echo +# sleep 1200 +# wait for 10 minutes... +for i in {1..120}; do + # TODO: detect other type of errors. + if grep -Fq "raise RuntimeError" "$VLLM_LOG"; then + echo "Detected RuntimeError, exiting." + exit 1 + elif grep -Fq "Application startup complete" "$VLLM_LOG"; then + echo "Application started" + break + else + echo "wait for 10 seconds..." + sleep 10 + fi +done + +# +# run test +# +echo "run benchmark test..." +echo "logging to $BM_LOG" +echo +python benchmarks/benchmark_serving.py \ + --backend vllm \ + --model $MODEL \ + --dataset-name sonnet \ + --dataset-path benchmarks/sonnet_4x.txt \ + --sonnet-input-len $INPUT_LEN \ + --sonnet-output-len $OUTPUT_LEN \ + --ignore-eos > "$BM_LOG" + +echo "completed..." +echo + +throughput=$(grep "Request throughput (req/s):" "$BM_LOG" | sed 's/[^0-9.]//g') +echo "throughput: $throughput" +echo diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 5fb8ceaace05..b739851cb905 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -145,6 +145,7 @@ steps: - examples/offline_inference/rlhf_colocate.py - tests/examples/offline_inference/data_parallel.py - tests/v1/test_async_llm_dp.py + - tests/v1/engine/test_engine_core_client.py commands: # test with tp=2 and external_dp=2 - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py @@ -154,6 +155,7 @@ steps: # test with internal dp - python3 ../examples/offline_inference/data_parallel.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py + - pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp - pytest -v -s distributed/test_utils.py - pytest -v -s compile/test_basic_correctness.py - pytest -v -s distributed/test_pynccl.py @@ -318,6 +320,7 @@ steps: # these tests need to be separated, cannot combine - pytest -v -s compile/piecewise/test_simple.py - pytest -v -s compile/piecewise/test_toy_llama.py + - pytest -v -s compile/piecewise/test_full_cudagraph.py - label: PyTorch Fullgraph Test # 18min mirror_hardwares: [amdexperimental, amdproduction] @@ -421,6 +424,9 @@ steps: - vllm/model_executor/layers/quantization - tests/quantization commands: + # temporary install here since we need nightly, will move to requirements/test.in + # after torchao 0.12 release + - pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126 - VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization - label: LM Eval Small Models # 53min diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 4452ce22d504..e98ccd035ee9 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -10,15 +10,17 @@ /vllm/worker/worker.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/sampler.py @zhuohan123 @youkaichao @alexm-redhat @comaniac @njhill /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth -/vllm/model_executor/guided_decoding @mgoin @russellb +/vllm/model_executor/guided_decoding @mgoin @russellb @aarnphm /vllm/multimodal @DarkLight1337 @ywang96 /vllm/vllm_flash_attn @LucasWilkinson /vllm/lora @jeejeelee +/vllm/reasoning @aarnphm +/vllm/entrypoints @aarnphm CMakeLists.txt @tlrmchlsmth # vLLM V1 /vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat -/vllm/v1/structured_output @mgoin @russellb +/vllm/v1/structured_output @mgoin @russellb @aarnphm # Test ownership /.buildkite/lm-eval-harness @mgoin @simon-mo @@ -27,8 +29,8 @@ CMakeLists.txt @tlrmchlsmth /tests/distributed/test_multi_node_assignment.py @youkaichao /tests/distributed/test_pipeline_parallel.py @youkaichao /tests/distributed/test_same_node.py @youkaichao -/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo -/tests/entrypoints/llm/test_guided_generate.py @mgoin @russellb +/tests/entrypoints @DarkLight1337 @robertgshaw2-redhat @simon-mo @aarnphm +/tests/entrypoints/llm/test_guided_generate.py @mgoin @russellb @aarnphm /tests/kernels @tlrmchlsmth @WoosukKwon /tests/model_executor/test_guided_processors.py @mgoin @russellb /tests/models @DarkLight1337 @ywang96 @@ -38,11 +40,11 @@ CMakeLists.txt @tlrmchlsmth /tests/quantization @mgoin @robertgshaw2-redhat /tests/spec_decode @njhill @LiuXiaoxuanPKU /tests/test_inputs.py @DarkLight1337 @ywang96 -/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb -/tests/v1/structured_output @mgoin @russellb +/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm +/tests/v1/structured_output @mgoin @russellb @aarnphm /tests/weight_loading @mgoin @youkaichao /tests/lora @jeejeelee # Docs /docs @hmellor -mkdocs.yaml @hmellor \ No newline at end of file +mkdocs.yaml @hmellor diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 65be771b94fb..c1d1e07bf628 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,6 +1,15 @@ -FILL IN THE PR DESCRIPTION HERE +## Essential Elements of an Effective PR Description Checklist +- [ ] The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)". +- [ ] The test plan, such as providing test command. +- [ ] The test results, such as pasting the results comparison before and after, or e2e results -FIX #xxxx (*link existing issues this PR will resolve*) +PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED. + +## Purpose + +## Test Plan + +## Test Result **BEFORE SUBMITTING, PLEASE READ ** (anything written below this line will be removed by GitHub Actions) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 628782228e97..a105b0e14c4a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,6 +11,8 @@ repos: hooks: - id: yapf args: [--in-place, --verbose] + # Keep the same list from yapfignore here to avoid yapf failing without any inputs + exclude: '(.buildkite|benchmarks|build|examples)/.*' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.11.7 hooks: diff --git a/CMakeLists.txt b/CMakeLists.txt index 87aa23c080f5..afaed7cd1821 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,6 +242,7 @@ set(VLLM_EXT_SRC "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" "csrc/layernorm_quant_kernels.cu" + "csrc/sampler.cu" "csrc/cuda_view.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" @@ -542,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # CUTLASS MoE kernels # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works - # on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible - # to compile MoE kernels that use its output. + # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled + # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" diff --git a/README.md b/README.md index 67f6b957ec55..ec16d758327d 100644 --- a/README.md +++ b/README.md @@ -58,8 +58,8 @@ vLLM is fast with: - Efficient management of attention key and value memory with [**PagedAttention**](https://blog.vllm.ai/2023/06/20/vllm.html) - Continuous batching of incoming requests - Fast model execution with CUDA/HIP graph -- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516),INT4, INT8, and FP8. -- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer. +- Quantizations: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [AutoRound](https://arxiv.org/abs/2309.05516), INT4, INT8, and FP8 +- Optimized CUDA kernels, including integration with FlashAttention and FlashInfer - Speculative decoding - Chunked prefill @@ -72,14 +72,14 @@ vLLM is flexible and easy to use with: - Tensor parallelism and pipeline parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron. +- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron - Prefix caching support - Multi-LoRA support vLLM seamlessly supports most popular open-source models on HuggingFace, including: - Transformer-like LLMs (e.g., Llama) - Mixture-of-Expert LLMs (e.g., Mixtral, Deepseek-V2 and V3) -- Embedding Models (e.g. E5-Mistral) +- Embedding Models (e.g., E5-Mistral) - Multi-modal LLMs (e.g., LLaVA) Find the full list of supported models [here](https://docs.vllm.ai/en/latest/models/supported_models.html). @@ -162,4 +162,4 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs ## Media Kit -- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit). +- If you wish to use vLLM's logo, please refer to [our media kit repo](https://github.com/vllm-project/media-kit) diff --git a/benchmarks/auto_tune.sh b/benchmarks/auto_tune.sh index ea63c6f71a6c..1b01bbd61b62 100644 --- a/benchmarks/auto_tune.sh +++ b/benchmarks/auto_tune.sh @@ -10,11 +10,15 @@ # 3. Set variables (ALL REQUIRED) # BASE: your directory for vllm repo # MODEL: the model served by vllm +# TP: ways of tensor parallelism # DOWNLOAD_DIR: directory to download and load model weights. # INPUT_LEN: request input len # OUTPUT_LEN: request output len # MIN_CACHE_HIT_PCT: prefix cache rate # MAX_LATENCY_ALLOWED_MS: (e2e) latency requirement. If there's no latency requirement, set it to a large number like 1000000000 +# NUM_SEQS_LIST: a list of `max-num-seqs` you want to loop with. +# NUM_BATCHED_TOKENS_LIST: a list of `max-num-batched-tokens` you want to loop with. +# Note that the default NUM_SEQS_LIST and NUM_BATCHED_TOKENS_LIST are set for medium size input/output len, for extra short context (such as 20:20), you might need to include larger numbers in NUM_SEQS_LIST. # 4. Run the script, it might take a long time, you can use tmux to avoid the script stop if disconnection happens. # 5. The final result will be saved in RESULT file. @@ -30,31 +34,27 @@ TAG=$(date +"%Y_%m_%d_%H_%M") BASE="" MODEL="meta-llama/Llama-3.1-8B-Instruct" +TP=1 DOWNLOAD_DIR="" INPUT_LEN=4000 OUTPUT_LEN=16 -MIN_CACHE_HIT_PCT_PCT=0 +MIN_CACHE_HIT_PCT=0 MAX_LATENCY_ALLOWED_MS=100000000000 +NUM_SEQS_LIST="128 256" +NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096" LOG_FOLDER="$BASE/auto-benchmark/$TAG" RESULT="$LOG_FOLDER/result.txt" -echo "result file$ $RESULT" +echo "result file: $RESULT" echo "model: $MODEL" -echo rm -rf $LOG_FOLDER mkdir -p $LOG_FOLDER cd "$BASE/vllm" -# create sonnet-4x.txt so that we can sample 2048 tokens for input -echo "" > benchmarks/sonnet_4x.txt -for _ in {1..4} -do -cat benchmarks/sonnet.txt >> benchmarks/sonnet_4x.txt -done -pip install datasets +pip install -q datasets current_hash=$(git rev-parse HEAD) echo "hash:$current_hash" >> "$RESULT" @@ -64,53 +64,69 @@ best_throughput=0 best_max_num_seqs=0 best_num_batched_tokens=0 best_goodput=0 -run_benchmark() { - local max_num_seqs=$1 - local max_num_batched_tokens=$2 - echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" - local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" - echo "vllm_log: $vllm_log" - echo - rm -f $vllm_log - # start the server +start_server() { + local gpu_memory_utilization=$1 + local max_num_seqs=$2 + local max_num_batched_tokens=$3 + local vllm_log=$4 + + pkill -f vllm + VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \ --disable-log-requests \ --port 8004 \ - --gpu-memory-utilization 0.98 \ + --gpu-memory-utilization $gpu_memory_utilization \ --max-num-seqs $max_num_seqs \ --max-num-batched-tokens $max_num_batched_tokens \ - --tensor-parallel-size 1 \ + --tensor-parallel-size $TP \ --enable-prefix-caching \ --load-format dummy \ - --download-dir $DOWNLOAD_DIR \ + --download-dir "$DOWNLOAD_DIR" \ --max-model-len $(( INPUT_LEN+OUTPUT_LEN )) > "$vllm_log" 2>&1 & - echo "wait for 10 minutes.." - echo + # wait for 10 minutes... server_started=0 - for i in {1..60}; do - if grep -Fq "Application startup complete" "$vllm_log"; then - echo "Application started" + for i in {1..60}; do + RESPONSE=$(curl -s -X GET "http://0.0.0.0:8004/health" -w "%{http_code}" -o /dev/stdout) + STATUS_CODE=$(echo "$RESPONSE" | tail -n 1) + if [[ "$STATUS_CODE" -eq 200 ]]; then server_started=1 break else - # echo "wait for 10 seconds..." sleep 10 fi done - if (( ! server_started )); then - echo "server did not start within 10 minutes, terminate the benchmarking. Please check server log at $vllm_log" - echo "pkill -f vllm" - echo - pkill vllm - sleep 10 + echo "server did not start within 10 minutes. Please check server log at $vllm_log". return 1 + else + return 0 fi +} + +run_benchmark() { + local max_num_seqs=$1 + local max_num_batched_tokens=$2 + local gpu_memory_utilization=$3 + echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" + local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt" + echo "vllm_log: $vllm_log" + echo + rm -f $vllm_log + pkill -f vllm + + echo "starting server..." + start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log + result=$? + if [[ "$result" -eq 1 ]]; then + echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens" + else + echo "server started." + fi + echo echo "run benchmark test..." - echo meet_latency_requirement=0 # get a basic qps by using request-rate inf bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt" @@ -118,29 +134,29 @@ run_benchmark() { python benchmarks/benchmark_serving.py \ --backend vllm \ --model $MODEL \ - --dataset-name sonnet \ - --dataset-path benchmarks/sonnet_4x.txt \ - --sonnet-input-len $INPUT_LEN \ - --sonnet-output-len $OUTPUT_LEN \ + --dataset-name random \ + --random-input-len $INPUT_LEN \ + --random-output-len $OUTPUT_LEN \ --ignore-eos \ --disable-tqdm \ --request-rate inf \ --percentile-metrics ttft,tpot,itl,e2el \ --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ - --num-prompts 100 \ - --sonnet-prefix-len $prefix_len \ - --port 8004 > "$bm_log" - through_put=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + --num-prompts 1000 \ + --random-prefix-len $prefix_len \ + --port 8004 &> "$bm_log" + throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then meet_latency_requirement=1 + request_rate=inf fi if (( ! meet_latency_requirement )); then - # start from request-rate as int(through_put) + 1 - request_rate=$((${through_put%.*} + 1)) + # start from request-rate as int(throughput) + 1 + request_rate=$((${throughput%.*} + 1)) while ((request_rate > 0)); do # clear prefix cache curl -X POST http://0.0.0.0:8004/reset_prefix_cache @@ -149,19 +165,18 @@ run_benchmark() { python benchmarks/benchmark_serving.py \ --backend vllm \ --model $MODEL \ - --dataset-name sonnet \ - --dataset-path benchmarks/sonnet_4x.txt \ - --sonnet-input-len $INPUT_LEN \ - --sonnet-output-len $OUTPUT_LEN \ - --ignore_eos \ + --dataset-name random \ + --random-input-len $INPUT_LEN \ + --random-output-len $OUTPUT_LEN \ + --ignore-eos \ --disable-tqdm \ --request-rate $request_rate \ --percentile-metrics ttft,tpot,itl,e2el \ --goodput e2el:$MAX_LATENCY_ALLOWED_MS \ --num-prompts 100 \ - --sonnet-prefix-len $prefix_len \ - --port 8004 > "$bm_log" - through_put=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') + --random-prefix-len $prefix_len \ + --port 8004 &> "$bm_log" + throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}') goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g') if (( $(echo "$e2el <= $MAX_LATENCY_ALLOWED_MS" | bc -l) )); then @@ -173,10 +188,10 @@ run_benchmark() { fi # write the results and update the best result. if ((meet_latency_requirement)); then - echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, through put: $through_put, goodput: $goodput" - echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, through put: $through_put, goodput: $goodput" >> "$RESULT" - if (( $(echo "$through_put > $best_throughput" | bc -l) )); then - best_throughput=$through_put + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, throughput: $throughput, goodput: $goodput" + echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens, request_rate: $request_rate, e2el: $e2el, throughput: $throughput, goodput: $goodput" >> "$RESULT" + if (( $(echo "$throughput > $best_throughput" | bc -l) )); then + best_throughput=$throughput best_max_num_seqs=$max_num_seqs best_num_batched_tokens=$max_num_batched_tokens best_goodput=$goodput @@ -188,22 +203,39 @@ run_benchmark() { echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" - echo "pkill -f vllm" - echo pkill vllm sleep 10 - rm -f $vllm_log printf '=%.0s' $(seq 1 20) return 0 } +read -r -a num_seqs_list <<< "$NUM_SEQS_LIST" +read -r -a num_batched_tokens_list <<< "$NUM_BATCHED_TOKENS_LIST" + +# first find out the max gpu-memory-utilization without HBM OOM. +gpu_memory_utilization=0.98 +find_gpu_memory_utilization=0 +while (( $(echo "$gpu_memory_utilization >= 0.9" | bc -l) )); do + start_server $gpu_memory_utilization "${num_seqs_list[-1]}" "${num_batched_tokens_list[-1]}" "$LOG_FOLDER/vllm_log_gpu_memory_utilization_$gpu_memory_utilization.log" + result=$? + if [[ "$result" -eq 0 ]]; then + find_gpu_memory_utilization=1 + break + else + gpu_memory_utilization=$(echo "$gpu_memory_utilization - 0.01" | bc) + fi +done + +if [[ "$find_gpu_memory_utilization" -eq 1 ]]; then + echo "Using gpu_memory_utilization=$gpu_memory_utilization to serve model." +else + echo "Cannot find a proper gpu_memory_utilization over 0.9 to serve the model, please check logs in $LOG_FOLDER." + exit 1 +fi -num_seqs_list="128 256" -num_batched_tokens_list="512 1024 2048 4096" -for num_seqs in $num_seqs_list; do - for num_batched_tokens in $num_batched_tokens_list; do - run_benchmark $num_seqs $num_batched_tokens - exit 0 +for num_seqs in "${num_seqs_list[@]}"; do + for num_batched_tokens in "${num_batched_tokens_list[@]}"; do + run_benchmark $num_seqs $num_batched_tokens $gpu_memory_utilization done done echo "finish permutations" diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 85e6eda7f36f..ddb38e304cd6 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import io import json diff --git a/benchmarks/benchmark_dataset.py b/benchmarks/benchmark_dataset.py index d86bf045ea47..5d2a26cd443c 100644 --- a/benchmarks/benchmark_dataset.py +++ b/benchmarks/benchmark_dataset.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ This module defines a framework for sampling benchmark requests from various datasets. Each dataset subclass of BenchmarkDataset must implement sample @@ -864,7 +865,15 @@ def sample( for item in self.data: if len(sampled_requests) >= num_requests: break - prompt = f"{item['instruction']}:\n{item['input']}" + prompt = f"{item['input']}\n\n{item['instruction']} Just output \ + the code, do not include any explanation." + + # apply template + prompt = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True, + tokenize=False, + ) prompt_len = len(tokenizer(prompt).input_ids) sampled_requests.append( SampleRequest( diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index de62bf5c63c7..c06857247eee 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Benchmark the latency of processing a single batch of requests.""" import argparse diff --git a/benchmarks/benchmark_long_document_qa_throughput.py b/benchmarks/benchmark_long_document_qa_throughput.py index 109624c87789..00869fa94e71 100644 --- a/benchmarks/benchmark_long_document_qa_throughput.py +++ b/benchmarks/benchmark_long_document_qa_throughput.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Offline benchmark to test the long document QA throughput. diff --git a/benchmarks/benchmark_prefix_caching.py b/benchmarks/benchmark_prefix_caching.py index ffaa8035797c..3e4704f0b820 100644 --- a/benchmarks/benchmark_prefix_caching.py +++ b/benchmarks/benchmark_prefix_caching.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Benchmark the efficiency of prefix caching. diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index a05dd24dece8..5496703f23cc 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Benchmark offline prioritization.""" import argparse diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 6bd9f1b49c2e..81428fb7dae1 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project r"""Benchmark online serving throughput. On the server side, run one of the following commands: diff --git a/benchmarks/benchmark_serving_structured_output.py b/benchmarks/benchmark_serving_structured_output.py index 6a50f47d3951..c1501ad52c25 100644 --- a/benchmarks/benchmark_serving_structured_output.py +++ b/benchmarks/benchmark_serving_structured_output.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project r"""Benchmark online serving throughput with structured outputs. On the server side, run one of the following commands: @@ -11,7 +12,6 @@ --model \ --dataset json \ --structured-output-ratio 1.0 \ - --structured-output-backend auto \ --request-rate 10 \ --num-prompts 1000 diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 7a13babda9d1..d19753d40e49 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Benchmark offline inference throughput.""" import argparse diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index b0c4fca92c3d..283f938df50a 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import json @@ -65,4 +66,9 @@ def iterencode(self, o: Any, *args, **kwargs) -> Any: def write_to_json(filename: str, records: list) -> None: with open(filename, "w") as f: - json.dump(records, f, cls=InfEncoder) + json.dump( + records, + f, + cls=InfEncoder, + default=lambda o: f"<{type(o).__name__} object is not JSON serializable>", + ) diff --git a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py index da258f98e085..9ec270bbd2e9 100644 --- a/benchmarks/cutlass_benchmarks/sparse_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/sparse_benchmarks.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index 7e9f5a7fc0f4..b4f3c6bf94ed 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Cutlass bench utils from collections.abc import Iterable diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 08e93837f7dd..cec422e8d597 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy diff --git a/benchmarks/cutlass_benchmarks/weight_shapes.py b/benchmarks/cutlass_benchmarks/weight_shapes.py index d31b623a1ee6..25b96ef56620 100644 --- a/benchmarks/cutlass_benchmarks/weight_shapes.py +++ b/benchmarks/cutlass_benchmarks/weight_shapes.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Weight Shapes are in the format # ([K, N], TP_SPLIT_DIM) diff --git a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py index fce156e1c96c..f62d8102e2d9 100644 --- a/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py +++ b/benchmarks/disagg_benchmarks/disagg_prefill_proxy_server.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os diff --git a/benchmarks/disagg_benchmarks/round_robin_proxy.py b/benchmarks/disagg_benchmarks/round_robin_proxy.py index fd19b40bf252..b1df2f255822 100644 --- a/benchmarks/disagg_benchmarks/round_robin_proxy.py +++ b/benchmarks/disagg_benchmarks/round_robin_proxy.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import itertools diff --git a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py index 484d0cb3cba7..74fa56d076cf 100644 --- a/benchmarks/disagg_benchmarks/visualize_benchmark_results.py +++ b/benchmarks/disagg_benchmarks/visualize_benchmark_results.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json diff --git a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py index 37a9173a1a93..901524214469 100644 --- a/benchmarks/fused_kernels/layernorm_rms_benchmarks.py +++ b/benchmarks/fused_kernels/layernorm_rms_benchmarks.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pickle as pkl import time diff --git a/benchmarks/kernels/bench_fp8_gemm.py b/benchmarks/kernels/bench_fp8_gemm.py index 36d03e40ef9a..b964ed242edf 100644 --- a/benchmarks/kernels/bench_fp8_gemm.py +++ b/benchmarks/kernels/bench_fp8_gemm.py @@ -1,14 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy import itertools import torch -import triton from weight_shapes import WEIGHT_SHAPES from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm from vllm._custom_ops import scaled_fp8_quant as vllm_scaled_fp8_quant +from vllm.triton_utils import triton @triton.testing.perf_report( diff --git a/benchmarks/kernels/benchmark_aqlm.py b/benchmarks/kernels/benchmark_aqlm.py index e9934aa479dd..42de062b08e4 100644 --- a/benchmarks/kernels/benchmark_aqlm.py +++ b/benchmarks/kernels/benchmark_aqlm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import os import sys diff --git a/benchmarks/kernels/benchmark_bitblas.py b/benchmarks/kernels/benchmark_bitblas.py index d40ab70ec539..97ee06034137 100644 --- a/benchmarks/kernels/benchmark_bitblas.py +++ b/benchmarks/kernels/benchmark_bitblas.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. diff --git a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py index d39d8a6e3aba..35c20ee41b9a 100644 --- a/benchmarks/kernels/benchmark_cutlass_fp4_moe.py +++ b/benchmarks/kernels/benchmark_cutlass_fp4_moe.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ Benchmark the performance of the cutlass_moe_fp4 kernel vs the triton_moe kernel. The cutlass_moe_fp4 kernel takes in fp4 quantized weights and 16-bit @@ -90,7 +91,7 @@ def bench_run( score = torch.randn((m, num_experts), device=device, dtype=dtype) - topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False) + topk_weights, topk_ids, _ = fused_topk(a, score, topk, renormalize=False) quant_blocksize = 16 w1_blockscale = torch.empty( diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 2197bceabe6c..acabe6c1ddb0 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch.utils.benchmark as benchmark @@ -6,8 +7,8 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.fused_moe import ( - cutlass_moe_fp8, fused_experts, fused_topk, ) @@ -69,18 +70,9 @@ def bench_run( w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) - ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) - for expert in range(num_experts): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) - w1_q_notransp = w1_q.clone() - w2_q_notransp = w2_q.clone() - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) score = torch.randn((m, num_experts), device="cuda", dtype=dtype) @@ -121,10 +113,6 @@ def run_cutlass_moe( w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, num_repeats: int, ): for _ in range(num_repeats): @@ -132,14 +120,10 @@ def run_cutlass_moe( a, w1, w2, - w1_scale, - w2_scale, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, + w1_scale, + w2_scale, a1_scale=a_scale, ) @@ -152,10 +136,6 @@ def run_cutlass_from_graph( w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, ): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) @@ -164,14 +144,10 @@ def run_cutlass_from_graph( a, w1_q, w2_q, - w1_scale, - w2_scale, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, + w1_scale, + w2_scale, a1_scale=a_scale, ) @@ -217,10 +193,6 @@ def replay_graph(graph, num_repeats): w2_scale, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, ) torch.cuda.synchronize() @@ -229,8 +201,8 @@ def replay_graph(graph, num_repeats): with torch.cuda.graph(triton_graph, stream=triton_stream): run_triton_from_graph( a, - w1_q_notransp, - w2_q_notransp, + w1_q, + w2_q, topk_weights, topk_ids, w1_scale, @@ -249,18 +221,12 @@ def replay_graph(graph, num_repeats): "w2": w2, "score": score, "topk": topk, - "w1_q_notransp": w1_q_notransp, - "w2_q_notransp": w2_q_notransp, # Cutlass params "a_scale": a_scale, "w1_q": w1_q, "w2_q": w2_q, "w1_scale": w1_scale, "w2_scale": w2_scale, - "ab_strides1": ab_strides1, - "c_strides1": c_strides1, - "ab_strides2": ab_strides2, - "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -278,8 +244,8 @@ def replay_graph(graph, num_repeats): # Warmup run_triton_moe( a, - w1_q_notransp, - w2_q_notransp, + w1_q, + w2_q, topk_weights, topk_ids, w1_scale, @@ -290,7 +256,7 @@ def replay_graph(graph, num_repeats): results.append( benchmark.Timer( - stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 + stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -321,16 +287,12 @@ def replay_graph(graph, num_repeats): w2_scale, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, num_warmup, ) results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/benchmarks/kernels/benchmark_layernorm.py b/benchmarks/kernels/benchmark_layernorm.py index f21ca97eeb8a..69978ec6b23e 100644 --- a/benchmarks/kernels/benchmark_layernorm.py +++ b/benchmarks/kernels/benchmark_layernorm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time diff --git a/benchmarks/kernels/benchmark_lora.py b/benchmarks/kernels/benchmark_lora.py index 6c1284930c1e..3d38d4b3534e 100644 --- a/benchmarks/kernels/benchmark_lora.py +++ b/benchmarks/kernels/benchmark_lora.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index f8f1db04790b..0f896f187ecb 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import copy diff --git a/benchmarks/kernels/benchmark_marlin.py b/benchmarks/kernels/benchmark_marlin.py index b17baff2e5f5..9ea1fddae2a3 100644 --- a/benchmarks/kernels/benchmark_marlin.py +++ b/benchmarks/kernels/benchmark_marlin.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch.utils.benchmark as benchmark diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index c2f7660858f5..6cb55b35993e 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse import json diff --git a/benchmarks/kernels/benchmark_moe_permute_unpermute.py b/benchmarks/kernels/benchmark_moe_permute_unpermute.py index 333986fdf5ef..dba1f3943b96 100644 --- a/benchmarks/kernels/benchmark_moe_permute_unpermute.py +++ b/benchmarks/kernels/benchmark_moe_permute_unpermute.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse from typing import Any, TypedDict diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 54f05e723226..7e0376c18ecc 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import random import time diff --git a/benchmarks/kernels/benchmark_quant.py b/benchmarks/kernels/benchmark_quant.py index 2463dfebe83c..6ab26f5f1adf 100644 --- a/benchmarks/kernels/benchmark_quant.py +++ b/benchmarks/kernels/benchmark_quant.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time diff --git a/benchmarks/kernels/benchmark_rmsnorm.py b/benchmarks/kernels/benchmark_rmsnorm.py index d720083b6150..4cf633a81358 100644 --- a/benchmarks/kernels/benchmark_rmsnorm.py +++ b/benchmarks/kernels/benchmark_rmsnorm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools from typing import Optional, Union diff --git a/benchmarks/kernels/benchmark_rope.py b/benchmarks/kernels/benchmark_rope.py index 944024ca3572..b81baf17a8c6 100644 --- a/benchmarks/kernels/benchmark_rope.py +++ b/benchmarks/kernels/benchmark_rope.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project from itertools import accumulate from typing import Optional diff --git a/benchmarks/kernels/benchmark_shapes.py b/benchmarks/kernels/benchmark_shapes.py index 70190ba24d9d..18c459c31d3f 100644 --- a/benchmarks/kernels/benchmark_shapes.py +++ b/benchmarks/kernels/benchmark_shapes.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project WEIGHT_SHAPES = { "ideal": [[4 * 256 * 32, 256 * 32]], diff --git a/benchmarks/kernels/benchmark_w8a8_block_fp8.py b/benchmarks/kernels/benchmark_w8a8_block_fp8.py index 6315c1ee6cdd..4fcdbadd65ec 100644 --- a/benchmarks/kernels/benchmark_w8a8_block_fp8.py +++ b/benchmarks/kernels/benchmark_w8a8_block_fp8.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from sglang quantization/tuning_block_wise_kernel.py import argparse diff --git a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py index e37764825451..e67ce0545318 100644 --- a/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py +++ b/benchmarks/kernels/deepgemm/benchmark_fp8_block_dense_gemm.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # fmt: off # ruff: noqa: E501 import time diff --git a/benchmarks/kernels/graph_machete_bench.py b/benchmarks/kernels/graph_machete_bench.py index 0c86e4072957..9a4da0ef5a85 100644 --- a/benchmarks/kernels/graph_machete_bench.py +++ b/benchmarks/kernels/graph_machete_bench.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math import pickle diff --git a/benchmarks/kernels/utils.py b/benchmarks/kernels/utils.py index 877a29feed9d..4bbb36bb4359 100644 --- a/benchmarks/kernels/utils.py +++ b/benchmarks/kernels/utils.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses from collections.abc import Iterable diff --git a/benchmarks/kernels/weight_shapes.py b/benchmarks/kernels/weight_shapes.py index afe159ddda6e..a27f02394afb 100644 --- a/benchmarks/kernels/weight_shapes.py +++ b/benchmarks/kernels/weight_shapes.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Weight Shapes are in the format # ([K, N], TP_SPLIT_DIM) diff --git a/benchmarks/overheads/benchmark_hashing.py b/benchmarks/overheads/benchmark_hashing.py index d5701a8fbd6d..0957a9c65f06 100644 --- a/benchmarks/overheads/benchmark_hashing.py +++ b/benchmarks/overheads/benchmark_hashing.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import cProfile import pstats diff --git a/cmake/hipify.py b/cmake/hipify.py index a15577125eb1..55d378f5b111 100755 --- a/cmake/hipify.py +++ b/cmake/hipify.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # # A command line tool for running pytorch's hipify preprocessor on CUDA diff --git a/csrc/attention/mla/cutlass_mla_kernels.cu b/csrc/attention/mla/cutlass_mla_kernels.cu index 6743af0cf2db..f4b6b19f4b23 100644 --- a/csrc/attention/mla/cutlass_mla_kernels.cu +++ b/csrc/attention/mla/cutlass_mla_kernels.cu @@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options( {static_cast(out.data_ptr()), stride_O, static_cast(nullptr), stride_LSE}, hw_info, - -1, // split_kv + 1, // split_kv nullptr, // is_var_split_kv }; // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute diff --git a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py index d64f0d0a5c2a..1dd7101acc27 100644 --- a/csrc/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/csrc/cutlass_extensions/vllm_cutlass_library_extension.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum from typing import Union diff --git a/csrc/moe/marlin_moe_wna16/generate_kernels.py b/csrc/moe/marlin_moe_wna16/generate_kernels.py index 15f008d4f61e..49f33718a21e 100644 --- a/csrc/moe/marlin_moe_wna16/generate_kernels.py +++ b/csrc/moe/marlin_moe_wna16/generate_kernels.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project import glob import itertools import os diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index 8fda434d452f..c4faef731060 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -30,4 +30,8 @@ torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, int64_t BLOCK_SIZE_K, int64_t bit); #endif -bool moe_permute_unpermute_supported(); \ No newline at end of file +bool moe_permute_unpermute_supported(); + +void shuffle_rows(const torch::Tensor& input_tensor, + const torch::Tensor& dst2src_map, + torch::Tensor& output_tensor); \ No newline at end of file diff --git a/csrc/moe/moe_permute_unpermute_op.cu b/csrc/moe/moe_permute_unpermute_op.cu index 9a7465261abf..68f429fac18a 100644 --- a/csrc/moe/moe_permute_unpermute_op.cu +++ b/csrc/moe/moe_permute_unpermute_op.cu @@ -130,6 +130,62 @@ void moe_unpermute( }); } +template +__global__ void shuffleInputRowsKernel(const T* input, + const int32_t* dst2src_map, T* output, + int64_t num_src_rows, + int64_t num_dst_rows, int64_t num_cols) { + int64_t dest_row_idx = blockIdx.x; + int64_t const source_row_idx = dst2src_map[dest_row_idx]; + + if (blockIdx.x < num_dst_rows) { + // Load 128-bits per thread + constexpr int64_t ELEM_PER_THREAD = 128 / sizeof(T) / 8; + using DataElem = cutlass::Array; + + // Duplicate and permute rows + auto const* source_row_ptr = + reinterpret_cast(input + source_row_idx * num_cols); + auto* dest_row_ptr = + reinterpret_cast(output + dest_row_idx * num_cols); + + int64_t const start_offset = threadIdx.x; + int64_t const stride = blockDim.x; + int64_t const num_elems_in_col = num_cols / ELEM_PER_THREAD; + + for (int elem_index = start_offset; elem_index < num_elems_in_col; + elem_index += stride) { + dest_row_ptr[elem_index] = source_row_ptr[elem_index]; + } + } +} + +void shuffle_rows(const torch::Tensor& input_tensor, + const torch::Tensor& dst2src_map, + torch::Tensor& output_tensor) { + TORCH_CHECK(input_tensor.scalar_type() == output_tensor.scalar_type(), + "Input and output tensors must have the same data type"); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + int64_t const blocks = output_tensor.size(0); + int64_t const threads = 256; + int64_t const num_dest_rows = output_tensor.size(0); + int64_t const num_src_rows = input_tensor.size(0); + int64_t const num_cols = input_tensor.size(1); + + TORCH_CHECK(!(num_cols % (128 / sizeof(input_tensor.scalar_type()) / 8)), + "num_cols must be divisible by 128 / " + "sizeof(input_tensor.scalar_type()) / 8"); + + MOE_DISPATCH(input_tensor.scalar_type(), [&] { + shuffleInputRowsKernel<<>>( + reinterpret_cast(input_tensor.data_ptr()), + dst2src_map.data_ptr(), + reinterpret_cast(output_tensor.data_ptr()), num_src_rows, + num_dest_rows, num_cols); + }); +} + #else void moe_permute(const torch::Tensor& input, const torch::Tensor& topk_weights, diff --git a/csrc/moe/permute_unpermute_kernels/dispatch.h b/csrc/moe/permute_unpermute_kernels/dispatch.h index 41932cdd85bc..d0f1ea4aded3 100644 --- a/csrc/moe/permute_unpermute_kernels/dispatch.h +++ b/csrc/moe/permute_unpermute_kernels/dispatch.h @@ -14,12 +14,13 @@ __VA_ARGS__(); \ break; \ } -#define MOE_DISPATCH_FLOAT_CASE(...) \ - MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ - MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) +#define MOE_DISPATCH_FLOAT_CASE(...) \ + MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__) \ + MOE_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) #define MOE_DISPATCH(TYPE, ...) \ MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__)) @@ -39,6 +40,11 @@ template <> struct ScalarType2CudaType { using type = __nv_bfloat16; }; +// uint8 for packed fp4 +template <> +struct ScalarType2CudaType { + using type = uint8_t; +}; // #if __CUDA_ARCH__ >= 890 // fp8 diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index a9379032245d..10be47966f61 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -516,9 +516,8 @@ void topk_softmax( topk, stream); } - else + else if (topk_indices.scalar_type() == at::ScalarType::UInt32) { - assert(topk_indices.scalar_type() == at::ScalarType::UInt32); vllm::moe::topkGatingSoftmaxKernelLauncher( gating_output.data_ptr(), topk_weights.data_ptr(), @@ -530,4 +529,17 @@ void topk_softmax( topk, stream); } + else { + assert(topk_indices.scalar_type() == at::ScalarType::Int64); + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } } diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 7d35ec79ead4..a74eb3720cf1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -81,6 +81,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { m.def("moe_permute_unpermute_supported() -> bool"); m.impl("moe_permute_unpermute_supported", &moe_permute_unpermute_supported); + // Row shuffle for MoE + m.def( + "shuffle_rows(Tensor input_tensor, Tensor dst2src_map, Tensor! " + "output_tensor) -> ()"); + m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows); + #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 7044b4588b81..f02f5083ac19 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -92,6 +92,11 @@ void rms_norm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual, torch::Tensor& weight, double epsilon); +void apply_repetition_penalties_(torch::Tensor& logits, + const torch::Tensor& prompt_mask, + const torch::Tensor& output_mask, + const torch::Tensor& repetition_penalties); + void rms_norm_static_fp8_quant(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weight, torch::Tensor& scale, double epsilon); @@ -231,7 +236,8 @@ void cutlass_moe_mm( torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides); + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch); void cutlass_fp4_group_mm( torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, @@ -243,7 +249,16 @@ void get_cutlass_moe_mm_data( const torch::Tensor& topk_ids, torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, torch::Tensor& input_permutation, torch::Tensor& output_permutation, - const int64_t num_experts, const int64_t n, const int64_t k); + const int64_t num_experts, const int64_t n, const int64_t k, + const std::optional& blockscale_offsets); + +void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, + const int64_t padded_m, const int64_t n, + const int64_t k); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu index 84492553c02f..4a8a5ed02d6c 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -9,10 +9,6 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - TORCH_CHECK( - a.size(0) % 4 == 0, - "Input tensor must have a number of rows that is a multiple of 4. ", - "but got: ", a.size(0), " rows."); if (out.dtype() == torch::kBFloat16) { cutlass_gemm_blockwise_sm100_fp8_dispatch( out, a, b, a_scales, b_scales); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index ef324364c6d5..c841125dbb73 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -1,5 +1,6 @@ #pragma once +#include "cuda_utils.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -22,49 +23,49 @@ namespace vllm { using namespace cute; -template +// clang-format off +template struct cutlass_3x_gemm_fp8_blockwise { + static constexpr bool swap_ab = swap_ab_; using ElementAB = cutlass::float_e4m3_t; using ElementA = ElementAB; using LayoutA = cutlass::layout::RowMajor; + using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; using ElementB = ElementAB; using LayoutB = cutlass::layout::ColumnMajor; + using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; - using ElementC = void; using ElementD = OutType; using LayoutD = cutlass::layout::RowMajor; + using LayoutD_Transpose = typename cutlass::layout::LayoutTranspose::type; static constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + using ElementC = void; // TODO: support bias using LayoutC = LayoutD; + using LayoutC_Transpose = LayoutD_Transpose; static constexpr int AlignmentC = AlignmentD; using ElementAccumulator = float; using ElementCompute = float; using ElementBlockScale = float; - // MMA and Cluster Tile Shapes - // Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster - // Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>; - static constexpr int ScaleMsPerTile = size<0>(ScalesPerTile{}); - static constexpr int ScaleGranularityM = - size<0>(MmaTileShape{}) / ScaleMsPerTile; - static constexpr int ScaleGranularityN = - size<1>(MmaTileShape{}) / size<1>(ScalesPerTile{}); - static constexpr int ScaleGranularityK = - size<2>(MmaTileShape{}) / size<2>(ScalesPerTile{}); - - // Shape of the threadblocks in a cluster - using ClusterShape_MNK = ClusterShape; - - using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig< - ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, - cute::UMMA::Major::MN, cute::UMMA::Major::K>; + using ScaleConfig = conditional_t, + cutlass::detail::Sm100BlockwiseScaleConfig< + ScaleGranularityM, ScaleGranularityN, ScaleGranularityK, + cute::UMMA::Major::MN, cute::UMMA::Major::K>>; + + // layout_SFA and layout_SFB cannot be swapped since they are deduced. using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); @@ -73,7 +74,6 @@ struct cutlass_3x_gemm_fp8_blockwise { static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; using ElementScalar = float; - // clang-format off using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< ArchTag, @@ -84,33 +84,47 @@ struct cutlass_3x_gemm_fp8_blockwise { ElementAccumulator, ElementCompute, ElementC, - LayoutC, + conditional_t, AlignmentC, ElementD, - LayoutD, + conditional_t, AlignmentD, EpilogueScheduler, DefaultOperation >::CollectiveOp; using StageCountType = cutlass::gemm::collective::StageCountAuto; - using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, - OperatorClass, - ElementA, - cute::tuple, - AlignmentA, - ElementB, - cute::tuple, - AlignmentB, - ElementAccumulator, - MmaTileShape, - ClusterShape, - + using CollectiveMainloop = conditional_t, + AlignmentB, + ElementA, + cute::tuple, + AlignmentA, + ElementAccumulator, + MmaTileShape, + ClusterShape, cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, - MainloopScheduler - >::CollectiveOp; - // clang-format on + MainloopScheduler + >::CollectiveOp, + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, + OperatorClass, + ElementA, + cute::tuple, + AlignmentA, + ElementB, + cute::tuple, + AlignmentB, + ElementAccumulator, + MmaTileShape, + ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + MainloopScheduler + >::CollectiveOp>; using KernelType = enable_sm100_only, CollectiveMainloop, CollectiveEpilogue>>; @@ -123,6 +137,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { + static constexpr bool swap_ab = Gemm::swap_ab; using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; @@ -136,7 +151,6 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, using ElementD = typename Gemm::ElementD; int32_t m = a.size(0), n = b.size(1), k = a.size(1); - auto prob_shape = cute::make_shape(m, n, k, 1); StrideA a_stride; StrideB b_stride; @@ -146,11 +160,13 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, b_stride = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1)); c_stride = - cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, 1)); + cutlass::make_cute_packed_stride(StrideC{}, swap_ab ? cute::make_shape(n, m, 1) : cute::make_shape(m, n, 1)); - LayoutSFA layout_SFA = + LayoutSFA layout_SFA = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFA(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFA(make_shape(m, n, k, 1)); - LayoutSFB layout_SFB = + LayoutSFB layout_SFB = swap_ab ? + ScaleConfig::tile_atom_to_shape_SFB(make_shape(n, m, k, 1)) : ScaleConfig::tile_atom_to_shape_SFB(make_shape(m, n, k, 1)); auto a_ptr = static_cast(a.data_ptr()); @@ -158,9 +174,22 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, auto a_scales_ptr = static_cast(a_scales.data_ptr()); auto b_scales_ptr = static_cast(b_scales.data_ptr()); - typename GemmKernel::MainloopArguments mainloop_args{ - a_ptr, a_stride, b_ptr, b_stride, - a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB}; + auto mainloop_args = [&](){ + // layout_SFA and layout_SFB cannot be swapped since they are deduced. + if (swap_ab) { + return typename GemmKernel::MainloopArguments{ + b_ptr, b_stride, a_ptr, a_stride, + b_scales_ptr, layout_SFA, a_scales_ptr, layout_SFB + }; + } + else { + return typename GemmKernel::MainloopArguments{ + a_ptr, a_stride, b_ptr, b_stride, + a_scales_ptr, layout_SFA, b_scales_ptr, layout_SFB + }; + } + }(); + auto prob_shape = swap_ab ? cute::make_shape(n, m, k, 1) : cute::make_shape(m, n, k, 1); auto c_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ @@ -175,29 +204,74 @@ void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales) { - auto m = a.size(0); - auto k = a.size(1); - auto n = b.size(1); - int sms; + int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); - auto should_use_2sm = [&sms](int m, int n, int tile1SM = 128) { - return std::ceil(static_cast(m) / tile1SM) * - std::ceil(static_cast(n) / tile1SM) >= - sms; - }; - bool use_2sm = should_use_2sm(m, n); - if (use_2sm) { - cutlass_gemm_caller_blockwise, Shape<_256, _1, _1>, - Shape<_2, _2, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( - out, a, b, a_scales, b_scales); + constexpr int TILE_K = 128; + // TODO: better heuristics + bool swap_ab = (m < 16) || (m % 4 != 0); + bool use_tma_epilogue = (m * n) % 4 == 0; + if (!swap_ab) { + constexpr int TILE_N = 128; + int tile_m = 256; + if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 64) <= sms) { + tile_m = 64; + } + else if (cuda_utils::ceil_div(n, TILE_N) * cuda_utils::ceil_div(m, 128) <= sms) { + tile_m = 128; + } + if (tile_m == 64) { + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } + } else if (tile_m == 128) { + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + out, a, b, a_scales, b_scales); + } + } else { // tile_m == 256 + if (use_tma_epilogue) { + cutlass_gemm_caller_blockwise, Int>, + Shape<_2, _1, _1>, cutlass::epilogue::TmaWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } else { + cutlass_gemm_caller_blockwise, Int>, + Shape<_2, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized2Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise2SmSm100>>( + out, a, b, a_scales, b_scales); + } + } } else { + // TODO: Test more tile N configs + constexpr int TILE_M = 128; + constexpr int TILE_N = 16; + // TMA epilogue isn't compatible with Swap A/B cutlass_gemm_caller_blockwise, Shape<_128, _1, _1>, - Shape<_1, _1, _1>, cutlass::epilogue::TmaWarpSpecialized1Sm, - cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100>>( + OutType, TILE_M, 1, TILE_K, Shape, Int, Int>, + Shape<_1, _1, _1>, cutlass::epilogue::NoSmemWarpSpecialized1Sm, + cutlass::gemm::KernelTmaWarpSpecializedBlockwise1SmSm100, true>>( out, a, b, a_scales, b_scales); } } diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 468b77d9593b..6da2da634075 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -15,6 +15,7 @@ using c3x::cutlass_gemm_caller; template typename Epilogue> struct sm100_fp8_config_default { + // M in (128, inf) static_assert(std::is_same()); using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; @@ -25,6 +26,34 @@ struct sm100_fp8_config_default { KernelSchedule, EpilogueSchedule>; }; +template typename Epilogue> +struct sm100_fp8_config_M128 { + // M in (64, 128] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_128, _128, _64>; + using ClusterShape = Shape<_2, _2, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue> +struct sm100_fp8_config_M64 { + // M in [1, 64] + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_64, _64, _256>; + using ClusterShape = Shape<_1, _8, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + template typename Epilogue, typename... EpilogueArgs> @@ -39,8 +68,28 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, using Cutlass3xGemmDefault = typename sm100_fp8_config_default::Cutlass3xGemm; - return cutlass_gemm_caller( - out, a, b, std::forward(args)...); + using Cutlass3xGemmM64 = + typename sm100_fp8_config_M64::Cutlass3xGemm; + using Cutlass3xGemmM128 = + typename sm100_fp8_config_M128::Cutlass3xGemm; + + uint32_t const m = a.size(0); + uint32_t const mp2 = + std::max(static_cast(64), next_pow_2(m)); // next power of 2 + + if (mp2 <= 64) { + // m in [1, 64] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else if (mp2 <= 128) { + // m in (64, 128] + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } else { + // m in (128, inf) + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); + } } template