1- from typing import List , Optional , Union
1+ from typing import Iterable , List , Optional , Tuple , Union
22
33import torch
4+ import torch .nn as nn
45
56from vllm .attention import AttentionMetadata
6- from vllm .model_executor .models .gemma2 import Gemma2EmbeddingModel
7- from vllm .sequence import IntermediateTensors
7+ from vllm .config import VllmConfig
8+ from vllm .model_executor .layers .pooler import Pooler , PoolingType
9+ from vllm .model_executor .models .gemma2 import Gemma2Model
10+ from vllm .model_executor .models .utils import WeightsMapper , maybe_prefix
11+ from vllm .model_executor .pooling_metadata import PoolingMetadata
12+ from vllm .sequence import IntermediateTensors , PoolerOutput
813
914
10- class MyGemma2Embedding (Gemma2EmbeddingModel ):
15+ class MyGemma2Embedding (nn .Module ):
16+
17+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
18+ super ().__init__ ()
19+
20+ self .model = Gemma2Model (vllm_config = vllm_config ,
21+ prefix = maybe_prefix (prefix , "model" ))
22+
23+ self ._pooler = Pooler .from_config_with_defaults (
24+ vllm_config .model_config .pooler_config ,
25+ pooling_type = PoolingType .LAST ,
26+ normalize = True ,
27+ softmax = False ,
28+ )
29+
30+ self .make_empty_intermediate_tensors = (
31+ self .model .make_empty_intermediate_tensors )
1132
1233 def forward (
1334 self ,
@@ -18,7 +39,7 @@ def forward(
1839 intermediate_tensors : Optional [IntermediateTensors ] = None ,
1940 inputs_embeds : Optional [torch .Tensor ] = None ,
2041 ) -> Union [torch .Tensor , IntermediateTensors ]:
21- hidden_states = super (). forward (
42+ hidden_states = self . model (
2243 input_ids ,
2344 positions ,
2445 kv_caches ,
@@ -32,3 +53,17 @@ def forward(
3253
3354 # Return all-zero embeddings
3455 return torch .zeros_like (hidden_states )
56+
57+ def pooler (
58+ self ,
59+ hidden_states : torch .Tensor ,
60+ pooling_metadata : PoolingMetadata ,
61+ ) -> Optional [PoolerOutput ]:
62+ return self ._pooler (hidden_states , pooling_metadata )
63+
64+ def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
65+ hf_to_vllm_mapper = WeightsMapper (orig_to_new_prefix = {"model." : "" })
66+ weights = hf_to_vllm_mapper .apply (weights )
67+ weights = ((name , data ) for name , data in weights
68+ if not name .startswith ("lm_head." ))
69+ return self .model .load_weights (weights )
0 commit comments