|
22 | 22 | QKVParallelLinear, |
23 | 23 | ReplicatedLinear, |
24 | 24 | RowParallelLinear) |
| 25 | +from vllm.model_executor.layers.pooler import (ClassifierPooler, |
| 26 | + DispatchPooler, Pooler) |
25 | 27 | from vllm.model_executor.layers.quantization import QuantizationConfig |
26 | 28 | from vllm.model_executor.layers.rotary_embedding import get_rope |
27 | 29 | from vllm.model_executor.layers.vocab_parallel_embedding import ( |
28 | 30 | VocabParallelEmbedding) |
29 | 31 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
30 | | -from vllm.model_executor.models.interfaces import (SupportsQuant, |
| 32 | +from vllm.model_executor.models.bert import BertPooler |
| 33 | +from vllm.model_executor.models.interfaces import (SupportsCrossEncoding, |
| 34 | + SupportsQuant, |
31 | 35 | default_pooling_type) |
32 | | -from vllm.model_executor.models.utils import WeightsMapper |
| 36 | +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, |
| 37 | + maybe_prefix) |
33 | 38 | from vllm.model_executor.utils import set_weight_attrs |
34 | 39 | from vllm.platforms import current_platform |
35 | 40 | from vllm.sequence import IntermediateTensors |
@@ -405,16 +410,22 @@ def forward( |
405 | 410 | class BertWithRope(nn.Module, SupportsQuant): |
406 | 411 | hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) |
407 | 412 |
|
408 | | - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 413 | + def __init__(self, |
| 414 | + *, |
| 415 | + vllm_config: VllmConfig, |
| 416 | + prefix: str = "", |
| 417 | + add_pooling_layer: bool = False): |
409 | 418 | super().__init__() |
410 | 419 | self.vllm_config = vllm_config |
| 420 | + self.add_pooling_layer = add_pooling_layer |
411 | 421 | self.config = vllm_config.model_config.hf_config |
412 | 422 | self.embeddings = BertWithRopeEmbedding(self.config) |
413 | 423 | self.encoder = BertWithRopeEncoder( |
414 | 424 | vllm_config=vllm_config, |
415 | 425 | bias=getattr(self.config, "bias", True), |
416 | 426 | rotary_kwargs=self.config.rotary_kwargs, |
417 | 427 | prefix=f"{prefix}.encoder") |
| 428 | + self.pooler = BertPooler(self.config) if add_pooling_layer else None |
418 | 429 |
|
419 | 430 | def forward( |
420 | 431 | self, |
@@ -447,7 +458,7 @@ def load_weights(self, weights: Iterable[tuple[str, |
447 | 458 | params_dict = dict(self.named_parameters()) |
448 | 459 | loaded_params: set[str] = set() |
449 | 460 | for name, loaded_weight in weights: |
450 | | - if "pooler" in name: |
| 461 | + if not self.add_pooling_layer and "pooler" in name: |
451 | 462 | continue |
452 | 463 | for (param_name, weight_name, shard_id) in stacked_params_mapping: |
453 | 464 | if weight_name not in name: |
@@ -507,8 +518,8 @@ class GteNewModel(BertWithRope): |
507 | 518 | "attention.o_proj": "attn.out_proj", |
508 | 519 | }) |
509 | 520 |
|
510 | | - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
511 | | - super().__init__(vllm_config=vllm_config, prefix=prefix) |
| 521 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", **kwargs): |
| 522 | + super().__init__(vllm_config=vllm_config, prefix=prefix, **kwargs) |
512 | 523 |
|
513 | 524 | # GteNewModel only gate_up_proj does not have bias. |
514 | 525 | # Hack method learned from vllm/model_executor/models/glm.py |
@@ -613,3 +624,65 @@ def load_weights(self, weights: Iterable[tuple[str, |
613 | 624 | torch.Tensor]]) -> set[str]: |
614 | 625 | weights = self.jina_merge_lora_weights(weights) |
615 | 626 | return super().load_weights(weights) |
| 627 | + |
| 628 | + |
| 629 | +@default_pooling_type("CLS") |
| 630 | +class GteNewForSequenceClassification(nn.Module, SupportsCrossEncoding): |
| 631 | + is_pooling_model = True |
| 632 | + |
| 633 | + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| 634 | + super().__init__() |
| 635 | + config = vllm_config.model_config.hf_config |
| 636 | + quant_config = vllm_config.quant_config |
| 637 | + |
| 638 | + self.new = GteNewModel(vllm_config=vllm_config, |
| 639 | + prefix=prefix, |
| 640 | + add_pooling_layer=True) |
| 641 | + self.classifier = RowParallelLinear(config.hidden_size, |
| 642 | + config.num_labels, |
| 643 | + input_is_parallel=False, |
| 644 | + bias=True, |
| 645 | + quant_config=quant_config, |
| 646 | + prefix=maybe_prefix( |
| 647 | + prefix, "classifier"), |
| 648 | + return_bias=False) |
| 649 | + |
| 650 | + pooler_config = vllm_config.model_config.pooler_config |
| 651 | + assert pooler_config is not None |
| 652 | + |
| 653 | + self.pooler = DispatchPooler({ |
| 654 | + "encode": |
| 655 | + Pooler.for_encode(pooler_config), |
| 656 | + "classify": |
| 657 | + ClassifierPooler( |
| 658 | + pooling=self.new.pooler, |
| 659 | + classifier=self.classifier, |
| 660 | + act_fn=ClassifierPooler.act_fn_for_seq_cls( |
| 661 | + vllm_config.model_config), |
| 662 | + ), |
| 663 | + "score": |
| 664 | + ClassifierPooler( |
| 665 | + pooling=self.new.pooler, |
| 666 | + classifier=self.classifier, |
| 667 | + act_fn=ClassifierPooler.act_fn_for_cross_encoder( |
| 668 | + vllm_config.model_config), |
| 669 | + ), |
| 670 | + }) |
| 671 | + |
| 672 | + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): |
| 673 | + loader = AutoWeightsLoader(self) |
| 674 | + loaded_params = loader.load_weights(weights) |
| 675 | + return loaded_params |
| 676 | + |
| 677 | + def forward( |
| 678 | + self, |
| 679 | + input_ids: Optional[torch.Tensor], |
| 680 | + positions: torch.Tensor, |
| 681 | + intermediate_tensors: Optional[IntermediateTensors] = None, |
| 682 | + inputs_embeds: Optional[torch.Tensor] = None, |
| 683 | + ) -> torch.Tensor: |
| 684 | + |
| 685 | + return self.new(input_ids=input_ids, |
| 686 | + positions=positions, |
| 687 | + inputs_embeds=inputs_embeds, |
| 688 | + intermediate_tensors=intermediate_tensors) |
0 commit comments