Skip to content
Merged
46 changes: 46 additions & 0 deletions examples/offline_inference/context_extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# SPDX-License-Identifier: Apache-2.0

from vllm import LLM, SamplingParams

rope_theta = 1000000
original_max_position_embeddings = 32768
factor = 4.0

# Use yarn to extend context
hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings": original_max_position_embeddings,
},
"max_model_len": int(original_max_position_embeddings * factor),
}

llm = LLM(model="Qwen/Qwen3-0.6B", hf_overrides=hf_overrides)

sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=128,
)

conversation = [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hello! How can I assist you today?"},
]
outputs = llm.chat(conversation, sampling_params, use_tqdm=False)


def print_outputs(outputs):
print("\nGenerated Outputs:\n" + "-" * 80)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}\n")
print(f"Generated text: {generated_text!r}")
print("-" * 80)


print_outputs(outputs)
130 changes: 130 additions & 0 deletions tests/models/language/pooling/test_nomic_max_model_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-License-Identifier: Apache-2.0
# ruff: noqa: SIM117
import pytest

from ...utils import EmbedModelInfo

MODELS = [
EmbedModelInfo("nomic-ai/nomic-embed-text-v1"),
#EmbedModelInfo("nomic-ai/nomic-embed-text-v1.5"),
#EmbedModelInfo("nomic-ai/CodeRankEmbed"),
EmbedModelInfo("nomic-ai/nomic-embed-text-v2-moe"),
#EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-long"),
]

rope_theta = 1000
factor = 4.0
original_max_position_embeddings = 2048
max_model_len = int(original_max_position_embeddings * factor)


@pytest.mark.parametrize("model_info", MODELS)
def test_default(model_info, vllm_runner):
with vllm_runner(model_info.name, task="embed",
max_model_len=None) as vllm_model:
model_config = vllm_model.model.llm_engine.model_config
if model_info.name == "nomic-ai/nomic-embed-text-v2-moe":
# For nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
assert model_config.max_model_len == 512
else:
assert (
model_config.max_model_len == original_max_position_embeddings)


@pytest.mark.parametrize("model_info", MODELS)
def test_set_max_model_len_legal(model_info, vllm_runner):
# set max_model_len <= 512
with vllm_runner(model_info.name, task="embed",
max_model_len=256) as vllm_model:
model_config = vllm_model.model.llm_engine.model_config
assert model_config.max_model_len == 256

# set 512 < max_model_len <= 2048
if model_info.name == "nomic-ai/nomic-embed-text-v2-moe":
# For nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
with pytest.raises(ValueError):
with vllm_runner(model_info.name, task="embed",
max_model_len=1024):
pass
else:
with vllm_runner(model_info.name, task="embed",
max_model_len=1024) as vllm_model:
model_config = vllm_model.model.llm_engine.model_config
assert model_config.max_model_len == 1024


@pytest.mark.parametrize("model_info", MODELS)
def test_set_max_model_len_illegal(model_info, vllm_runner):
# set max_model_len > 2048
with pytest.raises(ValueError):
with vllm_runner(model_info.name, task="embed", max_model_len=4096):
pass

# set max_model_len > 2048 by hf_overrides
hf_overrides = {"max_model_len": 4096}
with pytest.raises(ValueError):
with vllm_runner(model_info.name,
task="embed",
max_model_len=None,
hf_overrides=hf_overrides):
pass


@pytest.mark.parametrize("model_info", MODELS)
def test_use_rope_scaling_legal(model_info, vllm_runner):
hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings":
original_max_position_embeddings
},
"max_model_len": max_model_len
}

with vllm_runner(model_info.name,
task="embed",
max_model_len=None,
hf_overrides=hf_overrides):
pass


@pytest.mark.parametrize("model_info", MODELS)
def test_use_rope_scaling_illegal(model_info, vllm_runner):
hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings":
original_max_position_embeddings
}
}
# illegal max_model_len
with pytest.raises(ValueError):
with vllm_runner(model_info.name,
task="embed",
max_model_len=max_model_len + 1,
hf_overrides=hf_overrides):
pass

hf_overrides = {
"rope_theta": rope_theta,
"rope_scaling": {
"rope_type": "yarn",
"factor": factor,
"original_max_position_embeddings":
original_max_position_embeddings
},
"max_model_len": max_model_len + 1
}
# illegal max_model_len by hf_overrides
with pytest.raises(ValueError):
with vllm_runner(model_info.name,
task="embed",
max_model_len=None,
hf_overrides=hf_overrides):
pass
14 changes: 14 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ def __post_init__(self) -> None:

sliding_window = None

self.original_max_model_len = self.max_model_len
self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=self.max_model_len,
Expand Down Expand Up @@ -4471,6 +4472,19 @@ def _set_cudagraph_sizes(self):
self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)

def recalculate_max_model_len(self, max_model_len: int):
model_config = self.model_config
max_model_len = _get_and_verify_max_len(
hf_config=model_config.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=model_config.disable_sliding_window,
sliding_window_len=model_config.get_hf_config_sliding_window(),
spec_target_max_model_len=model_config.spec_target_max_model_len,
encoder_config=model_config.encoder_config)
self.model_config.max_model_len = max_model_len
self.scheduler_config.max_model_len = max_model_len
self.compute_hash()

def __str__(self):
return (
f"model={self.model_config.model!r},"
Expand Down
55 changes: 52 additions & 3 deletions vllm/model_executor/models/bert_with_rope.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable
from copy import deepcopy
from typing import Optional

import torch
Expand All @@ -10,6 +11,7 @@
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import (get_act_and_mul_fn,
get_act_fn)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand All @@ -27,6 +29,8 @@
from vllm.model_executor.models.utils import WeightsMapper
from vllm.sequence import IntermediateTensors

logger = init_logger(__name__)


class BertWithRopeEmbedding(nn.Module):

Expand Down Expand Up @@ -513,10 +517,11 @@ def config_verify(self, vllm_config):

head_dim = config.hidden_size // config.num_attention_heads
rotary_emb_dim = head_dim * config.rotary_emb_fraction
max_trained_positions = getattr(config, "max_trained_positions", 2048)
config.rotary_kwargs = {
"head_size": head_dim,
"rotary_dim": rotary_emb_dim,
"max_position": config.max_trained_positions,
"max_position": max_trained_positions,
"base": getattr(config, "rope_theta", config.rotary_emb_base),
"rope_scaling": getattr(config, "rope_scaling", None)
}
Expand All @@ -525,8 +530,52 @@ def config_verify(self, vllm_config):
# than max_trained_positions 2048, the results are consistent
# with SentenceTransformer.
# The context extension uses vllm style rope_theta and rope_scaling.
# See #17785

# See #17785 #18755
if (not vllm_config.model_config.hf_overrides
and vllm_config.model_config.original_max_model_len is None):
# Default
# Reset max_model_len to max_trained_positions.
# nomic-embed-text-v2-moe the length is set to 512
# by sentence_bert_config.json.
max_model_len_before = vllm_config.model_config.max_model_len
max_model_len = min(vllm_config.model_config.max_model_len,
max_trained_positions)

vllm_config.recalculate_max_model_len(max_model_len)
logger.warning(
"Nomic context extension is disabled. "
"Changing max_model_len from %s to %s. "
"To enable context extension, see: "
"https://github.com/vllm-project/vllm/tree/main/examples/offline_inference/context_extension.html",
max_model_len_before, vllm_config.model_config.max_model_len)
else:
# We need to re-verify max_model_len to avoid lengths
# greater than position_embedding.
model_config = vllm_config.model_config
hf_text_config = model_config.hf_text_config

if isinstance(model_config.hf_overrides, dict):
# hf_overrides_kw
max_model_len = model_config.hf_overrides.get(
"max_model_len", vllm_config.model_config.max_model_len)
else:
# hf_overrides_fn
# This might be overridden by sentence_bert_config.json.
max_model_len = vllm_config.model_config.max_model_len

# reset hf_text_config for recalculate_max_model_len.
if hasattr(hf_text_config, "max_model_len"):
delattr(hf_text_config, "max_model_len")
hf_text_config.max_position_embeddings = max_trained_positions
hf_text_config.rope_scaling = config.rotary_kwargs["rope_scaling"]

# The priority of sentence_bert_config.json is higher
# than max_position_embeddings
encoder_config = deepcopy(model_config.encoder_config)
encoder_config.pop("max_seq_length", None)
model_config.encoder_config = encoder_config

vllm_config.recalculate_max_model_len(max_model_len)
return config


Expand Down