Skip to content

Commit 631d535

Browse files
authored
Merge pull request huggingface#2 from huggingface/add-image-processor
Add image processor
2 parents 38481b4 + 3d0cbdb commit 631d535

File tree

4 files changed

+151
-23
lines changed

4 files changed

+151
-23
lines changed

src/transformers/models/auto/image_processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
("detr", ("DetrImageProcessor", "DetrImageProcessorFast")),
8888
("dinat", ("ViTImageProcessor", "ViTImageProcessorFast")),
8989
("dinov2", ("BitImageProcessor", "BitImageProcessorFast")),
90+
("dinov3_vit", (None, "DINOv3ViTImageProcessorFast")),
9091
("donut-swin", ("DonutImageProcessor", "DonutImageProcessorFast")),
9192
("dpt", ("DPTImageProcessor", "DPTImageProcessorFast")),
9293
("efficientformer", ("EfficientFormerImageProcessor",)),

src/transformers/models/dinov3_vit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
if TYPE_CHECKING:
2121
from .configuration_dinov3_vit import *
2222
from .modeling_dinov3_vit import *
23+
from .image_processing_dinov3_vit_fast import *
2324
else:
2425
import sys
2526

src/transformers/models/dinov3_vit/convert_dinov3_vit_to_hf.py

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
URL: https://github.com/facebookresearch/dinov3/tree/main
44
"""
55

6+
import os
67
import argparse
7-
from typing import Optional
88
import torch
99

1010
import random
1111
import numpy as np
1212
from torchvision import transforms
1313
import requests
1414
from PIL import Image
15-
from transformers import DINOv3ViTConfig, DINOv3ViTModel
15+
from transformers import DINOv3ViTConfig, DINOv3ViTModel, DINOv3ViTImageProcessorFast
1616
from huggingface_hub import hf_hub_download
1717

1818
HUB_MODELS = {
@@ -34,7 +34,7 @@
3434
}
3535

3636

37-
def get_dinov3_config(model_name: str) -> Optional[DINOv3ViTConfig]:
37+
def get_dinov3_config(model_name: str) -> DINOv3ViTConfig:
3838
# size of the architecture
3939
if model_name == "vits":
4040
return DINOv3ViTConfig(
@@ -149,7 +149,6 @@ def get_dinov3_config(model_name: str) -> Optional[DINOv3ViTConfig]:
149149
else:
150150
raise ValueError("Model not supported")
151151

152-
153152
def convert_dinov3_vit_to_hf_vit(original_dinov3_state_dict, config: DINOv3ViTConfig):
154153
embed_dim = config.hidden_size
155154
hf_dinov3_state_dict = {}
@@ -204,7 +203,7 @@ def prepare_img():
204203
return image
205204

206205

207-
def make_transform(resize_size: int = 224):
206+
def get_transform(resize_size: int = 224):
208207
to_tensor = transforms.ToTensor()
209208
resize = transforms.Resize((resize_size, resize_size), antialias=True)
210209
normalize = transforms.Normalize(
@@ -213,6 +212,12 @@ def make_transform(resize_size: int = 224):
213212
)
214213
return transforms.Compose([to_tensor, resize, normalize])
215214

215+
def get_image_processor(resize_size: int = 224):
216+
return DINOv3ViTImageProcessorFast(
217+
do_resize=True,
218+
size={"height": resize_size, "width": resize_size},
219+
resample=2, # BILINEAR
220+
)
216221

217222
def set_deterministic(seed=42):
218223
random.seed(seed)
@@ -230,7 +235,7 @@ def set_deterministic(seed=42):
230235

231236

232237
@torch.no_grad()
233-
def convert_and_test_dinov3_checkpoint(model_name):
238+
def convert_and_test_dinov3_checkpoint(args):
234239
expected_outputs = {
235240
"vits_cls": [
236241
0.4635618329048157,
@@ -317,6 +322,7 @@ def convert_and_test_dinov3_checkpoint(model_name):
317322
-0.026546532288193703,
318323
],
319324
}
325+
model_name = args.model_name
320326
config = get_dinov3_config(model_name)
321327
print(config)
322328

@@ -330,35 +336,47 @@ def convert_and_test_dinov3_checkpoint(model_name):
330336
model.load_state_dict(hf_state_dict, strict=True)
331337
model = model.eval()
332338

333-
image_preprocessor = make_transform()
334-
# load image
335-
images = [image_preprocessor(prepare_img())]
336-
image_tensor = torch.stack(images, dim=0)
337-
with torch.inference_mode():
338-
with torch.autocast("cuda", dtype=torch.float):
339-
model_output = model(image_tensor)
339+
transform = get_transform()
340+
image_processor = get_image_processor()
341+
image = prepare_img()
342+
343+
# check preprocessing
344+
original_pixel_values = transform(image).unsqueeze(0) # add batch dimension
345+
inputs = image_processor(image, return_tensors="pt")
346+
347+
torch.testing.assert_close(original_pixel_values, inputs["pixel_values"], atol=1e-6, rtol=1e-6)
348+
print("Preprocessing looks ok!")
349+
350+
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float):
351+
model_output = model(**inputs)
340352

341353
last_layer_class_token = model_output.pooler_output
342-
last_layer_patch_tokens = model_output.last_hidden_state[
343-
:, config.num_register_tokens + 1 :
344-
]
354+
last_layer_patch_tokens = model_output.last_hidden_state[:, config.num_register_tokens + 1:]
355+
345356
actual_outputs = {}
346357
actual_outputs[f"{model_name}_cls"] = last_layer_class_token[0, :5].tolist()
347358
actual_outputs[f"{model_name}_patch"] = last_layer_patch_tokens[0, 0, :5].tolist()
348-
print(actual_outputs[f"{model_name}_cls"], expected_outputs[f"{model_name}_cls"])
359+
360+
print("Actual: ", actual_outputs[f"{model_name}_cls"])
361+
print("Expected:", expected_outputs[f"{model_name}_cls"])
362+
349363
torch.testing.assert_close(
350364
torch.Tensor(actual_outputs[f"{model_name}_cls"]),
351365
torch.Tensor(expected_outputs[f"{model_name}_cls"]),
352-
atol=1e-3,
353-
rtol=1e-3,
366+
atol=1e-4, rtol=1e-4,
354367
)
355368
torch.testing.assert_close(
356369
torch.Tensor(actual_outputs[f"{model_name}_patch"]),
357370
torch.Tensor(expected_outputs[f"{model_name}_patch"]),
358-
atol=1e-3,
359-
rtol=1e-3,
371+
atol=1e-4, rtol=1e-4,
360372
)
361-
print("Looks ok!")
373+
print("Forward pass looks ok!")
374+
375+
save_dir = os.path.join(args.save_dir, model_name)
376+
os.makedirs(save_dir, exist_ok=True)
377+
model.save_pretrained(save_dir)
378+
image_processor.save_pretrained(save_dir)
379+
print(f"Model saved to {save_dir}")
362380

363381

364382
if __name__ == "__main__":
@@ -371,5 +389,11 @@ def convert_and_test_dinov3_checkpoint(model_name):
371389
choices=["vits", "vitsplus", "vitb", "vitl", "vithplus", "vit7b"],
372390
help="Name of the model you'd like to convert.",
373391
)
392+
parser.add_argument(
393+
"--save-dir",
394+
default="converted_models",
395+
type=str,
396+
help="Directory to save the converted model.",
397+
)
374398
args = parser.parse_args()
375-
convert_and_test_dinov3_checkpoint(args.model_name)
399+
convert_and_test_dinov3_checkpoint(args)
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Fast Image processor class for DINOv3."""
16+
17+
from typing import Optional, Union
18+
19+
from transformers.image_processing_base import BatchFeature
20+
from transformers.image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
21+
from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, PILImageResampling, SizeDict
22+
from transformers.utils import (
23+
TensorType,
24+
auto_docstring,
25+
is_torch_available,
26+
is_torchvision_available,
27+
is_torchvision_v2_available,
28+
logging,
29+
)
30+
from transformers.utils.import_utils import requires
31+
32+
logger = logging.get_logger(__name__)
33+
34+
35+
if is_torch_available():
36+
import torch
37+
38+
if is_torchvision_v2_available():
39+
from torchvision.transforms.v2 import functional as F
40+
elif is_torchvision_available():
41+
from torchvision.transforms import functional as F
42+
43+
44+
@auto_docstring
45+
@requires(backends=("torchvision", "torch"))
46+
class DINOv3ViTImageProcessorFast(BaseImageProcessorFast):
47+
resample = PILImageResampling.BILINEAR
48+
image_mean = IMAGENET_DEFAULT_MEAN
49+
image_std = IMAGENET_DEFAULT_STD
50+
size = {"height": 224, "width": 224}
51+
do_resize = True
52+
do_rescale = True
53+
do_normalize = True
54+
55+
# Overriden for DINOv3 to preserve order of transforms
56+
# rescale -> resize -> normalize
57+
def _preprocess(
58+
self,
59+
images: list["torch.Tensor"],
60+
do_resize: bool,
61+
size: SizeDict,
62+
interpolation: Optional["F.InterpolationMode"],
63+
do_center_crop: bool,
64+
crop_size: SizeDict,
65+
do_rescale: bool,
66+
rescale_factor: float,
67+
do_normalize: bool,
68+
image_mean: Optional[Union[float, list[float]]],
69+
image_std: Optional[Union[float, list[float]]],
70+
disable_grouping: Optional[bool],
71+
return_tensors: Optional[Union[str, TensorType]],
72+
) -> BatchFeature:
73+
74+
# Group images by size for batched resizing
75+
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
76+
resized_images_grouped = {}
77+
for shape, stacked_images in grouped_images.items():
78+
if do_rescale:
79+
stacked_images = self.rescale(stacked_images, rescale_factor)
80+
if do_resize:
81+
stacked_images = self.resize(image=stacked_images, size=size, interpolation=interpolation, antialias=True)
82+
resized_images_grouped[shape] = stacked_images
83+
resized_images = reorder_images(resized_images_grouped, grouped_images_index)
84+
85+
# Group images by size for further processing
86+
# Needed in case do_resize is False, or resize returns images with different sizes
87+
grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
88+
processed_images_grouped = {}
89+
for shape, stacked_images in grouped_images.items():
90+
if do_center_crop:
91+
stacked_images = self.center_crop(stacked_images, crop_size)
92+
if do_normalize:
93+
stacked_images = self.normalize(stacked_images, image_mean, image_std)
94+
processed_images_grouped[shape] = stacked_images
95+
96+
processed_images = reorder_images(processed_images_grouped, grouped_images_index)
97+
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
98+
99+
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
100+
101+
102+
__all__ = ["DINOv3ViTImageProcessorFast"]

0 commit comments

Comments
 (0)