Skip to content

Commit

Permalink
fixed bugs, server crash on empty prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
CNTRYROA committed Sep 20, 2024
1 parent e0660f7 commit 93b8676
Showing 1 changed file with 105 additions and 97 deletions.
202 changes: 105 additions & 97 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def enable_output_validation(cls):

@classmethod
def validate_output(
cls,
output: object,
output_type: Type[_O],
cls,
output: object,
output_type: Type[_O],
) -> _O:
do_validate = cls.DO_VALIDATE_OUTPUT

Expand All @@ -130,9 +130,9 @@ def validate_output(

@classmethod
def validate_outputs(
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
cls,
outputs: GenericSequence[object],
output_type: Type[_O],
) -> List[_O]:
do_validate = cls.DO_VALIDATE_OUTPUT

Expand All @@ -153,23 +153,23 @@ def validate_outputs(
tokenizer: Optional[BaseTokenizerGroup]

def __init__(
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
self,
model_config: ModelConfig,
cache_config: CacheConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
device_config: DeviceConfig,
load_config: LoadConfig,
lora_config: Optional[LoRAConfig],
multimodal_config: Optional[MultiModalConfig],
speculative_config: Optional[SpeculativeConfig],
decoding_config: Optional[DecodingConfig],
observability_config: Optional[ObservabilityConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
executor_class: Type[ExecutorBase],
log_stats: bool,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> None:
logger.info(
"Initializing an LLM engine (v%s) with config: "
Expand Down Expand Up @@ -272,31 +272,31 @@ def __init__(
extra_kvs={
# Common configuration
"dtype":
str(model_config.dtype),
str(model_config.dtype),
"tensor_parallel_size":
parallel_config.tensor_parallel_size,
parallel_config.tensor_parallel_size,
"block_size":
cache_config.block_size,
cache_config.block_size,
"gpu_memory_utilization":
cache_config.gpu_memory_utilization,
cache_config.gpu_memory_utilization,

# Quantization
"quantization":
model_config.quantization,
model_config.quantization,
"kv_cache_dtype":
str(cache_config.cache_dtype),
str(cache_config.cache_dtype),

# Feature flags
"enable_lora":
bool(lora_config),
bool(lora_config),
"enable_prompt_adapter":
bool(prompt_adapter_config),
bool(prompt_adapter_config),
"enable_prefix_caching":
cache_config.enable_prefix_caching,
cache_config.enable_prefix_caching,
"enforce_eager":
model_config.enforce_eager,
model_config.enforce_eager,
"disable_custom_all_reduce":
parallel_config.disable_custom_all_reduce,
parallel_config.disable_custom_all_reduce,
})

if self.tokenizer:
Expand All @@ -320,13 +320,13 @@ def __init__(
else:
self.stat_loggers = {
"logging":
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
LoggingStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC),
"prometheus":
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
PrometheusStatLogger(
local_interval=_LOCAL_LOGGING_INTERVAL_SEC,
labels=dict(model_name=model_config.served_model_name),
max_model_len=self.model_config.max_model_len),
}
self.stat_loggers["prometheus"].info("cache_config",
self.cache_config)
Expand Down Expand Up @@ -432,10 +432,10 @@ def _get_executor_cls(cls,

@classmethod
def from_engine_args(
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
Expand Down Expand Up @@ -475,8 +475,8 @@ def get_tokenizer_group(
return self.tokenizer

def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
self,
lora_request: Optional[LoRARequest] = None,
) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(lora_request)

Expand Down Expand Up @@ -512,15 +512,16 @@ def _get_eos_token_id(
return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id

def _add_processed_request(
self,
request_id: str,
processed_inputs: LLMInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
self,
request_id: str,
processed_inputs: LLMInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
self._validate_model_inputs(processed_inputs)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
Expand Down Expand Up @@ -563,11 +564,11 @@ def stop_remote_worker_execution_loop(self) -> None:
self.model_executor.stop_remote_worker_execution_loop()

def process_model_inputs(
self,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
self,
request_id: str,
inputs: PromptInputs,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> LLMInputs:
if isinstance(inputs, str):
inputs = {"prompt": inputs}
Expand All @@ -584,8 +585,8 @@ def process_model_inputs(

if prompt_adapter_request:
prompt_token_ids = \
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\
+ prompt_token_ids
[0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens \
+ prompt_token_ids

llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids,
prompt=inputs.get("prompt"),
Expand All @@ -594,14 +595,14 @@ def process_model_inputs(
return self.input_processor(llm_inputs)

def add_request(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Add a request to the engine's request pool.
Expand Down Expand Up @@ -668,21 +669,21 @@ def add_request(
)

def _create_sequence_group_with_sampling(
self,
request_id: str,
seq: Sequence,
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
self,
request_id: str,
seq: Sequence,
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
if (sampling_params.logprobs
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
and sampling_params.logprobs > max_logprobs) or (
sampling_params.prompt_logprobs
and sampling_params.prompt_logprobs > max_logprobs):
raise ValueError(f"Cannot request more than "
f"{max_logprobs} logprobs.")

Expand All @@ -706,13 +707,13 @@ def _create_sequence_group_with_sampling(
return seq_group

def _create_sequence_group_with_pooling(
self,
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
self,
request_id: str,
seq: Sequence,
pooling_params: PoolingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> SequenceGroup:
"""Creates a SequenceGroup with PoolingParams."""
# Defensive copy of PoolingParams, which are used by the pooler
Expand Down Expand Up @@ -785,9 +786,9 @@ def has_unfinished_requests_for_virtual_engine(
return self.scheduler[virtual_engine].has_unfinished_seqs()

def _process_sequence_group_outputs(
self,
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
self,
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
) -> None:
seq_group.embeddings = outputs[0].embeddings

Expand All @@ -797,11 +798,11 @@ def _process_sequence_group_outputs(
return

def _process_model_outputs(
self,
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
self,
output: GenericSequence[Union[SamplerOutput, PoolerOutput]],
scheduled_seq_groups: List[ScheduledSequenceGroup],
ignored_seq_groups: List[SequenceGroup],
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Apply the model output to the sequences in the scheduled seq groups.
Expand Down Expand Up @@ -1088,8 +1089,8 @@ def _get_stats(
# + num_generation_tokens_from_prefill_groups (since we generate
# one token on prefills on iters where the prefill finishes).
num_generation_tokens_iter = (
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
num_generation_tokens_from_prefill_groups)
scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter +
num_generation_tokens_from_prefill_groups)

# Spec decode, if enabled, emits specialized metrics from the worker in
# sampler output.
Expand Down Expand Up @@ -1214,3 +1215,10 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None:
seq_span.set_attribute(
SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft)
seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time)

def _validate_model_inputs(self, inputs: Union[LLMInputs]):

prompt_ids = inputs.get("prompt_token_ids")

if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")

0 comments on commit 93b8676

Please sign in to comment.