-
-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[V1][Perf] Reduce scheduling overhead in model runner after cuda sync #12094
Changes from all commits
ff21f9e
41dba06
9ce3d6e
8ca382d
1cc6492
f35e80b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -775,10 +775,10 @@ def execute_model( | |||||
sampling_metadata=sampling_metadata, | ||||||
) | ||||||
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids | ||||||
# TODO(woosuk): The following loop can be slow since it iterates over | ||||||
# the requests one by one. Optimize. | ||||||
num_reqs = self.input_batch.num_reqs | ||||||
request_seq_lens: List[Tuple[int, CachedRequestState, int]] = [] | ||||||
for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): | ||||||
assert req_id is not None | ||||||
req_state = self.requests[req_id] | ||||||
|
@@ -787,10 +787,10 @@ def execute_model( | |||||
assert seq_len <= req_state.num_tokens | ||||||
if seq_len == req_state.num_tokens: | ||||||
# Append the sampled token to the output token ids. | ||||||
token_id = sampled_token_ids[i] | ||||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id | ||||||
self.input_batch.num_tokens[i] += 1 | ||||||
req_state.output_token_ids.append(token_id) | ||||||
# OPTIMIZATION: Priming the state updates for later updates. | ||||||
req_state.output_token_ids.append(0) | ||||||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
request_seq_lens.append((i, req_state, seq_len)) | ||||||
else: | ||||||
# Ignore the sampled token from the partial request. | ||||||
# Rewind the generator state as if the token was not sampled. | ||||||
|
@@ -799,6 +799,21 @@ def execute_model( | |||||
# This relies on cuda-specific torch-internal impl details | ||||||
generator.set_offset(generator.get_offset() - 4) | ||||||
|
||||||
# num_reqs entries should be non-None | ||||||
assert all( | ||||||
req_id is not None for req_id in | ||||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None" | ||||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) | ||||||
|
||||||
# NOTE: GPU -> CPU Sync happens here. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just for a record: If top-p or top-k sampling is used (with the FlashInfer kernel), CPU-GPU synchronization happens inside the sampler at vllm/vllm/v1/sample/ops/topk_topp_sampler.py Lines 193 to 194 in 324960a
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you think we can avoid this in a follow up PR? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we can. This is a fundamental limitation of the kernel (or the algorithm itself). The rejection sampling method cannot 100% guarantee the success. |
||||||
# Move as many CPU operations as possible before this sync point. | ||||||
WoosukKwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
sampled_token_ids = sampler_output.sampled_token_ids.tolist() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be faster to do sampler_output.sampled_token_ids.cpu() There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In my experience, |
||||||
# Update with the actual token ids | ||||||
for i, req_state, seq_len in request_seq_lens: | ||||||
token_id = sampled_token_ids[i] | ||||||
self.input_batch.token_ids_cpu[i, seq_len] = token_id | ||||||
req_state.output_token_ids[-1] = token_id | ||||||
|
||||||
if sampler_output.logprob_token_ids is None: | ||||||
logprob_token_ids = None | ||||||
else: | ||||||
|
@@ -808,12 +823,6 @@ def execute_model( | |||||
else: | ||||||
logprobs = sampler_output.logprobs.cpu() | ||||||
|
||||||
# num_reqs entries should be non-None | ||||||
assert all( | ||||||
req_id is not None for req_id in | ||||||
self.input_batch.req_ids[:num_reqs]), "req_ids contains None" | ||||||
req_ids = cast(List[str], self.input_batch.req_ids[:num_reqs]) | ||||||
|
||||||
model_runner_output = ModelRunnerOutput( | ||||||
req_ids=req_ids, | ||||||
req_id_to_index=self.input_batch.req_id_to_index, | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this necessary?
iirc, @tlrmchlsmth use
List[int]
because they are cheaper to serialize, and would benefit tensor parallel case, where we need to pass them across processes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is true — I didn’t look at how it impacts the non-TP case though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
ModelRunnerOutput
is what we serialize for TP, we don't serialize theSamplerOutput
directly, so this is not a concernThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, yep that's right -- I did change this line in #9856, but that was just downstream of changing
sampled_token_ids
to aList
in theModelRunnerOutput
. This looks good to me since that's left as-is!