Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 Fix torch.jit.trace for interpolate_pos_encoding in all vision models #33226

Merged
merged 12 commits into from
Sep 5, 2024
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
_import_structure["modeling_utils"] = ["PreTrainedModel"]

_import_structure["modeling_vision_utils"] = []
# PyTorch models structure

_import_structure["models.albert"].extend(
Expand Down
84 changes: 84 additions & 0 deletions src/transformers/modeling_vision_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import torch.nn as nn

from .utils import is_torch_available, torch_int
from .utils.logging import get_logger


if is_torch_available():
import torch

logger = get_logger(__name__)


def interpolate_pos_encoding(
embeddings: torch.Tensor,
position_embeddings: torch.Tensor,
height: int,
width: int,
patch_size: int | List[int],
num_class_embeds: int = 1,
interpolate_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support models that do not have class embeddings (e.g., SigLIP or Hiera) and
to enable torch.jit tracing.

Adapted from:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194 and
https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - num_class_embeds
num_positions = position_embeddings.shape[1] - num_class_embeds

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return position_embeddings

class_pos_embed = position_embeddings[:, :num_class_embeds]
patch_pos_embed = position_embeddings[:, num_class_embeds:]

dim = embeddings.shape[-1]

ph, pw = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
new_height = height // ph
new_width = width // pw

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

target_dtype = patch_pos_embed.dtype
if interpolate_dtype is not None:
patch_pos_embed = patch_pos_embed.to(interpolate_dtype)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
if interpolate_dtype is not None:
patch_pos_embed = patch_pos_embed.to(dtype=target_dtype)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
39 changes: 4 additions & 35 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_vision_utils import interpolate_pos_encoding as _interpolate_pos_encoding
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_beit import BeitConfig
Expand Down Expand Up @@ -151,40 +153,7 @@ def __init__(self, config: BeitConfig) -> None:
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1

patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
return _interpolate_pos_encoding(embeddings, self.position_embeddings, height, width, self.patch_size)

def forward(
self,
Expand Down Expand Up @@ -566,7 +535,7 @@ def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=

old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)

Expand Down
34 changes: 2 additions & 32 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
"""PyTorch BLIP model."""

import math
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
Expand All @@ -27,6 +26,7 @@
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import PreTrainedModel
from ...modeling_vision_utils import interpolate_pos_encoding as _interpolate_pos_encoding
from ...utils import (
ModelOutput,
add_start_docstrings,
Expand Down Expand Up @@ -233,37 +233,7 @@ def __init__(self, config: BlipVisionConfig):
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1

if num_patches == num_positions and height == width:
return self.position_embedding

class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
return _interpolate_pos_encoding(embeddings, self.position_embedding, height, width, self.patch_size)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
Expand Down
33 changes: 2 additions & 31 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
BaseModelOutputWithPoolingAndCrossAttentions,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_vision_utils import interpolate_pos_encoding as _interpolate_pos_encoding
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -199,37 +200,7 @@ def __init__(self, config: Blip2VisionConfig):
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1

if num_patches == num_positions and height == width:
return self.position_embedding

class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
return _interpolate_pos_encoding(embeddings, self.position_embedding, height, width, self.patch_size)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
Expand Down
39 changes: 4 additions & 35 deletions src/transformers/models/data2vec/modeling_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,15 @@
SemanticSegmenterOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_vision_utils import interpolate_pos_encoding as _interpolate_pos_encoding
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_data2vec_vision import Data2VecVisionConfig

Expand Down Expand Up @@ -150,40 +152,7 @@ def __init__(self, config: Data2VecVisionConfig) -> None:
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0]
patch_pos_embed = self.position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1

patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
return _interpolate_pos_encoding(embeddings, self.position_embeddings, height, width, self.patch_size)

def forward(
self,
Expand Down Expand Up @@ -575,7 +544,7 @@ def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=

old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)

Expand Down
36 changes: 3 additions & 33 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
MaskedImageModelingOutput,
)
from ...modeling_utils import PreTrainedModel
from ...modeling_vision_utils import interpolate_pos_encoding as _interpolate_pos_encoding
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import (
ModelOutput,
Expand Down Expand Up @@ -76,40 +77,9 @@ def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:
self.patch_size = config.patch_size

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

# return self.position_embeddings
num_patches = embeddings.shape[1] - 2
num_positions = self.position_embeddings.shape[1] - 2

if num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0, :]
dist_pos_embed = self.position_embeddings[:, 1, :]
patch_pos_embed = self.position_embeddings[:, 2:, :]
dim = embeddings.shape[-1]
h0 = height // self.patch_size
w0 = width // self.patch_size
# # we add a small number to avoid floating point error in the interpolation
# # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
return _interpolate_pos_encoding(
embeddings, self.position_embeddings, height, width, self.patch_size, num_class_embeds=2
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(
self,
Expand Down
Loading
Loading