diff --git a/tests/entrypoints/openai/correctness/test_mteb.py b/tests/entrypoints/openai/correctness/test_mteb.py index ebf2f829b583..44d7ac193760 100644 --- a/tests/entrypoints/openai/correctness/test_mteb.py +++ b/tests/entrypoints/openai/correctness/test_mteb.py @@ -4,6 +4,7 @@ import pytest from tests.models.language.pooling.mteb_utils import (MTEB_EMBED_TASKS, + MTEB_EMBED_TOL, OpenAIClientMtebEncoder, run_mteb_embed_task, run_mteb_embed_task_st) @@ -38,4 +39,4 @@ def test_mteb(server): print("SentenceTransformer main score: ", st_main_score) print("Difference: ", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, rel=1e-4) + assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 1019bfd58936..81ca65b6541a 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -11,7 +11,8 @@ from vllm.entrypoints.openai.protocol import EmbeddingResponse from vllm.transformers_utils.tokenizer import get_tokenizer -from ...models.utils import run_embedding_correctness_test +from ...models.language.pooling.embed_utils import ( + run_embedding_correctness_test) from ...utils import RemoteOpenAIServer MODEL_NAME = "intfloat/multilingual-e5-small" diff --git a/tests/entrypoints/openai/test_embedding_dimensions.py b/tests/entrypoints/openai/test_embedding_dimensions.py index 332fa332a4a4..341defae0b31 100644 --- a/tests/entrypoints/openai/test_embedding_dimensions.py +++ b/tests/entrypoints/openai/test_embedding_dimensions.py @@ -11,7 +11,9 @@ from vllm.entrypoints.openai.protocol import EmbeddingResponse from ...conftest import HfRunner -from ...models.utils import EmbedModelInfo, run_embedding_correctness_test +from ...models.language.pooling.embed_utils import ( + run_embedding_correctness_test) +from ...models.utils import EmbedModelInfo from ...utils import RemoteOpenAIServer MODELS = [ diff --git a/tests/models/language/pooling/embed_utils.py b/tests/models/language/pooling/embed_utils.py new file mode 100644 index 000000000000..0c8ac2ab1b9e --- /dev/null +++ b/tests/models/language/pooling/embed_utils.py @@ -0,0 +1,72 @@ +# SPDX-License-Identifier: Apache-2.0 +from collections.abc import Sequence +from typing import Optional + +import pytest + +from tests.conftest import HfRunner +from tests.models.utils import (EmbedModelInfo, check_embeddings_close, + matryoshka_fy) + + +def run_embedding_correctness_test( + hf_model: "HfRunner", + inputs: list[str], + vllm_outputs: Sequence[list[float]], + dimensions: Optional[int] = None, +): + hf_outputs = hf_model.encode(inputs) + if dimensions: + hf_outputs = matryoshka_fy(hf_outputs, dimensions) + + check_embeddings_close( + embeddings_0_lst=hf_outputs, + embeddings_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + tol=1e-2, + ) + + +def correctness_test_embed_models(hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + example_prompts, + vllm_extra_kwargs=None, + hf_model_callback=None): + if not model_info.enable_test: + # A model family has many models with the same architecture, + # and we don't need to test each one. + pytest.skip("Skipping test.") + + # The example_prompts has ending "\n", for example: + # "Write a short story about a robot that dreams for the first time.\n" + # sentence_transformers will strip the input texts, see: + # https://github.com/UKPLab/sentence-transformers/blob/v3.1.1/sentence_transformers/models/Transformer.py#L159 + # This makes the input_ids different between hf_model and vllm_model. + # So we need to strip the input texts to avoid test failing. + example_prompts = [str(s).strip() for s in example_prompts] + + vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = model_info.dtype + + with vllm_runner(model_info.name, + task="embed", + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + vllm_dtype = vllm_model.model.llm_engine.model_config.dtype + model_dtype = getattr( + vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype", + vllm_dtype) + + with hf_runner( + model_info.name, + dtype=model_dtype, + is_sentence_transformer=True, + ) as hf_model: + + if hf_model_callback is not None: + hf_model_callback(hf_model) + + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index f83c9940d524..f4837ae952c3 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -80,18 +80,19 @@ def run_mteb_embed_task_st(model_name, tasks): def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo, - vllm_extra_kwargs=None): + vllm_extra_kwargs=None, + hf_model_callback=None): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. pytest.skip("Skipping test.") vllm_extra_kwargs = vllm_extra_kwargs or {} + vllm_extra_kwargs["dtype"] = model_info.dtype with vllm_runner(model_info.name, task="embed", max_model_len=None, - dtype=model_info.dtype, **vllm_extra_kwargs) as vllm_model: if model_info.architecture: @@ -108,10 +109,14 @@ def mteb_test_embed_models(hf_runner, with set_default_torch_dtype(model_dtype) and hf_runner( model_info.name, is_sentence_transformer=True, dtype=model_dtype) as hf_model: + + if hf_model_callback is not None: + hf_model_callback(hf_model) + st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) print("VLLM:", vllm_dtype, vllm_main_score) print("SentenceTransformer:", model_dtype, st_main_score) print("Difference:", st_main_score - vllm_main_score) - assert st_main_score == pytest.approx(vllm_main_score, rel=MTEB_EMBED_TOL) + assert st_main_score == pytest.approx(vllm_main_score, abs=MTEB_EMBED_TOL) diff --git a/tests/models/language/pooling/test_baai.py b/tests/models/language/pooling/test_baai.py new file mode 100644 index 000000000000..fc0e8207954f --- /dev/null +++ b/tests/models/language/pooling/test_baai.py @@ -0,0 +1,71 @@ +# SPDX-License-Identifier: Apache-2.0 +import pytest + +from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from .mteb_utils import mteb_test_embed_models + +MODELS = [ + ########## BertModel + EmbedModelInfo("BAAI/bge-base-en", + architecture="BertModel", + enable_test=True), + EmbedModelInfo("BAAI/bge-base-zh", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-small-en", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-small-zh", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-large-en", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-large-zh", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-large-zh-noinstruct", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-base-en-v1.5", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-base-zh-v1.5", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-small-en-v1.5", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-small-zh-v1.5", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-large-en-v1.5", + architecture="BertModel", + enable_test=False), + EmbedModelInfo("BAAI/bge-large-zh-v1.5", + architecture="BertModel", + enable_test=False), + ########## XLMRobertaModel + EmbedModelInfo("BAAI/bge-m3", + architecture="XLMRobertaModel", + enable_test=True), + ########## Qwen2Model + EmbedModelInfo("BAAI/bge-code-v1", + architecture="Qwen2Model", + dtype="float32", + enable_test=True), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + mteb_test_embed_models(hf_runner, vllm_runner, model_info) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_embed_models_correctness(hf_runner, vllm_runner, + model_info: EmbedModelInfo, + example_prompts) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, + example_prompts) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py index 91d10f529cd6..18b27a688146 100644 --- a/tests/models/language/pooling/test_gte.py +++ b/tests/models/language/pooling/test_gte.py @@ -3,7 +3,8 @@ import pytest -from ...utils import EmbedModelInfo, run_embedding_correctness_test +from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from .mteb_utils import mteb_test_embed_models MODELS = [ ########## BertModel @@ -53,9 +54,8 @@ @pytest.mark.parametrize("model_info", MODELS) -def test_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - from .mteb_utils import mteb_test_embed_models +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": @@ -66,28 +66,13 @@ def test_models_mteb(hf_runner, vllm_runner, @pytest.mark.parametrize("model_info", MODELS) -def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, - example_prompts) -> None: - if not model_info.enable_test: - pytest.skip("Skipping test.") - - # ST will strip the input texts, see test_embedding.py - example_prompts = [str(s).strip() for s in example_prompts] +def test_embed_models_correctness(hf_runner, vllm_runner, + model_info: EmbedModelInfo, + example_prompts) -> None: vllm_extra_kwargs: dict[str, Any] = {} if model_info.architecture == "GteNewModel": vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} - with vllm_runner(model_info.name, - task="embed", - dtype=model_info.dtype, - max_model_len=None, - **vllm_extra_kwargs) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) - - with hf_runner( - model_info.name, - dtype=model_info.dtype, - is_sentence_transformer=True, - ) as hf_model: - run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) + correctness_test_embed_models(hf_runner, vllm_runner, model_info, + example_prompts, vllm_extra_kwargs) diff --git a/tests/models/language/pooling/test_jina.py b/tests/models/language/pooling/test_jina.py index 0ddff2146caa..0403a20a445a 100644 --- a/tests/models/language/pooling/test_jina.py +++ b/tests/models/language/pooling/test_jina.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 +from functools import partial + import pytest from vllm import PoolingParams -from ...utils import check_embeddings_close, matryoshka_fy +from .embed_utils import (EmbedModelInfo, check_embeddings_close, + correctness_test_embed_models, matryoshka_fy) +from .mteb_utils import mteb_test_embed_models SCORING_MODELS = [ "jinaai/jina-reranker-v2-base-multilingual", # Roberta @@ -25,16 +29,10 @@ ] EMBEDDING_MODELS = [ - "jinaai/jina-embeddings-v3", -] - -EMBEDDING_PROMPTS = [ - "Follow the white rabbit.", # English - "Sigue al conejo blanco.", # Spanish - "Suis le lapin blanc.", # French - "跟着白兔走。", # Chinese - "اتبع الأرنب الأبيض.", # Arabic - "Folge dem weißen Kaninchen.", # German + EmbedModelInfo("jinaai/jina-embeddings-v3", + architecture="XLMRobertaModel", + is_matryoshka=True, + dtype="float32") ] @@ -80,73 +78,66 @@ def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): assert hf_outputs[1] == pytest.approx(vllm_outputs[1], rel=0.01) -@pytest.fixture(scope="module", params=EMBEDDING_MODELS) -def emb_model_name(request): - yield request.param +@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + def hf_model_callback(model): + model.encode = partial(model.encode, task="text-matching") -def test_is_matryoshka(vllm_runner, emb_model_name): - with vllm_runner(emb_model_name, task="embed", - max_model_len=None) as vllm_model: - assert vllm_model.model.llm_engine.model_config.is_matryoshka - - -@pytest.mark.parametrize("model", EMBEDDING_MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -def test_embeddings( - hf_runner, - vllm_runner, - model, - dtype: str, - monkeypatch, -) -> None: + mteb_test_embed_models(hf_runner, + vllm_runner, + model_info, + hf_model_callback=hf_model_callback) - example_prompts = EMBEDDING_PROMPTS - with hf_runner( - model, - dtype=dtype, - is_sentence_transformer=True, - ) as hf_model: - hf_outputs = hf_model.encode(example_prompts, task="text-matching") +@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) +def test_embed_models_correctness(hf_runner, vllm_runner, + model_info: EmbedModelInfo, + example_prompts) -> None: - with vllm_runner(model, task="embed", dtype=dtype, - max_model_len=None) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) + def hf_model_callback(model): + model.encode = partial(model.encode, task="text-matching") - check_embeddings_close( - embeddings_0_lst=hf_outputs, - embeddings_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - tol=1e-2, - ) + correctness_test_embed_models(hf_runner, + vllm_runner, + model_info, + example_prompts, + hf_model_callback=hf_model_callback) -@pytest.mark.parametrize("model", EMBEDDING_MODELS) +@pytest.mark.parametrize("model_info", EMBEDDING_MODELS) @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dimensions", [16, 32]) def test_matryoshka( hf_runner, vllm_runner, - model, + model_info, dtype: str, dimensions: int, + example_prompts, monkeypatch, ) -> None: + if not model_info.is_matryoshka: + pytest.skip("Model is not matryoshka") - example_prompts = EMBEDDING_PROMPTS + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] with hf_runner( - model, + model_info.name, dtype=dtype, is_sentence_transformer=True, ) as hf_model: hf_outputs = hf_model.encode(example_prompts, task="text-matching") hf_outputs = matryoshka_fy(hf_outputs, dimensions) - with vllm_runner(model, task="embed", dtype=dtype, + with vllm_runner(model_info.name, + task="embed", + dtype=dtype, max_model_len=None) as vllm_model: + assert vllm_model.model.llm_engine.model_config.is_matryoshka + matryoshka_dimensions = ( vllm_model.model.llm_engine.model_config.matryoshka_dimensions) assert matryoshka_dimensions is not None diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py index 28df32e0c230..92cd7cc569d3 100644 --- a/tests/models/language/pooling/test_nomic.py +++ b/tests/models/language/pooling/test_nomic.py @@ -2,7 +2,8 @@ import pytest -from ...utils import EmbedModelInfo, run_embedding_correctness_test +from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from .mteb_utils import mteb_test_embed_models MODELS = [ EmbedModelInfo("nomic-ai/nomic-embed-text-v1", @@ -13,6 +14,9 @@ architecture="NomicBertModel", dtype="float32", enable_test=False), + EmbedModelInfo("nomic-ai/CodeRankEmbed", + architecture="NomicBertModel", + enable_test=False), EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe", architecture="NomicBertModel", dtype="float32", @@ -21,30 +25,14 @@ @pytest.mark.parametrize("model_info", MODELS) -def test_models_mteb(hf_runner, vllm_runner, - model_info: EmbedModelInfo) -> None: - from .mteb_utils import mteb_test_embed_models +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, - example_prompts) -> None: - if not model_info.enable_test: - pytest.skip("Skipping test.") - - # ST will strip the input texts, see test_embedding.py - example_prompts = [str(s).strip() for s in example_prompts] - - with vllm_runner(model_info.name, - task="embed", - dtype=model_info.dtype, - max_model_len=None) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) - - with hf_runner( - model_info.name, - dtype=model_info.dtype, - is_sentence_transformer=True, - ) as hf_model: - run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) +def test_embed_models_correctness(hf_runner, vllm_runner, + model_info: EmbedModelInfo, + example_prompts) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, + example_prompts) diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index 5679e0e1ce00..c6c2d1e7a679 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -2,7 +2,8 @@ import pytest -from ...utils import EmbedModelInfo, run_embedding_correctness_test +from .embed_utils import EmbedModelInfo, correctness_test_embed_models +from .mteb_utils import mteb_test_embed_models MODELS = [ EmbedModelInfo("Snowflake/snowflake-arctic-embed-xs", @@ -41,37 +42,14 @@ @pytest.mark.parametrize("model_info", MODELS) -def test_models_mteb( - hf_runner, - vllm_runner, - model_info: EmbedModelInfo, -) -> None: - from .mteb_utils import mteb_test_embed_models +def test_embed_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: mteb_test_embed_models(hf_runner, vllm_runner, model_info) @pytest.mark.parametrize("model_info", MODELS) -def test_models_correctness( - hf_runner, - vllm_runner, - model_info: EmbedModelInfo, - example_prompts, -) -> None: - if not model_info.enable_test: - pytest.skip("Skipping test.") - - # ST will strip the input texts, see test_embedding.py - example_prompts = [str(s).strip() for s in example_prompts] - - with vllm_runner(model_info.name, - task="embed", - dtype=model_info.dtype, - max_model_len=None) as vllm_model: - vllm_outputs = vllm_model.encode(example_prompts) - - with hf_runner( - model_info.name, - dtype=model_info.dtype, - is_sentence_transformer=True, - ) as hf_model: - run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) +def test_embed_models_correctness(hf_runner, vllm_runner, + model_info: EmbedModelInfo, + example_prompts) -> None: + correctness_test_embed_models(hf_runner, vllm_runner, model_info, + example_prompts) diff --git a/tests/models/registry.py b/tests/models/registry.py index a49e3ad6b20e..18342b671e0d 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -283,7 +283,7 @@ def check_available_online( "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", trust_remote_code=True), - "NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501 + "NomicBertModel": _HfExamplesInfo("nomic-ai/nomic-embed-text-v2-moe", trust_remote_code=True), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), diff --git a/tests/models/utils.py b/tests/models/utils.py index a43fd77c6d79..ac1fc6c8f0e2 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -2,7 +2,7 @@ import warnings from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union +from typing import Any, NamedTuple, Optional, Union import torch import torch.nn.functional as F @@ -13,9 +13,6 @@ from .registry import HF_EXAMPLE_MODELS -if TYPE_CHECKING: - from ..conftest import HfRunner - TokensText = tuple[list[int], str] @@ -337,22 +334,3 @@ class EmbedModelInfo(NamedTuple): architecture: str = "" dtype: str = "auto" enable_test: bool = True - - -def run_embedding_correctness_test( - hf_model: "HfRunner", - inputs: list[str], - vllm_outputs: Sequence[list[float]], - dimensions: Optional[int] = None, -): - hf_outputs = hf_model.encode(inputs) - if dimensions: - hf_outputs = matryoshka_fy(hf_outputs, dimensions) - - check_embeddings_close( - embeddings_0_lst=hf_outputs, - embeddings_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - tol=1e-2, - ) diff --git a/vllm/config.py b/vllm/config.py index 4afdda3cca64..295297cfbf9a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -572,13 +572,7 @@ def __post_init__(self) -> None: sliding_window = None self.original_max_model_len = self.max_model_len - self.max_model_len = _get_and_verify_max_len( - hf_config=self.hf_text_config, - max_model_len=self.max_model_len, - disable_sliding_window=self.disable_sliding_window, - sliding_window_len=self.get_hf_config_sliding_window(), - spec_target_max_model_len=self.spec_target_max_model_len, - encoder_config=self.encoder_config) + self.max_model_len = self.get_and_verify_max_len(self.max_model_len) self.served_model_name = get_served_model_name(self.model, self.served_model_name) self.multimodal_config = self._init_multimodal_config() @@ -1387,6 +1381,16 @@ def is_matryoshka(self) -> bool: def matryoshka_dimensions(self): return getattr(self.hf_config, "matryoshka_dimensions", None) + def get_and_verify_max_len(self, max_model_len: int): + max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + spec_target_max_model_len=self.spec_target_max_model_len, + encoder_config=self.encoder_config) + return max_model_len + BlockSize = Literal[1, 8, 16, 32, 64, 128] CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2"] @@ -4474,13 +4478,7 @@ def _set_cudagraph_sizes(self): def recalculate_max_model_len(self, max_model_len: int): model_config = self.model_config - max_model_len = _get_and_verify_max_len( - hf_config=model_config.hf_text_config, - max_model_len=max_model_len, - disable_sliding_window=model_config.disable_sliding_window, - sliding_window_len=model_config.get_hf_config_sliding_window(), - spec_target_max_model_len=model_config.spec_target_max_model_len, - encoder_config=model_config.encoder_config) + max_model_len = model_config.get_and_verify_max_len(max_model_len) self.model_config.max_model_len = max_model_len self.scheduler_config.max_model_len = max_model_len self.compute_hash()