From dbeee10ce9f5d1ac2deaac823926c087b888ea91 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 28 Jun 2024 09:37:19 +0000 Subject: [PATCH 1/2] Support `eos_token_id` from `config.json` --- tests/tokenization/test_get_eos.py | 29 +++++++++++++++++++++++++++++ vllm/engine/llm_engine.py | 22 ++++++++++++---------- vllm/transformers_utils/config.py | 22 +++++++++++++++++++++- 3 files changed, 62 insertions(+), 11 deletions(-) create mode 100644 tests/tokenization/test_get_eos.py diff --git a/tests/tokenization/test_get_eos.py b/tests/tokenization/test_get_eos.py new file mode 100644 index 0000000000000..cdca5d04df095 --- /dev/null +++ b/tests/tokenization/test_get_eos.py @@ -0,0 +1,29 @@ +""" +This test file includes some cases where it is inappropriate to +only get the `eos_token_id` from the tokenizer as defined by +:meth:`vllm.LLMEngine._get_eos_token_id`. +""" +from vllm.transformers_utils.config import try_get_generation_config +from vllm.transformers_utils.tokenizer import get_tokenizer + + +def test_get_llama3_eos_token(): + MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" + + tokenizer = get_tokenizer(MODEL_NAME) + assert tokenizer.eos_token_id == 128009 + + generation_config = try_get_generation_config(MODEL_NAME) + assert generation_config is not None + assert generation_config.eos_token_id == [128001, 128009] + + +def test_get_blip2_eos_token(): + MODEL_NAME = "Salesforce/blip2-opt-2.7b" + + tokenizer = get_tokenizer(MODEL_NAME) + assert tokenizer.eos_token_id == 2 + + generation_config = try_get_generation_config(MODEL_NAME) + assert generation_config is not None + assert generation_config.eos_token_id == 50118 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4b427b1fb2f22..7081801efdd11 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,10 +1,10 @@ import time from contextlib import contextmanager -from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional from typing import Sequence as GenericSequence from typing import Set, Type, TypeVar, Union -from transformers import GenerationConfig, PreTrainedTokenizer +from transformers import PreTrainedTokenizer from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, @@ -33,6 +33,7 @@ SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) +from vllm.transformers_utils.config import try_get_generation_config from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -45,16 +46,17 @@ _LOCAL_LOGGING_INTERVAL_SEC = 5 -def _load_generation_config_dict(model_config: ModelConfig): - try: - return GenerationConfig.from_pretrained( - model_config.model, - revision=model_config.revision, - ).to_diff_dict() - except OSError: - # Not found. +def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: + config = try_get_generation_config( + model_config.model, + revision=model_config.revision, + ) + + if config is None: return {} + return config.to_diff_dict() + _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 60fc756a12e3d..4c904dc27e7db 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,7 +1,7 @@ import contextlib from typing import Dict, Optional, Type -from transformers import PretrainedConfig +from transformers import GenerationConfig, PretrainedConfig from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger @@ -80,3 +80,23 @@ def get_hf_text_config(config: PretrainedConfig): return config.text_config else: return config + + +def try_get_generation_config( + model: str, + revision: Optional[str] = None, +) -> Optional[GenerationConfig]: + try: + return GenerationConfig.from_pretrained( + model, + revision=revision, + ) + except OSError: # Not found + try: + return GenerationConfig.from_model_config( + AutoConfig.from_pretrained( + model, + revision=revision, + ), ) + except OSError: # Not found + return None From 161585af4fa362af6f85b6ba1274895b6c1b73d6 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Fri, 28 Jun 2024 09:50:09 +0000 Subject: [PATCH 2/2] Add `trust_remote_code` --- tests/tokenization/test_get_eos.py | 14 ++++++++------ vllm/engine/llm_engine.py | 1 + vllm/transformers_utils/config.py | 12 +++++++----- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/tests/tokenization/test_get_eos.py b/tests/tokenization/test_get_eos.py index cdca5d04df095..875ca19d3b4b7 100644 --- a/tests/tokenization/test_get_eos.py +++ b/tests/tokenization/test_get_eos.py @@ -8,22 +8,24 @@ def test_get_llama3_eos_token(): - MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct" + model_name = "meta-llama/Meta-Llama-3-8B-Instruct" - tokenizer = get_tokenizer(MODEL_NAME) + tokenizer = get_tokenizer(model_name) assert tokenizer.eos_token_id == 128009 - generation_config = try_get_generation_config(MODEL_NAME) + generation_config = try_get_generation_config(model_name, + trust_remote_code=False) assert generation_config is not None assert generation_config.eos_token_id == [128001, 128009] def test_get_blip2_eos_token(): - MODEL_NAME = "Salesforce/blip2-opt-2.7b" + model_name = "Salesforce/blip2-opt-2.7b" - tokenizer = get_tokenizer(MODEL_NAME) + tokenizer = get_tokenizer(model_name) assert tokenizer.eos_token_id == 2 - generation_config = try_get_generation_config(MODEL_NAME) + generation_config = try_get_generation_config(model_name, + trust_remote_code=False) assert generation_config is not None assert generation_config.eos_token_id == 50118 diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7081801efdd11..8ff351f9c99d5 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -49,6 +49,7 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: config = try_get_generation_config( model_config.model, + trust_remote_code=model_config.trust_remote_code, revision=model_config.revision, ) diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4c904dc27e7db..5e2fe116db9c6 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -84,6 +84,7 @@ def get_hf_text_config(config: PretrainedConfig): def try_get_generation_config( model: str, + trust_remote_code: bool, revision: Optional[str] = None, ) -> Optional[GenerationConfig]: try: @@ -93,10 +94,11 @@ def try_get_generation_config( ) except OSError: # Not found try: - return GenerationConfig.from_model_config( - AutoConfig.from_pretrained( - model, - revision=revision, - ), ) + config = get_config( + model, + trust_remote_code=trust_remote_code, + revision=revision, + ) + return GenerationConfig.from_model_config(config) except OSError: # Not found return None