@@ -102,11 +102,12 @@ def has_unfinished_seqs(self) -> bool:
102102 def get_num_unfinished_seq_groups (self ) -> int :
103103 return len (self .waiting ) + len (self .running ) + len (self .swapped )
104104
105- def _schedule (self ) -> Tuple [SchedulerOutputs , List [str ]]:
105+ def _schedule (self ) -> Tuple [SchedulerOutputs , List [str ], List [ SequenceGroup ] ]:
106106 # Blocks that need to be swaped or copied before model execution.
107107 blocks_to_swap_in : Dict [int , int ] = {}
108108 blocks_to_swap_out : Dict [int , int ] = {}
109109 blocks_to_copy : Dict [int , List [int ]] = {}
110+ ignored_seq_groups : List [SequenceGroup ] = []
110111
111112 # Fix the current time.
112113 now = time .time ()
@@ -187,12 +188,24 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
187188 # If the sequence group has been preempted in this step, stop.
188189 if seq_group in preempted :
189190 break
191+
192+ num_prompt_tokens = seq_group .get_seqs ()[0 ].get_len ()
193+ if num_prompt_tokens >= self .scheduler_config .max_seq_len :
194+ logger .warn (
195+ f"Input prompt ({ num_prompt_tokens } tokens) is too long"
196+ " and exceeds limit of "
197+ f"{ self .scheduler_config .max_seq_len } " )
198+ for seq in seq_group .get_seqs ():
199+ seq .status = SequenceStatus .FINISHED_IGNORED
200+ ignored_seq_groups .append (seq_group )
201+ self .waiting .pop (0 )
202+ break
203+
190204 # If the sequence group cannot be allocated, stop.
191205 if not self .block_manager .can_allocate (seq_group ):
192206 break
193207
194208 # If the number of batched tokens exceeds the limit, stop.
195- num_prompt_tokens = seq_group .get_seqs ()[0 ].get_len ()
196209 if (num_batched_tokens + num_prompt_tokens
197210 > self .scheduler_config .max_num_batched_tokens ):
198211 break
@@ -218,7 +231,7 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
218231 blocks_to_copy = blocks_to_copy ,
219232 )
220233 if not self .log_stats :
221- return scheduler_outputs , prompt_group_ids
234+ return scheduler_outputs , prompt_group_ids , ignored_seq_groups
222235
223236 # TODO(woosuk): Move the below code to the engine.
224237 now = time .time ()
@@ -258,13 +271,13 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
258271 f"Pending: { len (self .waiting )} reqs, "
259272 f"GPU KV cache usage: { gpu_cache_usage * 100 :.1f} %, "
260273 f"CPU KV cache usage: { cpu_cache_usage * 100 :.1f} %" )
261- return scheduler_outputs , prompt_group_ids
274+ return scheduler_outputs , prompt_group_ids , ignored_seq_groups
262275
263- def schedule (self ) -> Tuple [List [SequenceGroupMetadata ], SchedulerOutputs ]:
276+ def schedule (self ) -> Tuple [List [SequenceGroupMetadata ], SchedulerOutputs , List [ SequenceGroup ] ]:
264277 # Schedule sequence groups.
265278 # This function call changes the internal states of the scheduler
266279 # such as self.running, self.swapped, and self.waiting.
267- scheduler_outputs , prompt_group_ids = self ._schedule ()
280+ scheduler_outputs , prompt_group_ids , ignored_seq_groups = self ._schedule ()
268281
269282 # Create input data structures.
270283 seq_group_metadata_list : List [SequenceGroupMetadata ] = []
@@ -286,7 +299,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
286299 block_tables = block_tables ,
287300 )
288301 seq_group_metadata_list .append (seq_group_metadata )
289- return seq_group_metadata_list , scheduler_outputs
302+ return seq_group_metadata_list , scheduler_outputs , ignored_seq_groups
290303
291304 def update (
292305 self ,
0 commit comments