Skip to content

Commit 8fd5399

Browse files
PotabkMengqingCao
authored andcommitted
support logitsprocessor
Signed-off-by: wangli <wangli858794774@gmail.com>
1 parent 9e7c168 commit 8fd5399

File tree

3 files changed

+177
-141
lines changed

3 files changed

+177
-141
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 46 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import torch.nn as nn
3737
from vllm.attention import AttentionType, get_attn_backend
3838
from vllm.attention.layer import Attention
39-
from vllm.config import CompilationLevel, VllmConfig
39+
from vllm.config import CompilationLevel, CUDAGraphMode, VllmConfig
4040
from vllm.distributed import get_tensor_model_parallel_world_size
4141
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
4242
has_kv_transfer_group)
@@ -58,12 +58,13 @@
5858
from vllm.sequence import IntermediateTensors
5959
from vllm.tasks import GenerationTask, SupportedTask
6060
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
61-
LazyLoader, cdiv)
61+
LazyLoader, cdiv, is_pin_memory_available)
6262
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
6363
KVCacheSpec)
6464
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
6565
ModelRunnerOutput)
6666
from vllm.v1.pool.metadata import PoolingMetadata
67+
from vllm.v1.sample.logits_processor import build_logitsprocs
6768
from vllm.v1.sample.metadata import SamplingMetadata
6869
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
6970
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
@@ -156,6 +157,8 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
156157
self.cache_config = vllm_config.cache_config
157158
self.lora_config = vllm_config.lora_config
158159
self.parallel_config = vllm_config.parallel_config
160+
self.compilation_config = vllm_config.compilation_config
161+
self.pin_memory = is_pin_memory_available()
159162
self.scheduler_config = vllm_config.scheduler_config
160163
self.speculative_config = vllm_config.speculative_config
161164
self.block_size = vllm_config.cache_config.block_size
@@ -335,9 +338,11 @@ def __init__(self, vllm_config: VllmConfig, device: torch.device):
335338
== CompilationLevel.PIECEWISE
336339
and not self.model_config.enforce_eager and
337340
not ascend_config.torchair_graph_config.enabled)
338-
self.aclgraph_batch_sizes = list(
339-
reversed(
340-
self.vllm_config.compilation_config.cudagraph_capture_sizes))
341+
self.aclgraph_batch_sizes = []
342+
if self.compilation_config.cudagraph_capture_sizes and \
343+
self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE:
344+
self.aclgraph_batch_sizes = list(
345+
reversed(self.compilation_config.cudagraph_capture_sizes))
341346

342347
self.new_kv_cache_bytes = -1
343348
self.torchair_compiled_model = None # type: ignore
@@ -405,12 +410,6 @@ def check_batch_sizes_consistency(self) -> None:
405410
)
406411

407412
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
408-
"""Update the cached states and the persistent batch with the scheduler
409-
output.
410-
411-
The SamplingMetadata is updated and copied to the NPU if there is a
412-
new/resumed/paused/finished request in the batch.
413-
"""
414413
# Remove finished requests from the cached states.
415414
for req_id in scheduler_output.finished_req_ids:
416415
self.requests.pop(req_id, None)
@@ -421,11 +420,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
421420
# then resubmitted with the same ID. In this case, we treat them as two
422421
# distinct requests - clearing the cached states for the first request
423422
# and handling the second as a new request.
424-
removed_req_indices: List[int] = []
425423
for req_id in scheduler_output.finished_req_ids:
426-
req_index = self.input_batch.remove_request(req_id)
427-
if req_index is not None:
428-
removed_req_indices.append(req_index)
424+
self.input_batch.remove_request(req_id)
429425

430426
# Free the cached encoder outputs.
431427
for req_id, input_id in scheduler_output.free_encoder_input_ids:
@@ -448,16 +444,15 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
448444
# have low request overlap (e.g., alternating between two distinct
449445
# sets of requests), this optimization becomes very inefficient.
450446
for req_id in unscheduled_req_ids:
451-
req_index = self.input_batch.remove_request(req_id)
452-
assert req_index is not None
453-
removed_req_indices.append(req_index)
447+
self.input_batch.remove_request(req_id)
454448

455-
req_ids_to_add: List[str] = []
449+
req_ids_to_add: list[str] = []
456450
# Add new requests to the cached states.
457451
for new_req_data in scheduler_output.scheduled_new_reqs:
458452
req_id = new_req_data.req_id
459453
sampling_params = new_req_data.sampling_params
460454
pooling_params = new_req_data.pooling_params
455+
461456
if sampling_params and \
462457
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
463458
generator = torch.Generator(device=self.device)
@@ -468,7 +463,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
468463
if pooling_params:
469464
assert (task := pooling_params.task) is not None, (
470465
"You did not set `task` in the API")
471-
model = cast(VllmModelForPooling, self.model)
466+
467+
model = cast(VllmModelForPooling, self.get_model())
472468
to_update = model.pooler.get_pooling_updates(task)
473469
to_update.apply(pooling_params)
474470

@@ -478,7 +474,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
478474
mm_kwargs=new_req_data.mm_kwargs,
479475
mm_positions=new_req_data.mm_positions,
480476
sampling_params=sampling_params,
481-
pooling_params=new_req_data.pooling_params,
477+
pooling_params=pooling_params,
482478
generator=generator,
483479
block_ids=new_req_data.block_ids,
484480
num_computed_tokens=new_req_data.num_computed_tokens,
@@ -493,9 +489,8 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
493489
second_per_grid_ts = []
494490
audio_feature_lengths = []
495491
use_audio_in_video = False
496-
497-
for item in self.requests[req_id].mm_kwargs:
498-
mm_input = item.require_data()
492+
for mm_item in self.requests[req_id].mm_kwargs:
493+
mm_input = mm_item.get_data()
499494
if mm_input.get("image_grid_thw") is not None:
500495
image_grid_thw.append(
501496
mm_input["image_grid_thw"].tolist())
@@ -528,19 +523,24 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
528523
req_ids_to_add.append(req_id)
529524

530525
# Update the states of the running/resumed requests.
531-
req_data = scheduler_output.scheduled_cached_reqs
532526
is_last_rank = get_pp_group().is_last_rank
527+
req_data = scheduler_output.scheduled_cached_reqs
533528
for i, req_id in enumerate(req_data.req_ids):
534529
req_state = self.requests[req_id]
535530
num_computed_tokens = req_data.num_computed_tokens[i]
536531
new_block_ids = req_data.new_block_ids[i]
537532
resumed_from_preemption = req_data.resumed_from_preemption[i]
538533

534+
# Update the cached states.
539535
req_state.num_computed_tokens = num_computed_tokens
536+
540537
if not is_last_rank:
538+
# When using PP, the scheduler sends the sampled tokens back,
539+
# because there's no direct communication between the first-
540+
# stage worker and the last-stage worker.
541541
new_token_ids = req_data.new_token_ids[i]
542542
# Add the sampled token(s) from the previous step (if any).
543-
# This doesn't include "unverified" tokens like spec decode tokens.
543+
# This doesn't include "unverified" tokens like spec tokens.
544544
num_new_tokens = (num_computed_tokens + len(new_token_ids) -
545545
req_state.num_tokens)
546546
if num_new_tokens == 1:
@@ -549,11 +549,12 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
549549
elif num_new_tokens > 0:
550550
req_state.output_token_ids.extend(
551551
new_token_ids[-num_new_tokens:])
552+
552553
# Update the block IDs.
553554
if not resumed_from_preemption:
554555
# Append the new blocks to the existing block IDs.
555-
for block_ids, new_ids in zip( # type: ignore[call-overload]
556-
req_state.block_ids, new_block_ids):
556+
for block_ids, new_ids in zip(req_state.block_ids,
557+
new_block_ids):
557558
block_ids.extend(new_ids)
558559
else:
559560
# The request is resumed from preemption.
@@ -571,9 +572,10 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
571572
# Update the persistent batch.
572573
self.input_batch.num_computed_tokens_cpu[req_index] = (
573574
num_computed_tokens)
574-
575575
self.input_batch.block_table.append_row(new_block_ids, req_index)
576576

577+
# For the last rank, we don't need to update the token_ids_cpu
578+
# because the sampled tokens are already cached.
577579
if not is_last_rank:
578580
# Add new_token_ids to token_ids_cpu.
579581
start_token_index = num_computed_tokens
@@ -583,9 +585,11 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
583585
start_token_index:end_token_index] = new_token_ids
584586
self.input_batch.num_tokens_no_spec[
585587
req_index] = end_token_index
588+
self.input_batch.num_tokens[req_index] = end_token_index
589+
586590
# Add spec_token_ids to token_ids_cpu.
587-
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
588-
req_id, ())
591+
spec_token_ids = (
592+
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
589593
if spec_token_ids:
590594
num_spec_tokens = len(spec_token_ids)
591595
start_index = self.input_batch.num_tokens_no_spec[req_index]
@@ -595,39 +599,17 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
595599
# NOTE(woosuk): `num_tokens` here may include spec tokens.
596600
self.input_batch.num_tokens[req_index] += num_spec_tokens
597601

598-
# Check if the batch has changed. If not, we can skip copying the
599-
# sampling metadata from CPU to GPU.
600-
batch_changed = len(removed_req_indices) > 0 or len(req_ids_to_add) > 0
601-
602602
# Add the new or resumed requests to the persistent batch.
603603
# The smaller empty indices are filled first.
604-
removed_req_indices.sort(reverse=True)
605604
for req_id in req_ids_to_add:
606605
req_state = self.requests[req_id]
607-
if removed_req_indices:
608-
# Fill the empty index.
609-
req_index = removed_req_indices.pop()
610-
else:
611-
# Append to the end.
612-
req_index = None
613-
self.input_batch.add_request(req_state, req_index)
614-
spec_token_ids = scheduler_output.scheduled_spec_decode_tokens.get(
615-
req_id, ())
616-
if spec_token_ids:
617-
req_index = self.input_batch.num_reqs - 1
618-
start_index = len(req_state.prompt_token_ids) + len(
619-
req_state.output_token_ids)
620-
end_token_index = start_index + len(spec_token_ids)
621-
self.input_batch.token_ids_cpu[
622-
req_index, start_index:end_token_index] = spec_token_ids
623-
self.input_batch.num_tokens[req_index] = end_token_index
606+
self.input_batch.add_request(req_state)
624607

625-
# Condense the batched states if there are empty indices.
626-
if removed_req_indices:
627-
self.input_batch.condense(removed_req_indices)
608+
# Condense the batched states if there are gaps left by removed requests
609+
self.input_batch.condense()
628610

629-
if batch_changed:
630-
self.input_batch.refresh_sampling_metadata()
611+
# Refresh batch metadata with any pending updates.
612+
self.input_batch.refresh_metadata()
631613

632614
def _get_forward_metadata_across_dp(
633615
self, num_tokens: int, with_prefill: bool,
@@ -1063,11 +1045,6 @@ def _process_reqs(
10631045
num_input_tokens)
10641046
num_input_tokens += num_pad
10651047

1066-
modified_batch = self.attn_metadata_builder.reorder_batch(
1067-
self.input_batch, scheduler_output)
1068-
if modified_batch:
1069-
self.input_batch.refresh_sampling_metadata()
1070-
10711048
# OPTIMIZATION: Start copying the block table first.
10721049
# This way, we can overlap the copy with the following CPU operations.
10731050
self.input_batch.block_table.commit_block_table(num_reqs)
@@ -2199,10 +2176,15 @@ def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
21992176
max_model_len=self.model_config.max_model_len,
22002177
max_num_batched_tokens=self.max_num_tokens,
22012178
device=self.device,
2202-
pin_memory=True,
2179+
pin_memory=self.pin_memory,
22032180
vocab_size=self.model_config.get_vocab_size(),
22042181
block_sizes=[self.block_size],
22052182
is_spec_decode=bool(self.vllm_config.speculative_config),
2183+
logitsprocs=build_logitsprocs(
2184+
self.vllm_config, self.device, self.pin_memory,
2185+
self.is_pooling_model,
2186+
self.vllm_config.model_config.logits_processors),
2187+
is_pooling_model=self.is_pooling_model,
22062188
)
22072189

22082190
kv_cache_sizes = {}

0 commit comments

Comments
 (0)