Skip to content

Commit

Permalink
[V1][Perf] Reduce scheduling overhead in model runner after cuda sync (
Browse files Browse the repository at this point in the history
…vllm-project#12094)

Signed-off-by: Keyun Tong <tongkeyun@gmail.com>
  • Loading branch information
youngkent authored and NickLucche committed Feb 7, 2025
1 parent 1d3625e commit 42bfed0
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 13 deletions.
2 changes: 1 addition & 1 deletion vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class SamplerOutput:

# [num_reqs]
sampled_token_ids: List[int]
sampled_token_ids: torch.Tensor

# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ def forward(
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)

# NOTE: CPU-GPU synchronization happens here.
sampler_output = SamplerOutput(
sampled_token_ids=sampled.tolist(),
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
Expand Down
29 changes: 19 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,10 +775,10 @@ def execute_model(
sampling_metadata=sampling_metadata,
)

sampled_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req_state = self.requests[req_id]
Expand All @@ -787,10 +787,10 @@ def execute_model(
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
self.input_batch.num_tokens[i] += 1
req_state.output_token_ids.append(token_id)
# OPTIMIZATION: Priming the state updates for later updates.
req_state.output_token_ids.append(0)
request_seq_lens.append((i, req_state, seq_len))
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
Expand All @@ -799,6 +799,21 @@ def execute_model(
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)

# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])

# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
# Update with the actual token ids
for i, req_state, seq_len in request_seq_lens:
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids[-1] = token_id

if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
else:
Expand All @@ -808,12 +823,6 @@ def execute_model(
else:
logprobs = sampler_output.logprobs.cpu()

# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])

model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
Expand Down

0 comments on commit 42bfed0

Please sign in to comment.