From 86d5e5d7ffac901468140591abe0aa1fcea6eeae Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sat, 29 Jun 2024 19:19:02 +0800 Subject: [PATCH] [Bugfix] Support `eos_token_id` from `config.json` (#5954) --- tests/tokenization/test_get_eos.py | 31 ++++++++++++++++++++++++++++++ vllm/engine/llm_engine.py | 23 ++++++++++++---------- vllm/transformers_utils/config.py | 24 ++++++++++++++++++++++- 3 files changed, 67 insertions(+), 11 deletions(-) create mode 100644 tests/tokenization/test_get_eos.py diff --git a/tests/tokenization/test_get_eos.py b/tests/tokenization/test_get_eos.py new file mode 100644 index 0000000000000..875ca19d3b4b7 --- /dev/null +++ b/tests/tokenization/test_get_eos.py @@ -0,0 +1,31 @@ +""" +This test file includes some cases where it is inappropriate to +only get the `eos_token_id` from the tokenizer as defined by +:meth:`vllm.LLMEngine._get_eos_token_id`. +""" +from vllm.transformers_utils.config import try_get_generation_config +from vllm.transformers_utils.tokenizer import get_tokenizer + + +def test_get_llama3_eos_token(): + model_name = "meta-llama/Meta-Llama-3-8B-Instruct" + + tokenizer = get_tokenizer(model_name) + assert tokenizer.eos_token_id == 128009 + + generation_config = try_get_generation_config(model_name, + trust_remote_code=False) + assert generation_config is not None + assert generation_config.eos_token_id == [128001, 128009] + + +def test_get_blip2_eos_token(): + model_name = "Salesforce/blip2-opt-2.7b" + + tokenizer = get_tokenizer(model_name) + assert tokenizer.eos_token_id == 2 + + generation_config = try_get_generation_config(model_name, + trust_remote_code=False) + assert generation_config is not None + assert generation_config.eos_token_id == 50118 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 808a639f5dc9e..f7e38c0e6b948 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,10 +1,10 @@ import time from contextlib import contextmanager -from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Set, Type, TypeVar, Union -from transformers import GenerationConfig, PreTrainedTokenizer +from transformers import PreTrainedTokenizer from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, @@ -34,6 +34,7 @@ SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) +from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -46,16 +47,18 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5 -def _load_generation_config_dict(model_config: ModelConfig): - try: - return GenerationConfig.from_pretrained( - model_config.model, - revision=model_config.revision, - ).to_diff_dict() - except OSError: - # Not found. +def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: + config = try_get_generation_config( + model_config.model, + trust_remote_code=model_config.trust_remote_code, + revision=model_config.revision, + ) + + if config is None: return {} + return config.to_diff_dict() + _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 60fc756a12e3d..5e2fe116db9c6 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,7 +1,7 @@ import contextlib from typing import Dict, Optional, Type -from transformers import PretrainedConfig +from transformers import GenerationConfig, PretrainedConfig from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger @@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig): return config.text_config else: return config + + +def try_get_generation_config( + model: str, + trust_remote_code: bool, + revision: Optional[str] = None, +) -> Optional[GenerationConfig]: + try: + return GenerationConfig.from_pretrained( + model, + revision=revision, + ) + except OSError: # Not found + try: + config = get_config( + model, + trust_remote_code=trust_remote_code, + revision=revision, + ) + return GenerationConfig.from_model_config(config) + except OSError: # Not found + return None