Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚨 Fix torch.jit.trace for interpolate_pos_encoding in all vision models #33226

Merged
merged 12 commits into from
Sep 5, 2024
40 changes: 23 additions & 17 deletions src/transformers/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from ...utils.backbone_utils import BackboneMixin
from .configuration_beit import BeitConfig
Expand Down Expand Up @@ -150,41 +151,46 @@ def __init__(self, config: BeitConfig) -> None:
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0]
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]

dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1

patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(
self,
Expand Down Expand Up @@ -566,7 +572,7 @@ def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=

old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)

Expand Down
44 changes: 26 additions & 18 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# limitations under the License.
"""PyTorch BLIP model."""

import math
import warnings
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
Expand All @@ -33,6 +32,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_blip import BlipConfig, BlipTextConfig, BlipVisionConfig
from .modeling_blip_text import BlipTextLMHeadModel, BlipTextModel
Expand Down Expand Up @@ -232,38 +232,46 @@ def __init__(self, config: BlipVisionConfig):

self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

if num_patches == num_positions and height == width:
return self.position_embedding
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]

class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
Expand Down
43 changes: 26 additions & 17 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from ..auto import AutoModelForCausalLM, AutoModelForSeq2SeqLM
from .configuration_blip_2 import Blip2Config, Blip2QFormerConfig, Blip2VisionConfig
Expand Down Expand Up @@ -198,38 +199,46 @@ def __init__(self, config: Blip2VisionConfig):

self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

if num_patches == num_positions and height == width:
return self.position_embedding
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]

class_pos_embed = self.position_embedding[:, 0, :]
patch_pos_embed = self.position_embedding[:, 1:, :]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
Expand Down
40 changes: 23 additions & 17 deletions src/transformers/models/data2vec/modeling_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_data2vec_vision import Data2VecVisionConfig

Expand Down Expand Up @@ -149,41 +150,46 @@ def __init__(self, config: Data2VecVisionConfig) -> None:
self.position_embeddings = None
self.dropout = nn.Dropout(config.hidden_dropout_prob)

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows the model to interpolate the pre-trained position encodings so that it can be used on
higher resolution images.
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0]
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]

dim = embeddings.shape[-1]
h = height // self.patch_size
w = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h, w = h + 0.1, w + 0.1

patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h / math.sqrt(num_positions), w / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)
if int(h) != patch_pos_embed.shape[-2] or int(w) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

return torch.cat((class_pos_embed, patch_pos_embed), dim=1)

def forward(
self,
Expand Down Expand Up @@ -575,7 +581,7 @@ def forward(self, window_size, interpolate_pos_encoding: bool = False, dim_size=

old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
new_sub_table = nn.functional.interpolate(
old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear"
old_sub_table, size=(torch_int(new_height), torch_int(new_width)), mode="bilinear"
)
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)

Expand Down
39 changes: 22 additions & 17 deletions src/transformers/models/deit/modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
add_start_docstrings_to_model_forward,
logging,
replace_return_docstrings,
torch_int,
)
from .configuration_deit import DeiTConfig

Expand Down Expand Up @@ -77,39 +78,43 @@ def __init__(self, config: DeiTConfig, use_mask_token: bool = False) -> None:

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
images. This method is also adapted to support torch.jit tracing and 2 class embeddings.

Adapted from:
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194, and
- https://github.com/facebookresearch/dinov2/blob/e1277af2ba9496fbadf7aec6eba56e8d882d1e35/dinov2/models/vision_transformer.py#L179-L211
"""

# return self.position_embeddings
num_patches = embeddings.shape[1] - 2
num_positions = self.position_embeddings.shape[1] - 2

if num_patches == num_positions and height == width:
# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings

class_pos_embed = self.position_embeddings[:, 0, :]
dist_pos_embed = self.position_embeddings[:, 1, :]
patch_pos_embed = self.position_embeddings[:, 2:, :]
class_and_dist_pos_embed = self.position_embeddings[:, :2]
patch_pos_embed = self.position_embeddings[:, 2:]

dim = embeddings.shape[-1]
h0 = height // self.patch_size
w0 = width // self.patch_size
# # we add a small number to avoid floating point error in the interpolation
# # see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)

new_height = height // self.patch_size
new_width = width // self.patch_size

sqrt_num_positions = torch_int(num_positions**0.5)
patch_pos_embed = patch_pos_embed.reshape(1, sqrt_num_positions, sqrt_num_positions, dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)

patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
size=(new_height, new_width),
mode="bicubic",
align_corners=False,
)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)

return torch.cat((class_pos_embed.unsqueeze(0), dist_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
return torch.cat((class_and_dist_pos_embed, patch_pos_embed), dim=1)

def forward(
self,
Expand Down
Loading
Loading