@@ -32,6 +32,7 @@ class PoolingType(IntEnum):
3232 CLS = 2
3333 STEP = 3
3434 MEAN = 4
35+ VISION = 5
3536
3637
3738@dataclass (frozen = True )
@@ -91,6 +92,8 @@ def from_config_with_defaults(
9192
9293 if pooling_type == PoolingType .STEP :
9394 return StepPooler .from_config (resolved_config )
95+ if pooling_type == PoolingType .VISION :
96+ return VisionPooler .from_config (resolved_config )
9497
9598 return SimplePooler .from_config (resolved_config )
9699
@@ -622,6 +625,86 @@ def forward(
622625ClassifierFn = Callable [[torch .Tensor ], torch .Tensor ]
623626
624627
628+ class VisionPooler (Pooler ):
629+
630+ @classmethod
631+ def from_config (cls , model_config : ModelConfig ) -> "VisionPooler" :
632+ return cls (model_config )
633+
634+ def __init__ (self , config : ModelConfig ):
635+ super ().__init__ ()
636+ self .config = config
637+
638+ def get_pooling_params (self , task : PoolingTask ) -> Optional [PoolingParams ]:
639+ if task == "embed" :
640+ return PoolingParams (pooling_type = "vision" ,
641+ logits_processing_needs_token_ids = True )
642+ return None
643+
644+ def forward (
645+ self ,
646+ hidden_states : torch .Tensor ,
647+ pooling_metadata : PoolingMetadata ,
648+ ) -> PoolerOutput :
649+ assert isinstance (pooling_metadata , V1PoolingMetadata )
650+
651+ pooled_outputs = []
652+ for i in range (len (pooling_metadata .prompt_lens )):
653+ start_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
654+ hf_config .vision_start_token_id ).nonzero ()[- 1 ].item ()
655+ end_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
656+ hf_config .vision_end_token_id ).nonzero ()[- 1 ].item ()
657+
658+ seq_start = torch .cumsum (
659+ torch .tensor ([0 ] + pooling_metadata .prompt_lens .tolist ()),
660+ dim = 0 )[i ]
661+ seq_len = pooling_metadata .prompt_lens [i ]
662+
663+ output = torch .empty (self .config .hidden_size ,
664+ device = hidden_states .device ,
665+ dtype = hidden_states .dtype )
666+
667+ grid = lambda meta : (self .config .hidden_size , )
668+ mean_pool_with_position_kernel [grid ](hidden_states , output ,
669+ seq_start , seq_len ,
670+ self .config .hidden_size ,
671+ start_pos , end_pos + 1 )
672+
673+ pooled_outputs .append (output )
674+
675+ return build_output (torch .stack (pooled_outputs ))
676+
677+
678+ if HAS_TRITON :
679+
680+ @triton .jit
681+ def mean_pool_with_position_kernel (
682+ hidden_states_ptr ,
683+ output_ptr ,
684+ seq_start ,
685+ seq_len ,
686+ hidden_size ,
687+ pool_start ,
688+ pool_end ,
689+ BLOCK_SIZE : tl .constexpr ,
690+ ):
691+ """Triton kernel to perform mean pooling over a specified token range."""
692+ pid = tl .program_id (0 )
693+
694+ if pid >= hidden_size :
695+ return
696+
697+ accumulator = 0.0
698+ for i in range (pool_start , pool_end ):
699+ hidden_val = tl .load (hidden_states_ptr +
700+ (seq_start + i ) * hidden_size + pid )
701+ accumulator += hidden_val
702+
703+ # Store mean pooled result
704+ result = accumulator / (pool_end - pool_start )
705+ tl .store (output_ptr + pid , result )
706+
707+
625708class ClassifierPooler (nn .Module ):
626709 """A pooling layer for classification tasks.
627710
@@ -709,39 +792,81 @@ def forward(
709792 return build_output (scores )
710793
711794
795+ class VisionPooler (Pooler ):
796+
797+ @classmethod
798+ def from_config (cls , model_config : ModelConfig ) -> "VisionPooler" :
799+ return cls (model_config )
800+
801+ def __init__ (self , config : ModelConfig ):
802+ super ().__init__ ()
803+ self .config = config
804+
805+ def get_pooling_params (self , task : PoolingTask ) -> Optional [PoolingParams ]:
806+ if task == "embed" :
807+ return PoolingParams (pooling_type = "vision" ,
808+ logits_processing_needs_token_ids = True )
809+ return None
810+
811+ def forward (
812+ self ,
813+ hidden_states : torch .Tensor ,
814+ pooling_metadata : PoolingMetadata ,
815+ ) -> PoolerOutput :
816+ assert isinstance (pooling_metadata , V1PoolingMetadata )
817+
818+ pooled_outputs = []
819+ for i in range (len (pooling_metadata .prompt_lens )):
820+ start_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
821+ hf_config .vision_start_token_id ).nonzero ()[- 1 ].item ()
822+ end_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
823+ hf_config .vision_end_token_id ).nonzero ()[- 1 ].item ()
824+
825+ seq_start = torch .cumsum (
826+ torch .tensor ([0 ] + pooling_metadata .prompt_lens .tolist ()),
827+ dim = 0 )[i ]
828+ seq_len = pooling_metadata .prompt_lens [i ]
829+
830+ output = torch .empty (self .config .hidden_size ,
831+ device = hidden_states .device ,
832+ dtype = hidden_states .dtype )
833+
834+ grid = lambda meta : (self .config .hidden_size , )
835+ mean_pool_with_position_kernel [grid ](hidden_states , output ,
836+ seq_start , seq_len ,
837+ self .config .hidden_size ,
838+ start_pos , end_pos + 1 )
839+
840+ pooled_outputs .append (output )
841+
842+ return build_output (torch .stack (pooled_outputs ))
843+
844+
712845if HAS_TRITON :
713846
714847 @triton .jit
715- def extract_vision_tokens_kernel (
848+ def mean_pool_with_position_kernel (
716849 hidden_states_ptr ,
717- token_ids_ptr ,
718850 output_ptr ,
719851 seq_start ,
720852 seq_len ,
721853 hidden_size ,
722- vision_start_id : tl . constexpr ,
723- vision_end_id : tl . constexpr ,
854+ pool_start ,
855+ pool_end ,
724856 BLOCK_SIZE : tl .constexpr ,
725857 ):
726- """Triton kernel to extract and pool vision tokens efficiently ."""
858+ """Triton kernel to perform mean pooling over a specified token range ."""
727859 pid = tl .program_id (0 )
728860
729861 if pid >= hidden_size :
730862 return
731863
732- # Find vision token range
733- vision_count = 0
734864 accumulator = 0.0
735-
736- for i in range (seq_len ):
737- token_id = tl .load (token_ids_ptr + seq_start + i )
738- if token_id >= vision_start_id and token_id <= vision_end_id :
739- hidden_val = tl .load (hidden_states_ptr +
740- (seq_start + i ) * hidden_size + pid )
741- accumulator += hidden_val
742- vision_count += 1
865+ for i in range (pool_start , pool_end ):
866+ hidden_val = tl .load (hidden_states_ptr +
867+ (seq_start + i ) * hidden_size + pid )
868+ accumulator += hidden_val
743869
744870 # Store mean pooled result
745- result = accumulator / vision_count if vision_count > 0 else 0.0
746-
871+ result = accumulator / (pool_end - pool_start )
747872 tl .store (output_ptr + pid , result )
0 commit comments