diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3747f93..28d6940 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -115,9 +115,9 @@ def enable_output_validation(cls): @classmethod def validate_output( - cls, - output: object, - output_type: Type[_O], + cls, + output: object, + output_type: Type[_O], ) -> _O: do_validate = cls.DO_VALIDATE_OUTPUT @@ -130,9 +130,9 @@ def validate_output( @classmethod def validate_outputs( - cls, - outputs: GenericSequence[object], - output_type: Type[_O], + cls, + outputs: GenericSequence[object], + output_type: Type[_O], ) -> List[_O]: do_validate = cls.DO_VALIDATE_OUTPUT @@ -153,23 +153,23 @@ def validate_outputs( tokenizer: Optional[BaseTokenizerGroup] def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - multimodal_config: Optional[MultiModalConfig], - speculative_config: Optional[SpeculativeConfig], - decoding_config: Optional[DecodingConfig], - observability_config: Optional[ObservabilityConfig], - prompt_adapter_config: Optional[PromptAdapterConfig], - executor_class: Type[ExecutorBase], - log_stats: bool, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + multimodal_config: Optional[MultiModalConfig], + speculative_config: Optional[SpeculativeConfig], + decoding_config: Optional[DecodingConfig], + observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], + executor_class: Type[ExecutorBase], + log_stats: bool, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -272,31 +272,31 @@ def __init__( extra_kvs={ # Common configuration "dtype": - str(model_config.dtype), + str(model_config.dtype), "tensor_parallel_size": - parallel_config.tensor_parallel_size, + parallel_config.tensor_parallel_size, "block_size": - cache_config.block_size, + cache_config.block_size, "gpu_memory_utilization": - cache_config.gpu_memory_utilization, + cache_config.gpu_memory_utilization, # Quantization "quantization": - model_config.quantization, + model_config.quantization, "kv_cache_dtype": - str(cache_config.cache_dtype), + str(cache_config.cache_dtype), # Feature flags "enable_lora": - bool(lora_config), + bool(lora_config), "enable_prompt_adapter": - bool(prompt_adapter_config), + bool(prompt_adapter_config), "enable_prefix_caching": - cache_config.enable_prefix_caching, + cache_config.enable_prefix_caching, "enforce_eager": - model_config.enforce_eager, + model_config.enforce_eager, "disable_custom_all_reduce": - parallel_config.disable_custom_all_reduce, + parallel_config.disable_custom_all_reduce, }) if self.tokenizer: @@ -320,13 +320,13 @@ def __init__( else: self.stat_loggers = { "logging": - LoggingStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC), + LoggingStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC), "prometheus": - PrometheusStatLogger( - local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), - max_model_len=self.model_config.max_model_len), + PrometheusStatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.served_model_name), + max_model_len=self.model_config.max_model_len), } self.stat_loggers["prometheus"].info("cache_config", self.cache_config) @@ -432,10 +432,10 @@ def _get_executor_cls(cls, @classmethod def from_engine_args( - cls, - engine_args: EngineArgs, - usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, - stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, + cls, + engine_args: EngineArgs, + usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, + stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. @@ -475,8 +475,8 @@ def get_tokenizer_group( return self.tokenizer def get_tokenizer( - self, - lora_request: Optional[LoRARequest] = None, + self, + lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: return self.get_tokenizer_group().get_lora_tokenizer(lora_request) @@ -512,15 +512,16 @@ def _get_eos_token_id( return self.tokenizer.get_lora_tokenizer(lora_request).eos_token_id def _add_processed_request( - self, - request_id: str, - processed_inputs: LLMInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: float, - lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], - trace_headers: Optional[Mapping[str, str]] = None, + self, + request_id: str, + processed_inputs: LLMInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], + trace_headers: Optional[Mapping[str, str]] = None, ) -> None: + self._validate_model_inputs(processed_inputs) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) @@ -563,11 +564,11 @@ def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() def process_model_inputs( - self, - request_id: str, - inputs: PromptInputs, - lora_request: Optional[LoRARequest] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + self, + request_id: str, + inputs: PromptInputs, + lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -584,8 +585,8 @@ def process_model_inputs( if prompt_adapter_request: prompt_token_ids = \ - [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ - + prompt_token_ids + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens \ + + prompt_token_ids llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), @@ -594,14 +595,14 @@ def process_model_inputs( return self.input_processor(llm_inputs) def add_request( - self, - request_id: str, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + self, + request_id: str, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -668,21 +669,21 @@ def add_request( ) def _create_sequence_group_with_sampling( - self, - request_id: str, - seq: Sequence, - sampling_params: SamplingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, + self, + request_id: str, + seq: Sequence, + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs if (sampling_params.logprobs - and sampling_params.logprobs > max_logprobs) or ( - sampling_params.prompt_logprobs - and sampling_params.prompt_logprobs > max_logprobs): + and sampling_params.logprobs > max_logprobs) or ( + sampling_params.prompt_logprobs + and sampling_params.prompt_logprobs > max_logprobs): raise ValueError(f"Cannot request more than " f"{max_logprobs} logprobs.") @@ -706,13 +707,13 @@ def _create_sequence_group_with_sampling( return seq_group def _create_sequence_group_with_pooling( - self, - request_id: str, - seq: Sequence, - pooling_params: PoolingParams, - arrival_time: float, - lora_request: Optional[LoRARequest], - prompt_adapter_request: Optional[PromptAdapterRequest], + self, + request_id: str, + seq: Sequence, + pooling_params: PoolingParams, + arrival_time: float, + lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -785,9 +786,9 @@ def has_unfinished_requests_for_virtual_engine( return self.scheduler[virtual_engine].has_unfinished_seqs() def _process_sequence_group_outputs( - self, - seq_group: SequenceGroup, - outputs: List[EmbeddingSequenceGroupOutput], + self, + seq_group: SequenceGroup, + outputs: List[EmbeddingSequenceGroupOutput], ) -> None: seq_group.embeddings = outputs[0].embeddings @@ -797,11 +798,11 @@ def _process_sequence_group_outputs( return def _process_model_outputs( - self, - output: GenericSequence[Union[SamplerOutput, PoolerOutput]], - scheduled_seq_groups: List[ScheduledSequenceGroup], - ignored_seq_groups: List[SequenceGroup], - seq_group_metadata_list: List[SequenceGroupMetadata], + self, + output: GenericSequence[Union[SamplerOutput, PoolerOutput]], + scheduled_seq_groups: List[ScheduledSequenceGroup], + ignored_seq_groups: List[SequenceGroup], + seq_group_metadata_list: List[SequenceGroupMetadata], ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Apply the model output to the sequences in the scheduled seq groups. @@ -1088,8 +1089,8 @@ def _get_stats( # + num_generation_tokens_from_prefill_groups (since we generate # one token on prefills on iters where the prefill finishes). num_generation_tokens_iter = ( - scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter + - num_generation_tokens_from_prefill_groups) + scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter + + num_generation_tokens_from_prefill_groups) # Spec decode, if enabled, emits specialized metrics from the worker in # sampler output. @@ -1214,3 +1215,10 @@ def create_trace_span(self, seq_group: SequenceGroup) -> None: seq_span.set_attribute( SpanAttributes.LLM_LATENCY_TIME_TO_FIRST_TOKEN, ttft) seq_span.set_attribute(SpanAttributes.LLM_LATENCY_E2E, e2e_time) + + def _validate_model_inputs(self, inputs: Union[LLMInputs]): + + prompt_ids = inputs.get("prompt_token_ids") + + if prompt_ids is None or len(prompt_ids) == 0: + raise ValueError("Prompt cannot be empty") \ No newline at end of file