diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 1fb4c4f1022e..449eda3ee821 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -449,17 +449,17 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, ViTEmbeddings): - nn.init.trunc_normal_( - module.position_embeddings, + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, - ) + ).to(module.position_embeddings.dtype) - nn.init.trunc_normal_( - module.cls_token, + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, - ) + ).to(module.cls_token.dtype) def _set_gradient_checkpointing(self, module: ViTEncoder, value: bool = False) -> None: if isinstance(module, ViTEncoder): diff --git a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py index df35c13f689d..7ad3183421bc 100644 --- a/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py +++ b/src/transformers/models/vit_hybrid/modeling_vit_hybrid.py @@ -474,17 +474,17 @@ def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> No module.bias.data.zero_() module.weight.data.fill_(1.0) elif isinstance(module, ViTHybridEmbeddings): - nn.init.trunc_normal_( - module.position_embeddings, + module.position_embeddings.data = nn.init.trunc_normal_( + module.position_embeddings.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, - ) + ).to(module.position_embeddings.dtype) - nn.init.trunc_normal_( - module.cls_token, + module.cls_token.data = nn.init.trunc_normal_( + module.cls_token.data.to(torch.float32), mean=0.0, std=self.config.initializer_range, - ) + ).to(module.cls_token.dtype) def _set_gradient_checkpointing(self, module: ViTHybridEncoder, value: bool = False) -> None: if isinstance(module, ViTHybridEncoder):