Skip to content

Commit 00af9a7

Browse files
Revert "Adding 2D pooling for image embeddings"
This reverts commit 65350cf.
1 parent 65350cf commit 00af9a7

File tree

3 files changed

+23
-55
lines changed

3 files changed

+23
-55
lines changed

src/transformers/models/gemma3/configuration_gemma3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,6 @@ def __init__(
256256
layer_norm_eps: float = 0.000001,
257257
vision_use_head: bool = False,
258258
torch_dtype: str = "bfloat16",
259-
pooled_seq_len: int = 256,
260259
**kwargs,
261260
):
262261
super().__init__(
@@ -274,7 +273,6 @@ def __init__(
274273
**kwargs,
275274
)
276275

277-
self.pooled_seq_len = pooled_seq_len
278276
self.vision_use_head = vision_use_head
279277

280278

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import torch
2727
import torch.nn as nn
28-
import torch.nn.functional as F
2928

3029
from ...activations import ACT2FN
3130
from ...cache_utils import Cache, HybridCache, StaticCache
@@ -45,7 +44,7 @@
4544
from ...utils.deprecation import deprecate_kwarg
4645
from ..gemma import GemmaPreTrainedModel
4746
from ..siglip import SiglipVisionModel
48-
from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig, Gemma3VisionConfig
47+
from .configuration_gemma3 import Gemma3Config, Gemma3RotaryEmbeddingConfig, Gemma3TextConfig
4948

5049

5150
logger = logging.get_logger(__name__)
@@ -72,28 +71,6 @@ def extra_repr(self):
7271
return f"{tuple(self.weight.shape)}, eps={self.eps}"
7372

7473

75-
class Gemma3VisionAvgPool2D(nn.Module):
76-
def __init__(self, config: Gemma3VisionConfig):
77-
super().__init__()
78-
self.config = config
79-
80-
def forward(self, x):
81-
"""
82-
Applies average pooling on (B, width, width)
83-
to make it (B, final_width, final_width).
84-
"""
85-
batch_size, seq_len, channels = x.shape
86-
width = int(seq_len**0.5)
87-
if width * width != seq_len:
88-
raise ValueError(f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image.")
89-
final_width = int(self.config.pooled_seq_len**0.5)
90-
kernel_size = width // final_width
91-
x = x.transpose(1, 2).reshape(batch_size, channels, width, width)
92-
x = F.avg_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
93-
x = x.flatten(2).transpose(1, 2)
94-
return x
95-
96-
9774
class Gemma3MultimodalInputProjection(nn.Module):
9875
def __init__(self, vision_dim: int, text_dim: int):
9976
super().__init__()
@@ -1035,6 +1012,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
10351012

10361013
def __init__(self, config: Gemma3Config):
10371014
super().__init__(config)
1015+
10381016
self.config = config
10391017
text_config = self.config.text_config
10401018
vision_config = self.config.vision_config
@@ -1050,7 +1028,10 @@ def __init__(self, config: Gemma3Config):
10501028
vision_dim=vision_config.hidden_size, text_dim=text_config.hidden_size
10511029
)
10521030
self.mm_soft_emb_norm = Gemma3RMSNorm(vision_config.hidden_size, eps=vision_config.layer_norm_eps)
1053-
self.avg_pool = Gemma3VisionAvgPool2D(config.vision_config)
1031+
1032+
patches_per_image = vision_config.image_size // vision_config.patch_size
1033+
avg_pool_k = patches_per_image**2 // text_config.mm_tokens_per_image
1034+
self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k)
10541035
self.vocab_size = text_config.vocab_size
10551036
self.pad_token_id = pad_token_id if (pad_token_id := text_config.pad_token_id) is not None else -1
10561037
self.post_init()
@@ -1095,7 +1076,12 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
10951076
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
10961077
"""
10971078
vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state
1098-
pooled_vision_outputs = self.avg_pool(vision_outputs)
1079+
b, n, l = vision_outputs.shape
1080+
reshaped_vision_outputs = vision_outputs.permute(0, 2, 1)
1081+
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
1082+
reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n)
1083+
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
1084+
pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1)
10991085
image_features = self.encode_vision(pooled_vision_outputs)
11001086
return image_features
11011087

src/transformers/models/gemma3/modular_gemma3.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import torch
2727
import torch.nn as nn
2828
import torch.utils.checkpoint
29-
import torch.nn.functional as F
3029

3130
from ...activations import ACT2FN
3231
from ...cache_utils import Cache, HybridCache, StaticCache
@@ -333,7 +332,6 @@ def __init__(
333332
layer_norm_eps: float = 0.000001,
334333
vision_use_head: bool = False,
335334
torch_dtype: str = "bfloat16",
336-
pooled_seq_len: int = 256,
337335
**kwargs,
338336
):
339337
super().__init__(
@@ -351,7 +349,6 @@ def __init__(
351349
**kwargs,
352350
)
353351

354-
self.pooled_seq_len = pooled_seq_len
355352
self.vision_use_head = vision_use_head
356353

357354

@@ -713,28 +710,6 @@ def model_input_names(self):
713710
class Gemma3RMSNorm(GemmaRMSNorm):
714711
pass
715712

716-
class Gemma3VisionAvgPool2D(nn.Module):
717-
def __init__(self, config: Gemma3VisionConfig):
718-
super().__init__()
719-
self.config = config
720-
721-
def forward(self, x):
722-
"""
723-
Applies average pooling on (B, width, width)
724-
to make it (B, final_width, final_width).
725-
"""
726-
batch_size, seq_len, channels = x.shape
727-
width = int(seq_len**0.5)
728-
if width * width != seq_len:
729-
raise ValueError(
730-
f"Sequence length {seq_len} is not a perfect square. Cannot reshape to a square image."
731-
)
732-
final_width = int(self.config.pooled_seq_len**0.5)
733-
kernel_size = width//final_width
734-
x = x.transpose(1, 2).reshape(batch_size, channels, width, width)
735-
x = F.avg_pool2d(x, kernel_size=kernel_size, stride=kernel_size)
736-
x = x.flatten(2).transpose(1, 2)
737-
return x
738713

739714
class Gemma3MultimodalInputProjection(nn.Module):
740715

@@ -1715,6 +1690,7 @@ class Gemma3ForConditionalGeneration(PreTrainedModel, GenerationMixin):
17151690

17161691
def __init__(self, config: Gemma3Config):
17171692
super().__init__(config)
1693+
17181694
self.config = config
17191695
text_config = self.config.text_config
17201696
vision_config = self.config.vision_config
@@ -1732,7 +1708,10 @@ def __init__(self, config: Gemma3Config):
17321708
self.mm_soft_emb_norm = Gemma3RMSNorm(
17331709
vision_config.hidden_size, eps=vision_config.layer_norm_eps
17341710
)
1735-
self.avg_pool = Gemma3VisionAvgPool2D(config.vision_config)
1711+
1712+
patches_per_image = vision_config.image_size // vision_config.patch_size
1713+
avg_pool_k = patches_per_image ** 2 // text_config.mm_tokens_per_image
1714+
self.avg_pool = nn.AvgPool1d(kernel_size=avg_pool_k, stride=avg_pool_k)
17361715
self.vocab_size = text_config.vocab_size
17371716
self.pad_token_id = (
17381717
pad_token_id
@@ -1781,7 +1760,12 @@ def get_image_features(self, pixel_values: torch.Tensor) -> torch.Tensor:
17811760
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
17821761
"""
17831762
vision_outputs = self.vision_model(pixel_values=pixel_values).last_hidden_state
1784-
pooled_vision_outputs = self.avg_pool(vision_outputs)
1763+
b, n, l = vision_outputs.shape
1764+
reshaped_vision_outputs = vision_outputs.permute(0, 2, 1)
1765+
reshaped_vision_outputs = reshaped_vision_outputs.contiguous()
1766+
reshaped_vision_outputs = reshaped_vision_outputs.view(b, l, n)
1767+
pooled_vision_outputs = self.avg_pool(reshaped_vision_outputs)
1768+
pooled_vision_outputs = pooled_vision_outputs.permute(0, 2, 1)
17851769
image_features = self.encode_vision(pooled_vision_outputs)
17861770
return image_features
17871771

0 commit comments

Comments
 (0)