Skip to content

Commit

Permalink
[Bugfix] Support eos_token_id from config.json (vllm-project#5954)
Browse files Browse the repository at this point in the history
Signed-off-by: Alvant <alvasian@yandex.ru>
  • Loading branch information
DarkLight1337 authored and Alvant committed Oct 26, 2024
1 parent a5587b2 commit cf1d8b7
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 11 deletions.
31 changes: 31 additions & 0 deletions tests/tokenization/test_get_eos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer


def test_get_llama3_eos_token():
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009

generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128009]


def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"

tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2

generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118
23 changes: 13 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union

from transformers import GenerationConfig, PreTrainedTokenizer
from transformers import PreTrainedTokenizer

from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig,
Expand Down Expand Up @@ -34,6 +34,7 @@
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
Expand All @@ -46,16 +47,18 @@
_LOCAL_LOGGING_INTERVAL_SEC = 5


def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)

if config is None:
return {}

return config.to_diff_dict()


_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)

Expand Down
24 changes: 23 additions & 1 deletion vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
from typing import Dict, Optional, Type

from transformers import PretrainedConfig
from transformers import GenerationConfig, PretrainedConfig

from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
Expand Down Expand Up @@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig):
return config.text_config
else:
return config


def try_get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
) -> Optional[GenerationConfig]:
try:
return GenerationConfig.from_pretrained(
model,
revision=revision,
)
except OSError: # Not found
try:
config = get_config(
model,
trust_remote_code=trust_remote_code,
revision=revision,
)
return GenerationConfig.from_model_config(config)
except OSError: # Not found
return None

0 comments on commit cf1d8b7

Please sign in to comment.