@@ -81,6 +81,7 @@ def get_dinov3_config(model_name: str) -> Optional[DINOv3ViTConfig]:
8181 num_attention_heads = 12 ,
8282 mask_k_bias = True ,
8383 qkv_bias = True ,
84+ proj_bias = True ,
8485 num_register_tokens = 4 ,
8586 layerscale_value = 1.0 ,
8687 mlp_ratio = 4 ,
@@ -213,109 +214,109 @@ def make_transform(resize_size: int = 224):
213214 return transforms .Compose ([to_tensor , resize , normalize ])
214215
215216
217+ def set_deterministic (seed = 42 ):
218+ random .seed (seed )
219+ np .random .seed (seed )
220+ torch .manual_seed (seed )
221+ torch .cuda .manual_seed (seed )
222+ torch .cuda .manual_seed_all (seed )
223+ torch .backends .cudnn .deterministic = True
224+ torch .backends .cudnn .benchmark = False
225+ torch .backends .cudnn .enabled = False
226+
227+
228+ seed = 42 # any number
229+ set_deterministic (seed = seed )
230+
231+
216232@torch .no_grad ()
217233def convert_and_test_dinov3_checkpoint (model_name ):
218-
219234 expected_outputs = {
220235 "vits_cls" : [
221- 0.47379571199417114 ,
222- - 0.41561394929885864 ,
223- 0.41169291734695435 ,
224- - 0.12478338927030563 ,
225- - 0.2959742844104767 ,
236+ 0.4635618329048157 ,
237+ - 0.41560935974121094 ,
238+ 0.40823689103126526 ,
239+ - 0.12661336362361908 ,
240+ - 0.28663691878318787 ,
226241 ],
227242 "vits_patch" : [
228- - 0.03959187492728233 ,
229- - 0.25311151146888733 ,
230- - 0.015847790986299515 ,
231- - 0.45699289441108704 ,
232- 0.5675609707832336 ,
243+ - 0.03875422105193138 ,
244+ - 0.2508954405784607 ,
245+ - 0.01639290526509285 ,
246+ - 0.4554736316204071 ,
247+ 0.5715821981430054 ,
233248 ],
234249 "vitsplus_cls" : [
235- - 0.4748912751674652 ,
236- - 1.3652222156524658 ,
237- - 0.32735151052474976 ,
238- 0.3742392957210541 ,
239- - 0.7740300893783569 ,
250+ - 0.47134941816329956 ,
251+ - 1.365778923034668 ,
252+ - 0.3179832398891449 ,
253+ 0.37721940875053406 ,
254+ - 0.769085705280304 ,
240255 ],
241256 "vitsplus_patch" : [
242- 0.14932650327682495 ,
243- - 0.3805270791053772 ,
244- - 0.4004722833633423 ,
245- - 0.15717053413391113 ,
246- - 0.5877845287322998 ,
257+ 0.14455188810825348 ,
258+ - 0.3881174623966217 ,
259+ - 0.39343395829200745 ,
260+ - 0.1576954871416092 ,
261+ - 0.6003801226615906 ,
247262 ],
248263 "vitb_cls" : [
249- 1.048130750656128 ,
250- - 0.16398264467716217 ,
251- - 0.3483588695526123 ,
252- - 0.07031229883432388 ,
253- - 0.018643084913492203 ,
264+ 1.0346431732177734 ,
265+ - 0.18060928583145142 ,
266+ - 0.3410182595252991 ,
267+ - 0.0663769543170929 ,
268+ - 0.011383970268070698 ,
254269 ],
255270 "vitb_patch" : [
256- - 0.0795423611998558 ,
257- - 0.45527052879333496 ,
258- - 0.7357183694839478 ,
259- - 0.4356740117073059 ,
260- - 0.14763328433036804 ,
271+ - 0.08252374082803726 ,
272+ - 0.45627278089523315 ,
273+ - 0.7280299663543701 ,
274+ - 0.4306802451610565 ,
275+ - 0.15288019180297852 ,
261276 ],
262277 "vitl_cls" : [
263- 0.4834900200366974 ,
264- - 0.587904155254364 ,
265- 0.476875901222229 ,
266- 0.5853531360626221 ,
267- 0.9454823136329651 ,
278+ 0.4845271110534668 ,
279+ - 0.5822147130966187 ,
280+ 0.4806361198425293 ,
281+ 0.5920403599739075 ,
282+ 0.9451664686203003 ,
268283 ],
269284 "vitl_patch" : [
270- - 0.21309036016464233 ,
271- - 0.49482738971710205 ,
272- - 0.2584819495677948 ,
273- 0.1072424128651619 ,
274- 0.14616338908672333 ,
285+ - 0.2113673835992813 ,
286+ - 0.490863561630249 ,
287+ - 0.2571314871311188 ,
288+ 0.10176393389701843 ,
289+ 0.1545112431049347 ,
275290 ],
276291 "vithplus_cls" : [
277- - 0.06420943140983582 ,
278- - 0.1494205743074417 ,
279- - 0.618586540222168 ,
280- 0.6363415122032166 ,
281- 0.15246111154556274 ,
292+ - 0.0645759105682373 ,
293+ - 0.14886680245399475 ,
294+ - 0.6215243935585022 ,
295+ 0.6348787546157837 ,
296+ 0.1526956558227539 ,
282297 ],
283298 "vithplus_patch" : [
284- - 0.09335622191429138 ,
285- 0.28375640511512756 ,
286- - 0.049649134278297424 ,
287- 0.4244541823863983 ,
288- 0.0950070321559906 ,
299+ - 0.09381738305091858 ,
300+ 0.287407249212265 ,
301+ - 0.05003691464662552 ,
302+ 0.4280431866645813 ,
303+ 0.09456184506416321 ,
289304 ],
290305 "vit7b_cls" : [
291- 0.27555006742477417 ,
292- - 0.2604803442955017 ,
293- 0.06795521825551987 ,
294- 0.05062410980463028 ,
295- - 0.15915830433368683 ,
306+ 0.2754395306110382 ,
307+ - 0.261353999376297 ,
308+ 0.0677720308303833 ,
309+ 0.049936190247535706 ,
310+ - 0.15874707698822021 ,
296311 ],
297312 "vit7b_patch" : [
298- 0.04416150599718094 ,
299- - 0.05306466668844223 ,
300- 0.0719609260559082 ,
301- - 0.06456729769706726 ,
302- - 0.026268284767866135 ,
313+ 0.04444204643368721 ,
314+ - 0.05254213139414787 ,
315+ 0.07077747583389282 ,
316+ - 0.0651116818189621 ,
317+ - 0.026546532288193703 ,
303318 ],
304319 }
305-
306- def set_deterministic (seed = 42 ):
307- random .seed (seed )
308- np .random .seed (seed )
309- torch .manual_seed (seed )
310- torch .cuda .manual_seed (seed )
311- torch .cuda .manual_seed_all (seed )
312- torch .backends .cudnn .deterministic = True
313- torch .backends .cudnn .benchmark = False
314- torch .backends .cudnn .enabled = False
315-
316- seed = 42 # any number
317- set_deterministic (seed = seed )
318-
319320 config = get_dinov3_config (model_name )
320321 print (config )
321322
@@ -334,7 +335,7 @@ def set_deterministic(seed=42):
334335 images = [image_preprocessor (prepare_img ())]
335336 image_tensor = torch .stack (images , dim = 0 )
336337 with torch .inference_mode ():
337- with torch .autocast ("cuda" , dtype = torch .bfloat16 ):
338+ with torch .autocast ("cuda" , dtype = torch .float ):
338339 model_output = model (image_tensor )
339340
340341 last_layer_class_token = model_output .pooler_output
@@ -348,14 +349,14 @@ def set_deterministic(seed=42):
348349 torch .testing .assert_close (
349350 torch .Tensor (actual_outputs [f"{ model_name } _cls" ]),
350351 torch .Tensor (expected_outputs [f"{ model_name } _cls" ]),
351- atol = 1e-2 ,
352- rtol = 1e-2 ,
352+ atol = 1e-3 ,
353+ rtol = 1e-3 ,
353354 )
354355 torch .testing .assert_close (
355356 torch .Tensor (actual_outputs [f"{ model_name } _patch" ]),
356357 torch .Tensor (expected_outputs [f"{ model_name } _patch" ]),
357- atol = 1e-2 ,
358- rtol = 1e-2 ,
358+ atol = 1e-3 ,
359+ rtol = 1e-3 ,
359360 )
360361 print ("Looks ok!" )
361362
0 commit comments