@@ -572,6 +572,220 @@ def _decode_token_type_ids(input_ids: torch.Tensor) -> torch.Tensor:
572572    return  token_type_ids 
573573
574574
575+ class  BertMLMHead (nn .Module ):
576+     def  __init__ (
577+         self , hidden_size : int , vocab_size : int , layer_norm_eps : float  =  1e-12 
578+     ):
579+         super ().__init__ ()
580+         self .dense  =  nn .Linear (hidden_size , hidden_size )
581+         self .activation  =  nn .GELU ()
582+         self .layer_norm  =  nn .LayerNorm (hidden_size , eps = layer_norm_eps )
583+         self .decoder  =  nn .Linear (hidden_size , vocab_size , bias = True )
584+ 
585+     def  tie_weights_with_embeddings (self , embeddings_weight : torch .Tensor ):
586+         self .decoder .weight  =  embeddings_weight 
587+ 
588+     def  forward (self , hidden_states : torch .Tensor ) ->  torch .Tensor :
589+         x  =  self .dense (hidden_states )
590+         x  =  self .activation (x )
591+         x  =  self .layer_norm (x )
592+         logits  =  self .decoder (x )
593+         return  logits 
594+ 
595+ 
596+ class  SPLADESparsePooler (Pooler ):
597+     """ 
598+     SPLADE sparse pooling: 
599+     logits = mlm_head(hidden_states) 
600+             -> log1p(relu(logits)) 
601+             -> (max|sum over L) 
602+             -> [V] 
603+ 
604+     Padding is masked with an attention mask, 
605+     [CLS]/[SEP] is removed (selected), 
606+     and then pooled. 
607+     """ 
608+ 
609+     def  __init__ (
610+         self ,
611+         mlm_head : nn .Module ,
612+         cls_token_id : Optional [int ] =  101 ,
613+         sep_token_id : Optional [int ] =  102 ,
614+         pooling : str  =  "max" ,
615+         remove_cls_sep : bool  =  True ,
616+     ):
617+         super ().__init__ ()
618+         assert  pooling  in  ("max" , "sum" )
619+         self .mlm_head  =  mlm_head 
620+         self .cls_token_id  =  cls_token_id 
621+         self .sep_token_id  =  sep_token_id 
622+         self .pooling  =  pooling 
623+         self .remove_cls_sep  =  remove_cls_sep 
624+ 
625+     def  get_supported_tasks (self ) ->  Set [PoolingTask ]:
626+         return  {"embed" }
627+ 
628+     def  get_pooling_updates (self , task : PoolingTask ) ->  PoolingParamsUpdate :
629+         return  PoolingParamsUpdate (requires_token_ids = True )
630+ 
631+     def  forward (
632+         self ,
633+         hidden_states : torch .Tensor ,
634+         pooling_metadata : PoolingMetadata ,
635+     ) ->  torch .Tensor :
636+         assert  isinstance (hidden_states , torch .Tensor ) and  hidden_states .dim () ==  2 
637+ 
638+         lens_tensor : torch .Tensor  =  pooling_metadata .prompt_lens 
639+         lens : list [int ] =  lens_tensor .tolist ()
640+         B : int  =  len (lens )
641+ 
642+         token_ids  =  pooling_metadata .prompt_token_ids 
643+         offset  =  0 
644+         pooled_list : list [torch .Tensor ] =  []
645+ 
646+         for  i  in  range (B ):
647+             L  =  int (lens [i ])
648+             hs  =  hidden_states [offset  : offset  +  L ]
649+ 
650+             start_idx  =  0 
651+             end_idx  =  L 
652+             if  self .remove_cls_sep  and  token_ids  is  not None :
653+                 if  (
654+                     self .cls_token_id  is  not None 
655+                     and  token_ids [i , 0 ].item () ==  self .cls_token_id 
656+                 ):
657+                     start_idx  =  1 
658+                 if  (
659+                     self .sep_token_id  is  not None 
660+                     and  token_ids [i , L  -  1 ].item () ==  self .sep_token_id 
661+                 ):
662+                     end_idx  =  max (start_idx , L  -  1 )
663+ 
664+             if  end_idx  <=  start_idx :
665+                 V  =  int (self .mlm_head .decoder .out_features )
666+                 pooled_list .append (hs .new_zeros ((V ,)))
667+                 offset  +=  L 
668+                 continue 
669+ 
670+             logits_i  =  self .mlm_head (hs [start_idx :end_idx ])
671+             scores_i  =  torch .log1p (torch .relu (logits_i ))
672+ 
673+             if  self .pooling  ==  "sum" :
674+                 pooled_i  =  scores_i .sum (dim = 0 )
675+             else :  # "max" 
676+                 pooled_i  =  scores_i .max (dim = 0 ).values 
677+ 
678+             pooled_list .append (pooled_i .contiguous ())
679+             offset  +=  L 
680+ 
681+         return  torch .stack (pooled_list , dim = 0 ).contiguous ()
682+ 
683+ 
684+ @default_pooling_type ("CLS" ) 
685+ class  BertSpladeSparseEmbeddingModel (BertEmbeddingModel ):
686+     """ 
687+     BertEmbeddingModel + SPLADE sparse embedding. 
688+     - Make logits by self.mlm_head 
689+     - pooler: SPLADESparsePooler(mlm_head...) 
690+     """ 
691+ 
692+     def  __init__ (
693+         self , * , vllm_config : VllmConfig , prefix : str  =  "" , splade_pooling : str  =  "max" 
694+     ):
695+         super ().__init__ (vllm_config = vllm_config , prefix = prefix )
696+         cfg  =  vllm_config .model_config .hf_config 
697+ 
698+         # MLM head 
699+         self .mlm_head  =  BertMLMHead (
700+             hidden_size = cfg .hidden_size ,
701+             vocab_size = cfg .vocab_size ,
702+             layer_norm_eps = getattr (cfg , "layer_norm_eps" , 1e-12 ),
703+         )
704+ 
705+         self ._splade_pooling  =  splade_pooling 
706+         pooler_config  =  vllm_config .model_config .pooler_config 
707+         assert  pooler_config  is  not None 
708+         self .pooler  =  self ._build_pooler (pooler_config )
709+ 
710+     def  _build_pooler (self , pooler_config : PoolerConfig ) ->  Pooler :
711+         cfg  =  self .model .config 
712+ 
713+         if  not  hasattr (self , "mlm_head" ):
714+             self .mlm_head  =  BertMLMHead (
715+                 hidden_size = cfg .hidden_size ,
716+                 vocab_size = cfg .vocab_size ,
717+                 layer_norm_eps = getattr (cfg , "layer_norm_eps" , 1e-12 ),
718+             )
719+ 
720+         pooling_mode  =  getattr (self , "_splade_pooling" , "max" )
721+ 
722+         cls_id  =  getattr (cfg , "cls_token_id" , None )
723+         sep_id  =  getattr (cfg , "sep_token_id" , None )
724+ 
725+         return  DispatchPooler (
726+             {
727+                 "encode" : Pooler .for_encode (pooler_config ),
728+                 "embed" : SPLADESparsePooler (
729+                     mlm_head = self .mlm_head ,
730+                     cls_token_id = cls_id ,
731+                     sep_token_id = sep_id ,
732+                     pooling = pooling_mode ,  # "max" or "sum" 
733+                     remove_cls_sep = True ,
734+                 ),
735+             }
736+         )
737+ 
738+     def  load_weights (self , weights : Iterable [tuple [str , torch .Tensor ]]):
739+         if  not  hasattr (self , "mlm_head" ):
740+             cfg  =  self .model .config 
741+             self .mlm_head  =  BertMLMHead (
742+                 hidden_size = cfg .hidden_size ,
743+                 vocab_size = cfg .vocab_size ,
744+                 layer_norm_eps = getattr (cfg , "layer_norm_eps" , 1e-12 ),
745+             )
746+ 
747+         def  _strip (name : str ) ->  str :
748+             for  p  in  ("model." , "bert." ):
749+                 if  name .startswith (p ):
750+                     name  =  name [len (p ) :]
751+             return  name 
752+ 
753+         weights_list  =  list (weights )
754+         model_side : list [tuple [str , torch .Tensor ]] =  []
755+         mlm_side : list [tuple [str , torch .Tensor ]] =  []
756+ 
757+         for  k , w  in  weights_list :
758+             name  =  _strip (k )
759+             if  name .startswith ("cls.predictions." ):
760+                 mlm_side .append ((name , w ))
761+             else :
762+                 model_side .append ((name , w ))
763+ 
764+         loaded : set [str ] =  set ()
765+         loaded_model  =  self .model .load_weights (model_side )
766+         loaded .update ({"model."  +  n  for  n  in  loaded_model })
767+ 
768+         if  mlm_side :
769+             name_map  =  {
770+                 "cls.predictions.transform.dense.weight" : "mlm_head.dense.weight" ,
771+                 "cls.predictions.transform.dense.bias" : "mlm_head.dense.bias" ,
772+                 ("cls.predictions.transform.LayerNorm.weight" ): (
773+                     "mlm_head.layer_norm.weight" 
774+                 ),
775+                 ("cls.predictions.transform.LayerNorm.bias" ): (
776+                     "mlm_head.layer_norm.bias" 
777+                 ),
778+                 "cls.predictions.decoder.weight" : "mlm_head.decoder.weight" ,
779+                 "cls.predictions.decoder.bias" : "mlm_head.decoder.bias" ,
780+             }
781+             remapped  =  [(name_map [n ], w ) for  n , w  in  mlm_side  if  n  in  name_map ]
782+             if  remapped :
783+                 loaded_mlm  =  AutoWeightsLoader (self ).load_weights (remapped )
784+                 loaded .update (loaded_mlm )
785+ 
786+         return  loaded 
787+ 
788+ 
575789@default_pooling_type ("CLS" ) 
576790class  BertForSequenceClassification (nn .Module , SupportsCrossEncoding , SupportsQuant ):
577791    """A model that uses Bert to provide embedding functionalities. 
0 commit comments