Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Support eos_token_id from config.json #5954

Merged
merged 2 commits into from
Jun 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/tokenization/test_get_eos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
This test file includes some cases where it is inappropriate to
only get the `eos_token_id` from the tokenizer as defined by
:meth:`vllm.LLMEngine._get_eos_token_id`.
"""
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.tokenizer import get_tokenizer


def test_get_llama3_eos_token():
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 128009

generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == [128001, 128009]


def test_get_blip2_eos_token():
model_name = "Salesforce/blip2-opt-2.7b"

tokenizer = get_tokenizer(model_name)
assert tokenizer.eos_token_id == 2

generation_config = try_get_generation_config(model_name,
trust_remote_code=False)
assert generation_config is not None
assert generation_config.eos_token_id == 50118
23 changes: 13 additions & 10 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union

from transformers import GenerationConfig, PreTrainedTokenizer
from transformers import PreTrainedTokenizer

from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig,
LoRAConfig, ModelConfig, ObservabilityConfig,
Expand Down Expand Up @@ -33,6 +33,7 @@
SequenceStatus)
from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context,
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup,
get_tokenizer_group)
Expand All @@ -45,16 +46,18 @@
_LOCAL_LOGGING_INTERVAL_SEC = 5


def _load_generation_config_dict(model_config: ModelConfig):
try:
return GenerationConfig.from_pretrained(
model_config.model,
revision=model_config.revision,
).to_diff_dict()
except OSError:
# Not found.
def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]:
config = try_get_generation_config(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision,
)

if config is None:
return {}

return config.to_diff_dict()


_O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput)

Expand Down
24 changes: 23 additions & 1 deletion vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
from typing import Dict, Optional, Type

from transformers import PretrainedConfig
from transformers import GenerationConfig, PretrainedConfig

from vllm.envs import VLLM_USE_MODELSCOPE
from vllm.logger import init_logger
Expand Down Expand Up @@ -80,3 +80,25 @@ def get_hf_text_config(config: PretrainedConfig):
return config.text_config
else:
return config


def try_get_generation_config(
model: str,
trust_remote_code: bool,
revision: Optional[str] = None,
) -> Optional[GenerationConfig]:
try:
return GenerationConfig.from_pretrained(
model,
revision=revision,
)
except OSError: # Not found
try:
config = get_config(
model,
trust_remote_code=trust_remote_code,
revision=revision,
)
return GenerationConfig.from_model_config(config)
except OSError: # Not found
return None
Loading