Skip to content

Commit 38481b4

Browse files
committed
bf16 -> fp32
1 parent c6daf9f commit 38481b4

File tree

1 file changed

+81
-80
lines changed

1 file changed

+81
-80
lines changed

src/transformers/models/dinov3_vit/convert_dinov3_vit_to_hf.py

Lines changed: 81 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def get_dinov3_config(model_name: str) -> Optional[DINOv3ViTConfig]:
8181
num_attention_heads=12,
8282
mask_k_bias=True,
8383
qkv_bias=True,
84+
proj_bias=True,
8485
num_register_tokens=4,
8586
layerscale_value=1.0,
8687
mlp_ratio=4,
@@ -213,109 +214,109 @@ def make_transform(resize_size: int = 224):
213214
return transforms.Compose([to_tensor, resize, normalize])
214215

215216

217+
def set_deterministic(seed=42):
218+
random.seed(seed)
219+
np.random.seed(seed)
220+
torch.manual_seed(seed)
221+
torch.cuda.manual_seed(seed)
222+
torch.cuda.manual_seed_all(seed)
223+
torch.backends.cudnn.deterministic = True
224+
torch.backends.cudnn.benchmark = False
225+
torch.backends.cudnn.enabled = False
226+
227+
228+
seed = 42 # any number
229+
set_deterministic(seed=seed)
230+
231+
216232
@torch.no_grad()
217233
def convert_and_test_dinov3_checkpoint(model_name):
218-
219234
expected_outputs = {
220235
"vits_cls": [
221-
0.47379571199417114,
222-
-0.41561394929885864,
223-
0.41169291734695435,
224-
-0.12478338927030563,
225-
-0.2959742844104767,
236+
0.4635618329048157,
237+
-0.41560935974121094,
238+
0.40823689103126526,
239+
-0.12661336362361908,
240+
-0.28663691878318787,
226241
],
227242
"vits_patch": [
228-
-0.03959187492728233,
229-
-0.25311151146888733,
230-
-0.015847790986299515,
231-
-0.45699289441108704,
232-
0.5675609707832336,
243+
-0.03875422105193138,
244+
-0.2508954405784607,
245+
-0.01639290526509285,
246+
-0.4554736316204071,
247+
0.5715821981430054,
233248
],
234249
"vitsplus_cls": [
235-
-0.4748912751674652,
236-
-1.3652222156524658,
237-
-0.32735151052474976,
238-
0.3742392957210541,
239-
-0.7740300893783569,
250+
-0.47134941816329956,
251+
-1.365778923034668,
252+
-0.3179832398891449,
253+
0.37721940875053406,
254+
-0.769085705280304,
240255
],
241256
"vitsplus_patch": [
242-
0.14932650327682495,
243-
-0.3805270791053772,
244-
-0.4004722833633423,
245-
-0.15717053413391113,
246-
-0.5877845287322998,
257+
0.14455188810825348,
258+
-0.3881174623966217,
259+
-0.39343395829200745,
260+
-0.1576954871416092,
261+
-0.6003801226615906,
247262
],
248263
"vitb_cls": [
249-
1.048130750656128,
250-
-0.16398264467716217,
251-
-0.3483588695526123,
252-
-0.07031229883432388,
253-
-0.018643084913492203,
264+
1.0346431732177734,
265+
-0.18060928583145142,
266+
-0.3410182595252991,
267+
-0.0663769543170929,
268+
-0.011383970268070698,
254269
],
255270
"vitb_patch": [
256-
-0.0795423611998558,
257-
-0.45527052879333496,
258-
-0.7357183694839478,
259-
-0.4356740117073059,
260-
-0.14763328433036804,
271+
-0.08252374082803726,
272+
-0.45627278089523315,
273+
-0.7280299663543701,
274+
-0.4306802451610565,
275+
-0.15288019180297852,
261276
],
262277
"vitl_cls": [
263-
0.4834900200366974,
264-
-0.587904155254364,
265-
0.476875901222229,
266-
0.5853531360626221,
267-
0.9454823136329651,
278+
0.4845271110534668,
279+
-0.5822147130966187,
280+
0.4806361198425293,
281+
0.5920403599739075,
282+
0.9451664686203003,
268283
],
269284
"vitl_patch": [
270-
-0.21309036016464233,
271-
-0.49482738971710205,
272-
-0.2584819495677948,
273-
0.1072424128651619,
274-
0.14616338908672333,
285+
-0.2113673835992813,
286+
-0.490863561630249,
287+
-0.2571314871311188,
288+
0.10176393389701843,
289+
0.1545112431049347,
275290
],
276291
"vithplus_cls": [
277-
-0.06420943140983582,
278-
-0.1494205743074417,
279-
-0.618586540222168,
280-
0.6363415122032166,
281-
0.15246111154556274,
292+
-0.0645759105682373,
293+
-0.14886680245399475,
294+
-0.6215243935585022,
295+
0.6348787546157837,
296+
0.1526956558227539,
282297
],
283298
"vithplus_patch": [
284-
-0.09335622191429138,
285-
0.28375640511512756,
286-
-0.049649134278297424,
287-
0.4244541823863983,
288-
0.0950070321559906,
299+
-0.09381738305091858,
300+
0.287407249212265,
301+
-0.05003691464662552,
302+
0.4280431866645813,
303+
0.09456184506416321,
289304
],
290305
"vit7b_cls": [
291-
0.27555006742477417,
292-
-0.2604803442955017,
293-
0.06795521825551987,
294-
0.05062410980463028,
295-
-0.15915830433368683,
306+
0.2754395306110382,
307+
-0.261353999376297,
308+
0.0677720308303833,
309+
0.049936190247535706,
310+
-0.15874707698822021,
296311
],
297312
"vit7b_patch": [
298-
0.04416150599718094,
299-
-0.05306466668844223,
300-
0.0719609260559082,
301-
-0.06456729769706726,
302-
-0.026268284767866135,
313+
0.04444204643368721,
314+
-0.05254213139414787,
315+
0.07077747583389282,
316+
-0.0651116818189621,
317+
-0.026546532288193703,
303318
],
304319
}
305-
306-
def set_deterministic(seed=42):
307-
random.seed(seed)
308-
np.random.seed(seed)
309-
torch.manual_seed(seed)
310-
torch.cuda.manual_seed(seed)
311-
torch.cuda.manual_seed_all(seed)
312-
torch.backends.cudnn.deterministic = True
313-
torch.backends.cudnn.benchmark = False
314-
torch.backends.cudnn.enabled = False
315-
316-
seed = 42 # any number
317-
set_deterministic(seed=seed)
318-
319320
config = get_dinov3_config(model_name)
320321
print(config)
321322

@@ -334,7 +335,7 @@ def set_deterministic(seed=42):
334335
images = [image_preprocessor(prepare_img())]
335336
image_tensor = torch.stack(images, dim=0)
336337
with torch.inference_mode():
337-
with torch.autocast("cuda", dtype=torch.bfloat16):
338+
with torch.autocast("cuda", dtype=torch.float):
338339
model_output = model(image_tensor)
339340

340341
last_layer_class_token = model_output.pooler_output
@@ -348,14 +349,14 @@ def set_deterministic(seed=42):
348349
torch.testing.assert_close(
349350
torch.Tensor(actual_outputs[f"{model_name}_cls"]),
350351
torch.Tensor(expected_outputs[f"{model_name}_cls"]),
351-
atol=1e-2,
352-
rtol=1e-2,
352+
atol=1e-3,
353+
rtol=1e-3,
353354
)
354355
torch.testing.assert_close(
355356
torch.Tensor(actual_outputs[f"{model_name}_patch"]),
356357
torch.Tensor(expected_outputs[f"{model_name}_patch"]),
357-
atol=1e-2,
358-
rtol=1e-2,
358+
atol=1e-3,
359+
rtol=1e-3,
359360
)
360361
print("Looks ok!")
361362

0 commit comments

Comments
 (0)