File tree Expand file tree Collapse file tree 3 files changed +19
-10
lines changed
tests/v1/kv_connector/unit Expand file tree Collapse file tree 3 files changed +19
-10
lines changed Original file line number Diff line number Diff line change @@ -179,6 +179,13 @@ def create_model_runner_output(
179179 sampled_token = EOS_TOKEN_ID if use_eos else 0
180180 sampled_token_ids = [[sampled_token ] for _ in req_ids ]
181181
182+ kv_connector_output = None if (
183+ finished_sending is None
184+ and finished_recving is None ) else KVConnectorOutput (
185+ finished_sending = finished_sending ,
186+ finished_recving = finished_recving ,
187+ )
188+
182189 # Make output data structure.
183190 return ModelRunnerOutput (
184191 req_ids = req_ids ,
@@ -188,10 +195,7 @@ def create_model_runner_output(
188195 logprobs = None ,
189196 prompt_logprobs_dict = {},
190197 pooler_output = None ,
191- kv_connector_output = KVConnectorOutput (
192- finished_sending = finished_sending ,
193- finished_recving = finished_recving ,
194- ),
198+ kv_connector_output = kv_connector_output ,
195199 )
196200
197201
Original file line number Diff line number Diff line change @@ -1151,8 +1151,8 @@ def _update_from_kv_xfer_finished(self,
11511151 scheduler the request during the next step.
11521152 """
11531153
1154- assert self .connector is not None
1155- self .connector .update_connector_output (kv_connector_output )
1154+ if self .connector is not None :
1155+ self .connector .update_connector_output (kv_connector_output )
11561156
11571157 # KV Connector:: update recv and send status from last step.
11581158 for req_id in (kv_connector_output .finished_recving or ()):
Original file line number Diff line number Diff line change @@ -1138,6 +1138,13 @@ def concat_lists(input_lists):
11381138 i , target_slice ] = valid_sampled_token_ids [i ]
11391139 req_state .output_token_ids .extend (valid_sampled_token_ids [i ])
11401140
1141+ kv_connector_output = None if (
1142+ finished_sending is None
1143+ and finished_recving is None ) else KVConnectorOutput (
1144+ finished_sending = finished_sending ,
1145+ finished_recving = finished_recving ,
1146+ )
1147+
11411148 model_runner_output = ModelRunnerOutput (
11421149 req_ids = req_ids ,
11431150 req_id_to_index = self .input_batch .req_id_to_index ,
@@ -1146,10 +1153,8 @@ def concat_lists(input_lists):
11461153 logprobs = logprobs_lists ,
11471154 prompt_logprobs_dict = prompt_logprobs_dict ,
11481155 pooler_output = [],
1149- kv_connector_output = KVConnectorOutput (
1150- finished_sending = finished_sending ,
1151- finished_recving = finished_recving ,
1152- ))
1156+ kv_connector_output = kv_connector_output ,
1157+ )
11531158
11541159 # Check there are no new graphs compiled - all the graphs should be
11551160 # captured and compiled during warm up.
You can’t perform that action at this time.
0 commit comments