Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential error in num_patches calculation in src/transformers/models/vit_mae/modeling_vit_mae.py #32410

Closed
1 of 4 tasks
ziyiss opened this issue Aug 4, 2024 · 7 comments · Fixed by #33330
Closed
1 of 4 tasks

Comments

@ziyiss
Copy link

ziyiss commented Aug 4, 2024

System Info

I've been examining the interpolate_pos_encoding method in the ViTMAE implementation, and I've discovered what appears to be an error in the calculation of num_patches. This error seems to have unintended consequences on the method's behavior especially when the input image has the same size as the pretrained image. Here are my findings:

  1. Current implementation:
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
    num_patches = embeddings.shape[1] - 1
    num_positions = self.position_embeddings.shape[1] - 1
    
    if num_patches == num_positions and height == width:
        return self.position_embeddings
    # ... rest of the method
  1. Issue with num_patches calculation:
  • The embeddings passed to this method are created in ViTMAEPatchEmbeddings.forward(),
    x = self.projection(pixel_values).flatten(2).transpose(1, 2)
  • For a 224x224 input image (same as pretrained size), x.shape is torch.Size([1, 196, 768]). This 196 represents the number of patches (14*14) without including a CLS token.
  • For another higher resolution input image: 480x640:
    embeddings.shape = [1, 1200, 768].
    Current: num_patches = 1200 - 1 = 1199 which shouldn't be the case, since the num_patches should be 30*40 = 1200.
  1. Consequences of the error:
  • I suppose in the current implementation, num_patches = embeddings.shape[1] - 1 the -1 is for removing the cls token?
  • With num_patches = embeddings.shape[1] - 1, we get num_patches = 196 - 1 = 195.
  • However, num_positions = self.position_embeddings.shape[1] - 1 = 197 - 1 = 196.
  • This causes num_patches != num_positions, even when using the pre-trained image size.
  1. Unintended behavior:
  • Due to this discrepancy, the condition num_patches == num_positions and height == width is never true.
  • As a result, the method never returns self.position_embeddings directly, even when it should (i.e., when the input image size matches the pretrained size).
  • This forces unnecessary interpolation even when interpolate_pos_encoding = True and the image size matches the pretrained size.
  1. Proposed fix:
    Change num_patches = embeddings.shape[1] - 1 to num_patches = embeddings.shape[1]

Can you confirm that this is indeed an error in the implementation? Are there any considerations I might be missing? Thank you for your time and for maintaining this amazing project!

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
    **num_patches = embeddings.shape[1] - 1**
    num_positions = self.position_embeddings.shape[1] - 1
    
    if num_patches == num_positions and height == width:
        return self.position_embeddings
    # ... rest of the method

Expected behavior

def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
    **num_patches = embeddings.shape[1]**
    num_positions = self.position_embeddings.shape[1] - 1
    
    if num_patches == num_positions and height == width:
        return self.position_embeddings
    # ... rest of the method
@ziyiss ziyiss added the bug label Aug 4, 2024
@qubvel
Copy link
Member

qubvel commented Aug 4, 2024

Hi @ziyiss thanks for reporting the issue and such a detailed description!

That indeed looks like a bug in the implementation, I can confirm that the following example produces different outputs for the default shape (224, 224) depending on the interpolate_pos_encodings flag value

import torch
import requests
from PIL import Image

from transformers import AutoImageProcessor, ViTMAEForPreTraining

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

processor = AutoImageProcessor.from_pretrained('facebook/vit-mae-base')
model = ViTMAEForPreTraining.from_pretrained('facebook/vit-mae-base')

inputs = processor(images=image, return_tensors="pt")
noise = torch.rand(size=(1, 196), device=inputs.pixel_values.device)

with torch.no_grad():
    outputs_no_interpolate = model(**inputs, noise=noise, interpolate_pos_encoding=False)
    outputs_with_interpolate = model(**inputs, noise=noise, interpolate_pos_encoding=True)

max_diff = torch.max(torch.abs(outputs_no_interpolate.logits - outputs_with_interpolate.logits)).item()
assert max_diff < 1e-4, f"Max diff is {max_diff}"

Output:

Traceback (most recent call last):
  File "/home/ubuntu/projects/transformers/test_vit_mae_pos_enc.py", line 22, in <module>
    assert max_diff < 1e-4, f"Max diff is {max_diff}"
AssertionError: Max diff is 0.013753533363342285

Change num_patches = embeddings.shape[1] - 1 to num_patches = embeddings.shape[1]

The fix you suggested works and the outputs match. This will be a breaking change for the model in case anyone uses the model with interpolate_pos_encoding=True by default, despite this, I think it needs to be fixed. @amyeroberts what do you think?

@ziyiss
Copy link
Author

ziyiss commented Aug 4, 2024

Hi @qubvel ,
Thank you for your quick reply and for confirming the bug! I appreciate you providing a more intuitive example that clearly demonstrates potential error caused by the bug!

I agree with your assessment that this will be a breaking change for the model, for users who have been using it with interpolate_pos_encoding=True by default, due to the unnecessary interpolation.

Regarding the test results you shared, I have a friendly reminder that might help in further testing(Correct me if I am wrong!:)). To further confirm the inconsistency (max_diff) caused by the bug and the interpolate_pos_encoding flag, it's advisable to set a random seed before each model forward run. This ensures that for each forward pass, the model produces deterministic results.

Without setting a random seed before each run, inconsistent results may occur regardless of other factors, which can also lead to significant differences in model output. This effect can be verified from the following test run:

with torch.no_grad():
    outputs_no_interpolate_1 = model(**inputs, interpolate_pos_encoding=False)
    outputs_no_interpolate_2 = model(**inputs, interpolate_pos_encoding=False)

max_diff = torch.max(torch.abs(outputs_no_interpolate_1.logits - outputs_no_interpolate_2.logits)).item()

try:
    assert max_diff < 1e-4, f"Max diff is {max_diff}"
except AssertionError as e:
    print(f"Error: {e}")
else:
    print(f'No AssertionError is raised, and max_diff: {max_diff}')

Output:
Error: Max diff is 2.819373369216919

After setting a random seed before both runs, able to get consistent results when passing the same inputs, no AssertionError raised:

with torch.no_grad():
    torch.manual_seed(2)
    outputs_no_interpolate_1 = model(**inputs, interpolate_pos_encoding=False)

    torch.manual_seed(2)
    outputs_no_interpolate_2 = model(**inputs, interpolate_pos_encoding=False)

max_diff = torch.max(torch.abs(outputs_no_interpolate_1.logits - outputs_no_interpolate_2.logits)).item()

try:
    assert max_diff < 1e-4, f"Max diff is {max_diff}"
except AssertionError as e:
    print(f"Error: {e}")
else:
    print(f'No AssertionError is raised, and max_diff: {max_diff}')

Output:
No AssertionError is raised, and max_diff: 0.0

Hope this example helps!

Thank you and all the developers for continuously developing and maintaining this repo! It has been incredibly helpful for learning and helpful for my ML work!

@qubvel qubvel added the Vision label Aug 5, 2024
@qubvel
Copy link
Member

qubvel commented Aug 5, 2024

Hi @ziyiss, thanks for the update.
In case you manually provide a noise vector the outputs should be consistent even without fixing the random seed

inputs = processor(images=image, return_tensors="pt")
noise = torch.rand(size=(1, 196), device=inputs.pixel_values.device)

with torch.no_grad():
    outputs_no_interpolate = model(**inputs, noise=noise, interpolate_pos_encoding=False)
    outputs_with_interpolate = model(**inputs, noise=noise, interpolate_pos_encoding=True)

@qubvel
Copy link
Member

qubvel commented Aug 5, 2024

There are even more issues with the current interpolate_pos_encoding implementation, as stated in

current implementation may lead to an error for certain image sizes:

Example to reproduce:

import torch
import requests
from PIL import Image

from transformers import AutoImageProcessor, ViTMAEForPreTraining

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)

model = ViTMAEForPreTraining.from_pretrained('facebook/vit-mae-base')

for i in range(14, 30):
    for j in range(14, 30):
        processor = AutoImageProcessor.from_pretrained(
            'facebook/vit-mae-base',
            size={"height": i * 16, "width": j * 16},
            use_fast=True,
        )

        inputs = processor(images=image, return_tensors="pt")

        with torch.no_grad():
            try:
                outputs_with_interpolate = model(**inputs, interpolate_pos_encoding=True)
            except Exception as e:
                print(f"Failed with interpolate_pos_encoding=True for shape {inputs.pixel_values.shape}")
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 224, 272])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 256, 256])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 256, 432])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 272, 224])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 272, 448])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 288, 352])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 288, 368])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 288, 384])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 304, 416])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 320, 352])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 352, 288])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 352, 320])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 368, 288])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 384, 288])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 416, 304])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 432, 256])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 448, 272])
Failed with interpolate_pos_encoding=True for shape torch.Size([1, 3, 464, 464])

This is caused by using the scale_factor in interpolation instead of size. Moreover, the method relies on the fact that the model was trained with rectangular images, but it would be more robust to use actual image size to reshape embeddings before interpolation.

@ziyiss
Copy link
Author

ziyiss commented Aug 6, 2024

Thank you @qubvel for the clarification and additional insights!
I greatly appreciate you pointing out the detail about manually providing a noise vector. This is a valuable piece of information I ignored before!

I'm also very grateful for you sharing information about the other related bugs in the interpolate_pos_encoding implementation! Thanks again!

@amyeroberts
Copy link
Collaborator

The fix you suggested works and the outputs match. This will be a breaking change for the model in case anyone uses the model with interpolate_pos_encoding=True by default, despite this, I think it needs to be fixed. @amyeroberts what do you think?

Yes - we should fix it! If I've understood correctly, this issue derives from the original implementation but we can address by switching the logic to rely on size rather than scale_factor when interpolating?

@xenova
Copy link
Contributor

xenova commented Sep 5, 2024

#33330 should fix this 👍

@huggingface huggingface deleted a comment from github-actions bot Sep 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants