diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index e20521df027a..4bb831749287 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -701,12 +701,22 @@ Specified using `--task embed`. * ✅︎ * ✅︎ - * `GteModel` - * GteModel + * Arctic-Embed-2.0-M * `Snowflake/snowflake-arctic-embed-m-v2.0`. * * ︎ +- * `GteNewModel` + * mGTE-TRM (see note) + * `Alibaba-NLP/gte-multilingual-base`, etc. + * ︎ + * ︎ +- * `ModernBertModel` + * ModernBERT-based + * `Alibaba-NLP/gte-modernbert-base`, etc. + * ︎ + * ︎ - * `NomicBertModel` - * NomicBertModel + * Nomic BERT * `nomic-ai/nomic-embed-text-v1`, `nomic-ai/nomic-embed-text-v2-moe`, `Snowflake/snowflake-arctic-embed-m-long`, etc. * ︎ * ︎ @@ -749,6 +759,10 @@ See [relevant issue on HF Transformers](https://github.com/huggingface/transform `jinaai/jina-embeddings-v3` supports multiple tasks through lora, while vllm temporarily only supports text-matching tasks by merging lora weights. ::: +:::{note} +The second-generation GTE model (mGTE-TRM) is named `NewModel`. The name `NewModel` is too generic, you should set `--hf-overrides '{"architectures": ["GteNewModel"]}'` to specify the use of the `GteNewModel` architecture. +::: + If your model is not in the above list, we will try to automatically convert the model using {func}`~vllm.model_executor.models.adapters.as_embedding_model`. By default, the embeddings of the whole prompt are extracted from the normalized hidden state corresponding to the last token. diff --git a/tests/models/language/pooling/mteb_utils.py b/tests/models/language/pooling/mteb_utils.py index eedf310d034a..7de2a9af2f2e 100644 --- a/tests/models/language/pooling/mteb_utils.py +++ b/tests/models/language/pooling/mteb_utils.py @@ -7,6 +7,7 @@ import pytest from tests.models.utils import EmbedModelInfo +from vllm.model_executor.model_loader.utils import set_default_torch_dtype # Most models on the STS12 task (See #17175): # - Model implementation and minor changes in tensor dtype @@ -77,16 +78,22 @@ def run_mteb_embed_task_st(model_name, tasks): return run_mteb_embed_task(model, tasks) -def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo): +def mteb_test_embed_models(hf_runner, + vllm_runner, + model_info: EmbedModelInfo, + vllm_extra_kwargs=None): if not model_info.enable_test: # A model family has many models with the same architecture, # and we don't need to test each one. pytest.skip("Skipping test.") + vllm_extra_kwargs = vllm_extra_kwargs or {} + with vllm_runner(model_info.name, task="embed", max_model_len=None, - dtype=model_info.dtype) as vllm_model: + dtype=model_info.dtype, + **vllm_extra_kwargs) as vllm_model: if model_info.architecture: assert (model_info.architecture @@ -99,9 +106,9 @@ def mteb_test_embed_models(hf_runner, vllm_runner, model_info: EmbedModelInfo): vllm_model.model.llm_engine.model_config.hf_config, "torch_dtype", vllm_dtype) - with hf_runner(model_info.name, - is_sentence_transformer=True, - dtype=model_dtype) as hf_model: + with set_default_torch_dtype(model_dtype) and hf_runner( + model_info.name, is_sentence_transformer=True, + dtype=model_dtype) as hf_model: st_main_score = run_mteb_embed_task(hf_model, MTEB_EMBED_TASKS) print("VLLM:", vllm_dtype, vllm_main_score) diff --git a/tests/models/language/pooling/test_gte.py b/tests/models/language/pooling/test_gte.py new file mode 100644 index 000000000000..3ccf2999664c --- /dev/null +++ b/tests/models/language/pooling/test_gte.py @@ -0,0 +1,104 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import pytest + +from ...utils import EmbedModelInfo, run_embedding_correctness_test + +MODELS = [ + ########## BertModel + EmbedModelInfo("thenlper/gte-large", + architecture="BertModel", + dtype="float32", + enable_test=True), + EmbedModelInfo("thenlper/gte-base", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-small", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-large-zh", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-base-zh", + architecture="BertModel", + dtype="float32", + enable_test=False), + EmbedModelInfo("thenlper/gte-small-zh", + architecture="BertModel", + dtype="float32", + enable_test=False), + ########### NewModel + EmbedModelInfo("Alibaba-NLP/gte-multilingual-base", + architecture="GteNewModel", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-base-en-v1.5", + architecture="GteNewModel", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-large-en-v1.5", + architecture="GteNewModel", + enable_test=True), + ########### Qwen2ForCausalLM + EmbedModelInfo("Alibaba-NLP/gte-Qwen2-1.5B-instruct", + architecture="Qwen2ForCausalLM", + enable_test=True), + EmbedModelInfo("Alibaba-NLP/gte-Qwen2-7B-instruct", + architecture="Qwen2ForCausalLM", + enable_test=False), + ########## ModernBertModel + EmbedModelInfo("Alibaba-NLP/gte-modernbert-base", + architecture="ModernBertModel", + enable_test=True), +] + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_mteb(hf_runner, vllm_runner, + model_info: EmbedModelInfo) -> None: + pytest.skip("Skipping mteb test.") + + from .mteb_utils import mteb_test_embed_models + + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": + vllm_extra_kwargs["hf_overrides"] = {"is_causal": True} + + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + mteb_test_embed_models(hf_runner, vllm_runner, model_info, + vllm_extra_kwargs) + + +@pytest.mark.parametrize("model_info", MODELS) +def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, + example_prompts) -> None: + if not model_info.enable_test: + pytest.skip("Skipping test.") + + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] + + vllm_extra_kwargs: dict[str, Any] = {} + if model_info.name == "Alibaba-NLP/gte-Qwen2-1.5B-instruct": + vllm_extra_kwargs["hf_overrides"] = {"is_causal": True} + + if model_info.architecture == "GteNewModel": + vllm_extra_kwargs["hf_overrides"] = {"architectures": ["GteNewModel"]} + + with vllm_runner(model_info.name, + task="embed", + dtype=model_info.dtype, + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + with hf_runner( + model_info.name, + dtype=model_info.dtype, + is_sentence_transformer=True, + ) as hf_model: + run_embedding_correctness_test(hf_model, example_prompts, vllm_outputs) diff --git a/tests/models/language/pooling/test_nomic.py b/tests/models/language/pooling/test_nomic.py index f1ed0d49498b..6e9de30f977d 100644 --- a/tests/models/language/pooling/test_nomic.py +++ b/tests/models/language/pooling/test_nomic.py @@ -23,6 +23,7 @@ @pytest.mark.parametrize("model_info", MODELS) def test_models_mteb(hf_runner, vllm_runner, model_info: EmbedModelInfo) -> None: + pytest.skip("Skipping mteb test.") from .mteb_utils import mteb_test_embed_models mteb_test_embed_models(hf_runner, vllm_runner, model_info) @@ -33,6 +34,9 @@ def test_models_correctness(hf_runner, vllm_runner, model_info: EmbedModelInfo, if not model_info.enable_test: pytest.skip("Skipping test.") + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] + with vllm_runner(model_info.name, task="embed", dtype=model_info.dtype, diff --git a/tests/models/language/pooling/test_snowflake_arctic_embed.py b/tests/models/language/pooling/test_snowflake_arctic_embed.py index c68aa008e854..7d9c3c73d852 100644 --- a/tests/models/language/pooling/test_snowflake_arctic_embed.py +++ b/tests/models/language/pooling/test_snowflake_arctic_embed.py @@ -46,6 +46,7 @@ def test_models_mteb( vllm_runner, model_info: EmbedModelInfo, ) -> None: + pytest.skip("Skipping mteb test.") from .mteb_utils import mteb_test_embed_models mteb_test_embed_models(hf_runner, vllm_runner, model_info) @@ -60,6 +61,9 @@ def test_models_correctness( if not model_info.enable_test: pytest.skip("Skipping test.") + # ST will strip the input texts, see test_embedding.py + example_prompts = [str(s).strip() for s in example_prompts] + with vllm_runner(model_info.name, task="embed", dtype=model_info.dtype, diff --git a/tests/models/registry.py b/tests/models/registry.py index 8e6422ae1a88..39b9795e7e16 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -256,11 +256,17 @@ def check_available_online( "GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"), "GteModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-v2.0", trust_remote_code=True), + "GteNewModel": _HfExamplesInfo("Alibaba-NLP/gte-base-en-v1.5", + trust_remote_code=True, + hf_overrides={"architectures": + ["GteNewModel"]}), "InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward", trust_remote_code=True), "JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501 "LlamaModel": _HfExamplesInfo("llama", is_available_online=False), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), + "ModernBertModel": _HfExamplesInfo("Alibaba-NLP/gte-modernbert-base", + trust_remote_code=True), "NomicBertModel": _HfExamplesInfo("Snowflake/snowflake-arctic-embed-m-long", # noqa: E501 trust_remote_code=True), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f082afb7e9c0..a32c26317a88 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -354,7 +354,7 @@ def get_act_fn(act_fn_name: str) -> nn.Module: _ACTIVATION_AND_MUL_REGISTRY = LazyDict({ "gelu": lambda: GeluAndMul(), "silu": lambda: SiluAndMul(), - "gelu_and_mul": lambda: GeluAndMul(), + "geglu": lambda: GeluAndMul(), }) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 7e0d65684229..70463ecd90ae 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -456,6 +456,40 @@ def scaling_factor_to_offset(self) -> dict[float, int]: return self._scaling_factor_to_offset +class NTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with fixed and mixed NTK scaling. + https://kexue.fm/archives/9706 """ + + def __init__(self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + mixed_b: Optional[float] = None) -> None: + self.scaling_factor = scaling_factor + self.mixed_b = mixed_b + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + base = self.base * (self.scaling_factor if self.mixed_b is None else 1) + inv_freq = super()._compute_inv_freq(base) + + if self.mixed_b is None: + inv_freq = inv_freq / self.scaling_factor**(2 / self.rotary_dim) + else: + a = torch.tensor(self.scaling_factor).log() / (self.rotary_dim / + 2)**self.mixed_b + lambda_1_m = (a * torch.arange( + 1, self.rotary_dim // 2 + 1).float()**self.mixed_b).exp() + inv_freq = inv_freq / lambda_1_m + + return inv_freq + + class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): """RotaryEmbedding extended with Dynamic NTK scaling. @@ -1765,6 +1799,14 @@ def get_rope( max_position, base, is_neox_style, scaling_factor, dtype) + elif scaling_type == "ntk": + scaling_factor = rope_scaling["factor"] + mixed_b = rope_scaling.get('mixed_b', None) + rotary_emb = NTKScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, dtype, + mixed_b) elif scaling_type == "dynamic": scaling_factor = rope_scaling["factor"] rotary_emb = DynamicNTKScalingRotaryEmbedding( diff --git a/vllm/model_executor/models/bert_with_rope.py b/vllm/model_executor/models/bert_with_rope.py index 05cd84748fb4..002949abff52 100644 --- a/vllm/model_executor/models/bert_with_rope.py +++ b/vllm/model_executor/models/bert_with_rope.py @@ -32,11 +32,18 @@ class BertWithRopeEmbedding(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() - assert config.type_vocab_size > 0 + if config.position_embedding_type not in ["rope", "rotary"]: + raise ValueError("Only 'rotary'('rope') position_embedding_type" + + " is supported") + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.token_type_embeddings = VocabParallelEmbedding( - config.type_vocab_size, config.hidden_size) + if config.type_vocab_size > 0: + self.token_type_embeddings = VocabParallelEmbedding( + config.type_vocab_size, config.hidden_size) + else: + self.token_type_embeddings = None + self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -47,13 +54,17 @@ def forward( ) -> torch.Tensor: input_shape = input_ids.size() inputs_embeds = self.word_embeddings(input_ids) - if token_type_ids is None: - token_type_ids = torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device) - token_type_embeddings = self.token_type_embeddings(token_type_ids) - embeddings = inputs_embeds + token_type_embeddings + embeddings = inputs_embeds + if self.token_type_embeddings is not None: + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings += token_type_embeddings + embeddings = self.LayerNorm(embeddings) return embeddings @@ -321,7 +332,7 @@ def __init__(self, if moe: self.mlp = NomicMoELayer(config=config, ) else: - if config.hidden_act in ["silu", "gelu_and_mul"]: + if config.hidden_act in ["silu", "geglu"]: self.mlp = BertWithRopeGatedMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -390,6 +401,7 @@ class BertWithRope(nn.Module, SupportsV0Only, SupportsQuant): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + self.vllm_config = vllm_config self.config = self.config_verify(vllm_config) self.embeddings = BertWithRopeEmbedding(self.config) self.encoder = BertWithRopeEncoder( @@ -420,7 +432,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: weights = self.hf_to_vllm_mapper.apply(weights) - if self.config.hidden_act in ["silu", "gelu_and_mul"]: + if self.config.hidden_act in ["silu", "geglu"]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -458,6 +470,8 @@ def load_weights(self, weights: Iterable[Tuple[str, class NomicBertModel(BertWithRope): + # for https://huggingface.co/nomic-ai/nomic-bert-2048 + hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ "emb_ln": "embeddings.LayerNorm", @@ -475,6 +489,9 @@ def config_verify(self, vllm_config): assert config.__class__.__name__ == "NomicBertConfig" assert config.activation_function in ["swiglu", "gelu"] + config.position_embedding_type = getattr(config, + "position_embedding_type", + "rope") if config.activation_function == "swiglu": config.hidden_act = "silu" @@ -512,10 +529,13 @@ def config_verify(self, vllm_config): return config -class GteModel(BertWithRope): +class GteNewModel(BertWithRope): + # for https://huggingface.co/Alibaba-NLP/new-impl + hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ - "layer": 'layers', + "new.": "", + "layer": "layers", "attention.qkv_proj": "attn.qkv_proj", "attention.o_proj": "attn.out_proj", }) @@ -523,7 +543,7 @@ class GteModel(BertWithRope): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__(vllm_config=vllm_config, prefix=prefix) - # GteModel only gate_up_proj does not have bias. + # GteNewModel only gate_up_proj does not have bias. # Hack method learned from vllm/model_executor/models/glm.py for layer in self.encoder.layers: layer.mlp.gate_up_proj.bias = None @@ -532,12 +552,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def config_verify(self, vllm_config): config = vllm_config.model_config.hf_config - assert config.__class__.__name__ == "GteConfig" - assert config.position_embedding_type == "rope" + assert config.__class__.__name__ == "NewConfig" assert config.hidden_act == "gelu" - config.position_embedding_type = "rotary" - config.hidden_act = "gelu_and_mul" + config.hidden_act = "geglu" head_dim = config.hidden_size // config.num_attention_heads config.rotary_kwargs = { @@ -559,13 +577,52 @@ def split_up_gate_proj(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: yield name, weight + def ignore_unnecessary_layers(self, + weights: Iterable[Tuple[str, torch.Tensor]]): + for name, weight in weights: + if name.startswith("classifier"): + continue + yield name, weight + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: + weights = self.ignore_unnecessary_layers(weights) weights = self.split_up_gate_proj(weights) return super().load_weights(weights) +class SnowflakeGteNewModel(GteNewModel): + # for Snowflake/snowflake-arctic-embed-m-v2.0 + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "layer": "layers", + "attention.qkv_proj": "attn.qkv_proj", + "attention.o_proj": "attn.out_proj", + }) + + def config_verify(self, vllm_config): + config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "GteConfig" + assert config.hidden_act == "gelu" + + config.hidden_act = "geglu" + + head_dim = config.hidden_size // config.num_attention_heads + config.rotary_kwargs = { + "head_size": head_dim, + "rotary_dim": getattr(config, "rotary_emb_dim", head_dim), + "max_position": config.max_position_embeddings, + "base": config.rope_theta, + "rope_scaling": getattr(config, "rope_scaling", None) + } + return config + + class JinaRobertaModel(BertWithRope): + # for https://huggingface.co/jinaai/jina-embeddings-v3 + hf_to_vllm_mapper = WeightsMapper( orig_to_new_substr={ "emb_ln": "embeddings.LayerNorm", @@ -579,6 +636,9 @@ class JinaRobertaModel(BertWithRope): def config_verify(self, vllm_config): config = vllm_config.model_config.hf_config + + assert config.__class__.__name__ == "XLMRobertaFlashConfig" + head_dim = config.hidden_size // config.num_attention_heads config.rotary_kwargs = { "head_size": head_dim, @@ -611,6 +671,7 @@ def jina_merge_lora_weights(self, weights: Iterable[Tuple[str, # This is a temporary solution until we have a better way to handle scaling = self.config.lora_alpha / self.config.lora_rank + device = self.vllm_config.device_config.device weights = {name: weight for name, weight in weights} @@ -628,13 +689,13 @@ def jina_merge_lora_weights(self, weights: Iterable[Tuple[str, weight_name = name[:-len(o)] if "embeddings" in weight_name: - B = weights[weight_name + a][i].cuda().float() - A = weights[weight_name + b][i].cuda().float() + B = weights[weight_name + a][i].to(device).float() + A = weights[weight_name + b][i].to(device).float() else: - B = weights[weight_name + b][i].cuda().float() - A = weights[weight_name + a][i].cuda().float() + B = weights[weight_name + b][i].to(device).float() + A = weights[weight_name + a][i].to(device).float() - weight = (weights[weight_name + o].cuda() + + weight = (weights[weight_name + o].to(device) + torch.matmul(B, A).view(shape) * scaling) weight = weight.cpu().to(dtype) diff --git a/vllm/model_executor/models/modernbert.py b/vllm/model_executor/models/modernbert.py index 2190241f0ba3..73effb207bce 100644 --- a/vllm/model_executor/models/modernbert.py +++ b/vllm/model_executor/models/modernbert.py @@ -230,9 +230,12 @@ def load_weights(self, weights: Iterable[Tuple[str, def forward( self, input_ids: Optional[torch.LongTensor] = None, + positions: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, ) -> torch.Tensor: + position_ids = positions if positions is not None else position_ids if inputs_embeds is not None: hidden_states = inputs_embeds else: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index ebbbb3938fa1..06a0e6574630 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -127,7 +127,8 @@ "Gemma2Model": ("gemma2", "Gemma2ForCausalLM"), "GlmForCausalLM": ("glm", "GlmForCausalLM"), "GritLM": ("gritlm", "GritLM"), - "GteModel": ("bert_with_rope", "GteModel"), + "GteModel": ("bert_with_rope", "SnowflakeGteNewModel"), + "GteNewModel": ("bert_with_rope", "GteNewModel"), "InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"), "JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501 "LlamaModel": ("llama", "LlamaForCausalLM"), @@ -137,6 +138,7 @@ if arch == "LlamaForCausalLM" }, "MistralModel": ("llama", "LlamaForCausalLM"), + "ModernBertModel": ("modernbert", "ModernBertModel"), "NomicBertModel": ("bert_with_rope", "NomicBertModel"), "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),