Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tests/v1/kv_connector/unit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ def create_model_runner_output(
sampled_token = EOS_TOKEN_ID if use_eos else 0
sampled_token_ids = [[sampled_token] for _ in req_ids]

kv_connector_output = None if (
finished_sending is None
and finished_recving is None) else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
)

# Make output data structure.
return ModelRunnerOutput(
req_ids=req_ids,
Expand All @@ -188,10 +195,7 @@ def create_model_runner_output(
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,
kv_connector_output=KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
),
kv_connector_output=kv_connector_output,
)


Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,8 +1151,8 @@ def _update_from_kv_xfer_finished(self,
scheduler the request during the next step.
"""

assert self.connector is not None
self.connector.update_connector_output(kv_connector_output)
if self.connector is not None:
self.connector.update_connector_output(kv_connector_output)

# KV Connector:: update recv and send status from last step.
for req_id in (kv_connector_output.finished_recving or ()):
Expand Down
13 changes: 9 additions & 4 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,6 +1138,13 @@ def concat_lists(input_lists):
i, target_slice] = valid_sampled_token_ids[i]
req_state.output_token_ids.extend(valid_sampled_token_ids[i])

kv_connector_output = None if (
finished_sending is None
and finished_recving is None) else KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
)

model_runner_output = ModelRunnerOutput(
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
Expand All @@ -1146,10 +1153,8 @@ def concat_lists(input_lists):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
kv_connector_output=KVConnectorOutput(
finished_sending=finished_sending,
finished_recving=finished_recving,
))
kv_connector_output=kv_connector_output,
)

# Check there are no new graphs compiled - all the graphs should be
# captured and compiled during warm up.
Expand Down