Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,13 +341,13 @@ def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs

pooling_params = self.input_batch.pooling_metadata.pooling_params

num_pooling_reqs = len(pooling_params)
num_pooling_reqs = len(self.input_batch.pooling_params)

if num_pooling_reqs == 0:
return model_kwargs

pooling_params = self.input_batch.pooling_metadata.pooling_params

assert num_pooling_reqs == num_reqs
Comment on lines +344 to 351
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While the optimization is correct, this implementation introduces a potential for future bugs. num_pooling_reqs is derived from self.input_batch.pooling_params, but the pooling_params variable that is used in the rest of the function is derived from self.input_batch.pooling_metadata. If the logic inside the pooling_metadata property changes in the future (e.g., to filter some requests), len(pooling_params) might no longer be equal to the originally computed num_pooling_reqs, which could lead to subtle issues.

To make the code more robust, it's better to derive num_pooling_reqs directly from the pooling_params list after it has been created. The initial check can be simplified to check if self.input_batch.pooling_params is empty.

Suggested change
num_pooling_reqs = len(self.input_batch.pooling_params)
if num_pooling_reqs == 0:
return model_kwargs
pooling_params = self.input_batch.pooling_metadata.pooling_params
assert num_pooling_reqs == num_reqs
if not self.input_batch.pooling_params:
return model_kwargs
pooling_params = self.input_batch.pooling_metadata.pooling_params
num_pooling_reqs = len(pooling_params)
assert num_pooling_reqs == num_reqs


token_type_id_requests = dict[int, Any]()
Expand Down