Skip to content

Commit

Permalink
[CLIPSeg] Make interpolate_pos_encoding default to True (huggingface#…
Browse files Browse the repository at this point in the history
…34419)

* Remove interpolate_pos_encoding

* Make fixup

* Make interpolate_pos_encoding default to True

* Reuse existing interpolation

* Add integration test
  • Loading branch information
NielsRogge authored and BernardZach committed Dec 6, 2024
1 parent 70f6321 commit cf4bd71
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
22 changes: 9 additions & 13 deletions src/transformers/models/clipseg/modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=True) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
Expand Down Expand Up @@ -535,7 +535,7 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
interpolate_pos_encoding (`bool`, *optional*, defaults to `True`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Expand Down Expand Up @@ -574,7 +574,7 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
interpolate_pos_encoding (`bool`, *optional*, defaults to `True`):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
Expand Down Expand Up @@ -845,14 +845,13 @@ def __init__(self, config: CLIPSegVisionConfig):

@add_start_docstrings_to_model_forward(CLIPSEG_VISION_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPSegVisionConfig)
# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer.forward
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
pixel_values: Optional[torch.FloatTensor],
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
interpolate_pos_encoding: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -864,9 +863,6 @@ def forward(
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if pixel_values is None:
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)

Expand Down Expand Up @@ -912,7 +908,7 @@ def forward(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
interpolate_pos_encoding: Optional[bool] = True,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Expand Down Expand Up @@ -1035,7 +1031,7 @@ def get_image_features(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
interpolate_pos_encoding: bool = True,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Expand Down Expand Up @@ -1091,7 +1087,7 @@ def forward(
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
interpolate_pos_encoding: bool = True,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPSegOutput]:
r"""
Expand Down Expand Up @@ -1397,7 +1393,7 @@ def forward(
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
interpolate_pos_encoding: bool = True,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CLIPSegOutput]:
r"""
Expand Down
4 changes: 2 additions & 2 deletions tests/models/clipseg/test_modeling_clipseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,15 +796,15 @@ def test_inference_image_segmentation(self):

# forward pass
with torch.no_grad():
outputs = model(**inputs, interpolate_pos_encoding=True)
outputs = model(**inputs)

# verify the predicted masks
self.assertEqual(
outputs.logits.shape,
torch.Size((3, 352, 352)),
)
expected_masks_slice = torch.tensor(
[[-7.4613, -7.4785, -7.3627], [-7.3268, -7.0898, -7.1333], [-6.9838, -6.7900, -6.8913]]
[[-7.4613, -7.4785, -7.3628], [-7.3268, -7.0899, -7.1333], [-6.9838, -6.7900, -6.8913]]
).to(torch_device)

self.assertTrue(torch.allclose(outputs.logits[0, :3, :3], expected_masks_slice, atol=1e-3))
Expand Down

0 comments on commit cf4bd71

Please sign in to comment.