Skip to content

Commit dafd924

Browse files
Raise error for long prompt (#273)
1 parent 598dc4b commit dafd924

File tree

5 files changed

+42
-11
lines changed

5 files changed

+42
-11
lines changed

vllm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,14 +186,18 @@ class SchedulerConfig:
186186
a single iteration.
187187
max_num_seqs: Maximum number of sequences to be processed in a single
188188
iteration.
189+
max_seq_len: Maximum length of a sequence (including prompt
190+
and generated text).
189191
"""
190192
def __init__(
191193
self,
192194
max_num_batched_tokens: int,
193195
max_num_seqs: int,
196+
max_seq_len: int
194197
) -> None:
195198
self.max_num_batched_tokens = max_num_batched_tokens
196199
self.max_num_seqs = max_num_seqs
200+
self.max_seq_len = max_seq_len
197201

198202

199203
_STR_DTYPE_TO_TORCH_DTYPE = {

vllm/core/scheduler.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,12 @@ def has_unfinished_seqs(self) -> bool:
102102
def get_num_unfinished_seq_groups(self) -> int:
103103
return len(self.waiting) + len(self.running) + len(self.swapped)
104104

105-
def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
105+
def _schedule(self) -> Tuple[SchedulerOutputs, List[str], List[SequenceGroup]]:
106106
# Blocks that need to be swaped or copied before model execution.
107107
blocks_to_swap_in: Dict[int, int] = {}
108108
blocks_to_swap_out: Dict[int, int] = {}
109109
blocks_to_copy: Dict[int, List[int]] = {}
110+
ignored_seq_groups: List[SequenceGroup] = []
110111

111112
# Fix the current time.
112113
now = time.time()
@@ -187,12 +188,24 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
187188
# If the sequence group has been preempted in this step, stop.
188189
if seq_group in preempted:
189190
break
191+
192+
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
193+
if num_prompt_tokens >= self.scheduler_config.max_seq_len:
194+
logger.warn(
195+
f"Input prompt ({num_prompt_tokens} tokens) is too long"
196+
" and exceeds limit of "
197+
f"{self.scheduler_config.max_seq_len}")
198+
for seq in seq_group.get_seqs():
199+
seq.status = SequenceStatus.FINISHED_IGNORED
200+
ignored_seq_groups.append(seq_group)
201+
self.waiting.pop(0)
202+
break
203+
190204
# If the sequence group cannot be allocated, stop.
191205
if not self.block_manager.can_allocate(seq_group):
192206
break
193207

194208
# If the number of batched tokens exceeds the limit, stop.
195-
num_prompt_tokens = seq_group.get_seqs()[0].get_len()
196209
if (num_batched_tokens + num_prompt_tokens
197210
> self.scheduler_config.max_num_batched_tokens):
198211
break
@@ -218,7 +231,7 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
218231
blocks_to_copy=blocks_to_copy,
219232
)
220233
if not self.log_stats:
221-
return scheduler_outputs, prompt_group_ids
234+
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
222235

223236
# TODO(woosuk): Move the below code to the engine.
224237
now = time.time()
@@ -258,13 +271,13 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
258271
f"Pending: {len(self.waiting)} reqs, "
259272
f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, "
260273
f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%")
261-
return scheduler_outputs, prompt_group_ids
274+
return scheduler_outputs, prompt_group_ids, ignored_seq_groups
262275

263-
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
276+
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, List[SequenceGroup]]:
264277
# Schedule sequence groups.
265278
# This function call changes the internal states of the scheduler
266279
# such as self.running, self.swapped, and self.waiting.
267-
scheduler_outputs, prompt_group_ids = self._schedule()
280+
scheduler_outputs, prompt_group_ids, ignored_seq_groups = self._schedule()
268281

269282
# Create input data structures.
270283
seq_group_metadata_list: List[SequenceGroupMetadata] = []
@@ -286,7 +299,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
286299
block_tables=block_tables,
287300
)
288301
seq_group_metadata_list.append(seq_group_metadata)
289-
return seq_group_metadata_list, scheduler_outputs
302+
return seq_group_metadata_list, scheduler_outputs, ignored_seq_groups
290303

291304
def update(
292305
self,

vllm/engine/arg_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,12 @@ def create_engine_configs(
123123
parallel_config = ParallelConfig(self.pipeline_parallel_size,
124124
self.tensor_parallel_size,
125125
self.worker_use_ray)
126+
max_seq_len = min(
127+
self.max_num_batched_tokens,
128+
getattr(model_config.hf_config, "max_position_embeddings",
129+
float("inf")))
126130
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
127-
self.max_num_seqs)
131+
self.max_num_seqs, max_seq_len)
128132
return model_config, cache_config, parallel_config, scheduler_config
129133

130134

vllm/engine/llm_engine.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,8 @@ def step(self) -> List[RequestOutput]:
226226
and updates the scheduler with the model outputs. Finally, it decodes
227227
the sequences and returns the newly generated results.
228228
"""
229-
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
230-
if (not seq_group_metadata_list) and scheduler_outputs.is_empty():
229+
seq_group_metadata_list, scheduler_outputs, ignored_seq_groups = self.scheduler.schedule()
230+
if (not seq_group_metadata_list) and scheduler_outputs.is_empty() and (not ignored_seq_groups):
231231
# Nothing to do.
232232
return []
233233

@@ -251,7 +251,7 @@ def step(self) -> List[RequestOutput]:
251251

252252
# Create the outputs.
253253
request_outputs: List[RequestOutput] = []
254-
for seq_group in seq_groups:
254+
for seq_group in seq_groups + ignored_seq_groups:
255255
request_output = RequestOutput.from_seq_group(seq_group)
256256
request_outputs.append(request_output)
257257
return request_outputs
@@ -288,6 +288,12 @@ def _stop_sequences(self, seq_groups: List[SequenceGroup]) -> None:
288288
if stopped:
289289
continue
290290

291+
# Check if the sequence has reached max_seq_len.
292+
if (seq.get_len() >=
293+
self.scheduler.scheduler_config.max_seq_len):
294+
self.scheduler.free_seq(
295+
seq, SequenceStatus.FINISHED_LENGTH_CAPPED)
296+
continue
291297
# Check if the sequence has reached max_tokens.
292298
if seq.get_output_len() == sampling_params.max_tokens:
293299
self.scheduler.free_seq(

vllm/sequence.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@ class SequenceStatus(enum.Enum):
1313
FINISHED_STOPPED = enum.auto()
1414
FINISHED_LENGTH_CAPPED = enum.auto()
1515
FINISHED_ABORTED = enum.auto()
16+
FINISHED_IGNORED = enum.auto()
1617

1718
@staticmethod
1819
def is_finished(status: "SequenceStatus") -> bool:
1920
return status in [
2021
SequenceStatus.FINISHED_STOPPED,
2122
SequenceStatus.FINISHED_LENGTH_CAPPED,
2223
SequenceStatus.FINISHED_ABORTED,
24+
SequenceStatus.FINISHED_IGNORED
2325
]
2426

2527
@staticmethod
@@ -30,6 +32,8 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
3032
finish_reason = "length"
3133
elif status == SequenceStatus.FINISHED_ABORTED:
3234
finish_reason = "abort"
35+
elif status == SequenceStatus.FINISHED_IGNORED:
36+
finish_reason = "length"
3337
else:
3438
finish_reason = None
3539
return finish_reason

0 commit comments

Comments
 (0)