33URL: https://github.com/facebookresearch/dinov3/tree/main
44"""
55
6+ import os
67import argparse
7- from typing import Optional
88import torch
99
1010import random
1111import numpy as np
1212from torchvision import transforms
1313import requests
1414from PIL import Image
15- from transformers import DINOv3ViTConfig , DINOv3ViTModel
15+ from transformers import DINOv3ViTConfig , DINOv3ViTModel , DINOv3ViTImageProcessorFast
1616from huggingface_hub import hf_hub_download
1717
1818HUB_MODELS = {
3434}
3535
3636
37- def get_dinov3_config (model_name : str ) -> Optional [ DINOv3ViTConfig ] :
37+ def get_dinov3_config (model_name : str ) -> DINOv3ViTConfig :
3838 # size of the architecture
3939 if model_name == "vits" :
4040 return DINOv3ViTConfig (
@@ -149,7 +149,6 @@ def get_dinov3_config(model_name: str) -> Optional[DINOv3ViTConfig]:
149149 else :
150150 raise ValueError ("Model not supported" )
151151
152-
153152def convert_dinov3_vit_to_hf_vit (original_dinov3_state_dict , config : DINOv3ViTConfig ):
154153 embed_dim = config .hidden_size
155154 hf_dinov3_state_dict = {}
@@ -204,7 +203,7 @@ def prepare_img():
204203 return image
205204
206205
207- def make_transform (resize_size : int = 224 ):
206+ def get_transform (resize_size : int = 224 ):
208207 to_tensor = transforms .ToTensor ()
209208 resize = transforms .Resize ((resize_size , resize_size ), antialias = True )
210209 normalize = transforms .Normalize (
@@ -213,6 +212,12 @@ def make_transform(resize_size: int = 224):
213212 )
214213 return transforms .Compose ([to_tensor , resize , normalize ])
215214
215+ def get_image_processor (resize_size : int = 224 ):
216+ return DINOv3ViTImageProcessorFast (
217+ do_resize = True ,
218+ size = {"height" : resize_size , "width" : resize_size },
219+ resample = 2 , # BILINEAR
220+ )
216221
217222def set_deterministic (seed = 42 ):
218223 random .seed (seed )
@@ -230,7 +235,7 @@ def set_deterministic(seed=42):
230235
231236
232237@torch .no_grad ()
233- def convert_and_test_dinov3_checkpoint (model_name ):
238+ def convert_and_test_dinov3_checkpoint (args ):
234239 expected_outputs = {
235240 "vits_cls" : [
236241 0.4635618329048157 ,
@@ -317,6 +322,7 @@ def convert_and_test_dinov3_checkpoint(model_name):
317322 - 0.026546532288193703 ,
318323 ],
319324 }
325+ model_name = args .model_name
320326 config = get_dinov3_config (model_name )
321327 print (config )
322328
@@ -330,35 +336,47 @@ def convert_and_test_dinov3_checkpoint(model_name):
330336 model .load_state_dict (hf_state_dict , strict = True )
331337 model = model .eval ()
332338
333- image_preprocessor = make_transform ()
334- # load image
335- images = [image_preprocessor (prepare_img ())]
336- image_tensor = torch .stack (images , dim = 0 )
337- with torch .inference_mode ():
338- with torch .autocast ("cuda" , dtype = torch .float ):
339- model_output = model (image_tensor )
339+ transform = get_transform ()
340+ image_processor = get_image_processor ()
341+ image = prepare_img ()
342+
343+ # check preprocessing
344+ original_pixel_values = transform (image ).unsqueeze (0 ) # add batch dimension
345+ inputs = image_processor (image , return_tensors = "pt" )
346+
347+ torch .testing .assert_close (original_pixel_values , inputs ["pixel_values" ], atol = 1e-6 , rtol = 1e-6 )
348+ print ("Preprocessing looks ok!" )
349+
350+ with torch .inference_mode (), torch .autocast ("cuda" , dtype = torch .float ):
351+ model_output = model (** inputs )
340352
341353 last_layer_class_token = model_output .pooler_output
342- last_layer_patch_tokens = model_output .last_hidden_state [
343- :, config .num_register_tokens + 1 :
344- ]
354+ last_layer_patch_tokens = model_output .last_hidden_state [:, config .num_register_tokens + 1 :]
355+
345356 actual_outputs = {}
346357 actual_outputs [f"{ model_name } _cls" ] = last_layer_class_token [0 , :5 ].tolist ()
347358 actual_outputs [f"{ model_name } _patch" ] = last_layer_patch_tokens [0 , 0 , :5 ].tolist ()
348- print (actual_outputs [f"{ model_name } _cls" ], expected_outputs [f"{ model_name } _cls" ])
359+
360+ print ("Actual: " , actual_outputs [f"{ model_name } _cls" ])
361+ print ("Expected:" , expected_outputs [f"{ model_name } _cls" ])
362+
349363 torch .testing .assert_close (
350364 torch .Tensor (actual_outputs [f"{ model_name } _cls" ]),
351365 torch .Tensor (expected_outputs [f"{ model_name } _cls" ]),
352- atol = 1e-3 ,
353- rtol = 1e-3 ,
366+ atol = 1e-4 , rtol = 1e-4 ,
354367 )
355368 torch .testing .assert_close (
356369 torch .Tensor (actual_outputs [f"{ model_name } _patch" ]),
357370 torch .Tensor (expected_outputs [f"{ model_name } _patch" ]),
358- atol = 1e-3 ,
359- rtol = 1e-3 ,
371+ atol = 1e-4 , rtol = 1e-4 ,
360372 )
361- print ("Looks ok!" )
373+ print ("Forward pass looks ok!" )
374+
375+ save_dir = os .path .join (args .save_dir , model_name )
376+ os .makedirs (save_dir , exist_ok = True )
377+ model .save_pretrained (save_dir )
378+ image_processor .save_pretrained (save_dir )
379+ print (f"Model saved to { save_dir } " )
362380
363381
364382if __name__ == "__main__" :
@@ -371,5 +389,11 @@ def convert_and_test_dinov3_checkpoint(model_name):
371389 choices = ["vits" , "vitsplus" , "vitb" , "vitl" , "vithplus" , "vit7b" ],
372390 help = "Name of the model you'd like to convert." ,
373391 )
392+ parser .add_argument (
393+ "--save-dir" ,
394+ default = "converted_models" ,
395+ type = str ,
396+ help = "Directory to save the converted model." ,
397+ )
374398 args = parser .parse_args ()
375- convert_and_test_dinov3_checkpoint (args . model_name )
399+ convert_and_test_dinov3_checkpoint (args )
0 commit comments