Skip to content

Commit 03e9fba

Browse files
committed
fix pretrain weights loading
1 parent ed0116e commit 03e9fba

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

networks/swin_transformer.py

+2
Original file line numberDiff line numberDiff line change
@@ -1319,6 +1319,8 @@ def init_weights(self, pretrained=None):
13191319
#ckpt = _load_checkpoint(self,
13201320
# pretrained, logger=logger, map_location='cpu')
13211321
ckpt = torch.load(pretrained,map_location='cpu')
1322+
if 'teacher' in ckpt:
1323+
ckpt = ckpt['teacher']
13221324
if 'state_dict' in ckpt:
13231325
_state_dict = ckpt['state_dict']
13241326
elif 'model' in ckpt:

0 commit comments

Comments
 (0)