diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index ca788d37c8f40..a7537c8761481 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,7 +1,7 @@ import time from typing import Iterable, List, Optional, Type, Union -from transformers import PreTrainedTokenizer +from transformers import GenerationConfig, PreTrainedTokenizer import vllm from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, @@ -34,6 +34,17 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5 +def _load_generation_config_dict(model_config: ModelConfig): + try: + return GenerationConfig.from_pretrained( + model_config.model, + revision=model_config.revision, + ).to_diff_dict() + except OSError: + # Not found. + return {} + + class LLMEngine: """An LLM engine that receives requests and generates texts. @@ -126,6 +137,8 @@ def __init__( self._init_tokenizer() self.detokenizer = Detokenizer(self.tokenizer) self.seq_counter = Counter() + self.generation_config_fields = _load_generation_config_dict( + model_config) self.model_executor = executor_class( model_config=model_config, @@ -393,6 +406,8 @@ def add_request( # inject the eos token id into the sampling_params to support min_tokens # processing sampling_params.eos_token_id = seq.eos_token_id + sampling_params.update_from_generation_config( + self.generation_config_fields) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, @@ -437,7 +452,7 @@ def _process_model_outputs( scheduled_seq_groups: List[SequenceGroup], ignored_seq_groups: List[SequenceGroup]) -> List[RequestOutput]: """Apply the model output to the sequences in the scheduled seq groups. - + Returns RequestOutputs that can be returned to the client. """ diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 53a38b25bfdac..dc0e60344d858 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -2,7 +2,7 @@ import copy from enum import IntEnum from functools import cached_property -from typing import Callable, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch from pydantic import Field @@ -271,6 +271,18 @@ def _verify_greedy_sampling(self) -> None: raise ValueError("best_of must be 1 when using greedy sampling." f"Got {self.best_of}.") + def update_from_generation_config( + self, generation_config: Dict[str, Any]) -> None: + """Update if there are non-default values from generation_config""" + # Update eos_token_id for generation + if eos_ids := generation_config.get("eos_token_id"): + # it can be either int or list of int + if isinstance(eos_ids, int): + eos_ids = [eos_ids] + original_stop_token_ids = set(self.stop_token_ids) + original_stop_token_ids.update(eos_ids) + self.stop_token_ids = list(original_stop_token_ids) + @cached_property def sampling_type(self) -> SamplingType: if self.use_beam_search: