2525
2626import torch
2727import torch .nn as nn
28+ import torch .nn .functional as F
2829
2930from ...activations import ACT2FN
3031from ...cache_utils import Cache , HybridCache , StaticCache
4445from ...utils .deprecation import deprecate_kwarg
4546from ..gemma import GemmaPreTrainedModel
4647from ..siglip import SiglipVisionModel
47- from .configuration_gemma3 import Gemma3Config , Gemma3RotaryEmbeddingConfig , Gemma3TextConfig
48+ from .configuration_gemma3 import Gemma3Config , Gemma3RotaryEmbeddingConfig , Gemma3TextConfig , Gemma3VisionConfig
4849
4950
5051logger = logging .get_logger (__name__ )
@@ -71,6 +72,28 @@ def extra_repr(self):
7172 return f"{ tuple (self .weight .shape )} , eps={ self .eps } "
7273
7374
75+ class Gemma3VisionAvgPool2D (nn .Module ):
76+ def __init__ (self , config : Gemma3VisionConfig ):
77+ super ().__init__ ()
78+ self .config = config
79+
80+ def forward (self , x ):
81+ """
82+ Applies average pooling on (B, channels, width, width)
83+ to make it (B, channels, final_width, final_width).
84+ """
85+ batch_size , seq_len , channels = x .shape
86+ width = int (seq_len ** 0.5 )
87+ if width * width != seq_len :
88+ raise ValueError (f"Sequence length { seq_len } is not a perfect square. Cannot reshape to a square image." )
89+ final_width = int (self .config .pooled_seq_len ** 0.5 )
90+ kernel_size = width // final_width
91+ x = x .transpose (1 , 2 ).reshape (batch_size , channels , width , width )
92+ x = F .avg_pool2d (x , kernel_size = kernel_size , stride = kernel_size )
93+ x = x .flatten (2 ).transpose (1 , 2 )
94+ return x
95+
96+
7497class Gemma3MultimodalInputProjection (nn .Module ):
7598 def __init__ (self , vision_dim : int , text_dim : int ):
7699 super ().__init__ ()
@@ -1029,9 +1052,7 @@ def __init__(self, config: Gemma3Config):
10291052 )
10301053 self .mm_soft_emb_norm = Gemma3RMSNorm (vision_config .hidden_size , eps = vision_config .layer_norm_eps )
10311054
1032- patches_per_image = vision_config .image_size // vision_config .patch_size
1033- avg_pool_k = patches_per_image ** 2 // text_config .mm_tokens_per_image
1034- self .avg_pool = nn .AvgPool1d (kernel_size = avg_pool_k , stride = avg_pool_k )
1055+ self .avg_pool = Gemma3VisionAvgPool2D (config .vision_config )
10351056 self .vocab_size = text_config .vocab_size
10361057 self .pad_token_id = pad_token_id if (pad_token_id := text_config .pad_token_id ) is not None else - 1
10371058 self .post_init ()
@@ -1076,12 +1097,7 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
10761097 image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
10771098 """
10781099 vision_outputs = self .vision_model (pixel_values = pixel_values ).last_hidden_state
1079- b , n , l = vision_outputs .shape
1080- reshaped_vision_outputs = vision_outputs .permute (0 , 2 , 1 )
1081- reshaped_vision_outputs = reshaped_vision_outputs .contiguous ()
1082- reshaped_vision_outputs = reshaped_vision_outputs .view (b , l , n )
1083- pooled_vision_outputs = self .avg_pool (reshaped_vision_outputs )
1084- pooled_vision_outputs = pooled_vision_outputs .permute (0 , 2 , 1 )
1100+ pooled_vision_outputs = self .avg_pool (vision_outputs )
10851101 image_features = self .encode_vision (pooled_vision_outputs )
10861102 return image_features
10871103
0 commit comments