2525
2626import torch
2727import torch .nn as nn
28- import torch .nn .functional as F
2928
3029from ...activations import ACT2FN
3130from ...cache_utils import Cache , HybridCache , StaticCache
4544from ...utils .deprecation import deprecate_kwarg
4645from ..gemma import GemmaPreTrainedModel
4746from ..siglip import SiglipVisionModel
48- from .configuration_gemma3 import Gemma3Config , Gemma3RotaryEmbeddingConfig , Gemma3TextConfig , Gemma3VisionConfig
47+ from .configuration_gemma3 import Gemma3Config , Gemma3RotaryEmbeddingConfig , Gemma3TextConfig
4948
5049
5150logger = logging .get_logger (__name__ )
@@ -72,28 +71,6 @@ def extra_repr(self):
7271 return f"{ tuple (self .weight .shape )} , eps={ self .eps } "
7372
7473
75- class Gemma3VisionAvgPool2D (nn .Module ):
76- def __init__ (self , config : Gemma3VisionConfig ):
77- super ().__init__ ()
78- self .config = config
79-
80- def forward (self , x ):
81- """
82- Applies average pooling on (B, width, width)
83- to make it (B, final_width, final_width).
84- """
85- batch_size , seq_len , channels = x .shape
86- width = int (seq_len ** 0.5 )
87- if width * width != seq_len :
88- raise ValueError (f"Sequence length { seq_len } is not a perfect square. Cannot reshape to a square image." )
89- final_width = int (self .config .pooled_seq_len ** 0.5 )
90- kernel_size = width // final_width
91- x = x .transpose (1 , 2 ).reshape (batch_size , channels , width , width )
92- x = F .avg_pool2d (x , kernel_size = kernel_size , stride = kernel_size )
93- x = x .flatten (2 ).transpose (1 , 2 )
94- return x
95-
96-
9774class Gemma3MultimodalInputProjection (nn .Module ):
9875 def __init__ (self , vision_dim : int , text_dim : int ):
9976 super ().__init__ ()
@@ -1035,6 +1012,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
10351012
10361013 def __init__ (self , config : Gemma3Config ):
10371014 super ().__init__ (config )
1015+
10381016 self .config = config
10391017 text_config = self .config .text_config
10401018 vision_config = self .config .vision_config
@@ -1050,7 +1028,10 @@ def __init__(self, config: Gemma3Config):
10501028 vision_dim = vision_config .hidden_size , text_dim = text_config .hidden_size
10511029 )
10521030 self .mm_soft_emb_norm = Gemma3RMSNorm (vision_config .hidden_size , eps = vision_config .layer_norm_eps )
1053- self .avg_pool = Gemma3VisionAvgPool2D (config .vision_config )
1031+
1032+ patches_per_image = vision_config .image_size // vision_config .patch_size
1033+ avg_pool_k = patches_per_image ** 2 // text_config .mm_tokens_per_image
1034+ self .avg_pool = nn .AvgPool1d (kernel_size = avg_pool_k , stride = avg_pool_k )
10541035 self .vocab_size = text_config .vocab_size
10551036 self .pad_token_id = pad_token_id if (pad_token_id := text_config .pad_token_id ) is not None else - 1
10561037 self .post_init ()
@@ -1095,7 +1076,12 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
10951076 image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
10961077 """
10971078 vision_outputs = self .vision_model (pixel_values = pixel_values ).last_hidden_state
1098- pooled_vision_outputs = self .avg_pool (vision_outputs )
1079+ b , n , l = vision_outputs .shape
1080+ reshaped_vision_outputs = vision_outputs .permute (0 , 2 , 1 )
1081+ reshaped_vision_outputs = reshaped_vision_outputs .contiguous ()
1082+ reshaped_vision_outputs = reshaped_vision_outputs .view (b , l , n )
1083+ pooled_vision_outputs = self .avg_pool (reshaped_vision_outputs )
1084+ pooled_vision_outputs = pooled_vision_outputs .permute (0 , 2 , 1 )
10991085 image_features = self .encode_vision (pooled_vision_outputs )
11001086 return image_features
11011087
0 commit comments