Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[V1][Perf] Reduce scheduling overhead in model runner after cuda sync #12094

Merged
merged 6 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
class SamplerOutput:

# [num_reqs]
sampled_token_ids: List[int]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this necessary?
iirc, @tlrmchlsmth use List[int] because they are cheaper to serialize, and would benefit tensor parallel case, where we need to pass them across processes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is true — I didn’t look at how it impacts the non-TP case though

Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ModelRunnerOutput is what we serialize for TP, we don't serialize the SamplerOutput directly, so this is not a concern

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yep that's right -- I did change this line in #9856, but that was just downstream of changing sampled_token_ids to a List in the ModelRunnerOutput. This looks good to me since that's left as-is!

sampled_token_ids: torch.Tensor
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

# [num_reqs, max_num_logprobs + 1]
logprob_token_ids: Optional[torch.Tensor]
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ def forward(
# Use int32 to reduce the tensor size.
sampled = sampled.to(torch.int32)

# NOTE: CPU-GPU synchronization happens here.
sampler_output = SamplerOutput(
sampled_token_ids=sampled.tolist(),
sampled_token_ids=sampled,
logprob_token_ids=topk_indices,
logprobs=topk_logprobs,
prompt_logprob_token_ids=None,
Expand Down
29 changes: 19 additions & 10 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,10 +775,10 @@ def execute_model(
sampling_metadata=sampling_metadata,
)

sampled_token_ids = sampler_output.sampled_token_ids
# TODO(woosuk): The following loop can be slow since it iterates over
# the requests one by one. Optimize.
num_reqs = self.input_batch.num_reqs
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = []
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
assert req_id is not None
req_state = self.requests[req_id]
Expand All @@ -787,10 +787,10 @@ def execute_model(
assert seq_len <= req_state.num_tokens
if seq_len == req_state.num_tokens:
# Append the sampled token to the output token ids.
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
self.input_batch.num_tokens[i] += 1
req_state.output_token_ids.append(token_id)
# OPTIMIZATION: Priming the state updates for later updates.
req_state.output_token_ids.append(0)
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
request_seq_lens.append((i, req_state, seq_len))
else:
# Ignore the sampled token from the partial request.
# Rewind the generator state as if the token was not sampled.
Expand All @@ -799,6 +799,21 @@ def execute_model(
# This relies on cuda-specific torch-internal impl details
generator.set_offset(generator.get_offset() - 4)

# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])

# NOTE: GPU -> CPU Sync happens here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for a record: If top-p or top-k sampling is used (with the FlashInfer kernel), CPU-GPU synchronization happens inside the sampler at

# NOTE: CPU-GPU synchronization happens here.
if not success.all():

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we can avoid this in a follow up PR?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can. This is a fundamental limitation of the kernel (or the algorithm itself). The rejection sampling method cannot 100% guarantee the success.

# Move as many CPU operations as possible before this sync point.
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved
sampled_token_ids = sampler_output.sampled_token_ids.tolist()
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Jan 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be faster to do sampler_output.sampled_token_ids.cpu()and thensampler_output.sampled_token_ids[i].item()` in the inner loop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my experience, item() took considerable time so should be avoided.

# Update with the actual token ids
for i, req_state, seq_len in request_seq_lens:
token_id = sampled_token_ids[i]
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids[-1] = token_id

if sampler_output.logprob_token_ids is None:
logprob_token_ids = None
else:
Expand All @@ -808,12 +823,6 @@ def execute_model(
else:
logprobs = sampler_output.logprobs.cpu()

# num_reqs entries should be non-None
assert all(
req_id is not None for req_id in
self.input_batch.req_ids[:num_reqs]), "req_ids contains None"
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs])

model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
Expand Down
Loading