Skip to content

Commit

Permalink
[Fix] fix config and code errors of swin/hrformer (open-mmlab#1995)
Browse files Browse the repository at this point in the history
  • Loading branch information
ly015 authored Feb 22, 2023
1 parent 62ad3c7 commit 8f17d2a
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
lr=5e-4,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={'relative_position_bias_table': dict(
decay_mult=0.)})))
),
paramwise_cfg=dict(
custom_keys={'relative_position_bias_table': dict(decay_mult=0.)}))

# learning policy
param_scheduler = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
lr=5e-4,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={'relative_position_bias_table': dict(
decay_mult=0.)})))
),
paramwise_cfg=dict(
custom_keys={'relative_position_bias_table': dict(decay_mult=0.)}))

# learning policy
param_scheduler = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
lr=5e-4,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={'relative_position_bias_table': dict(
decay_mult=0.)})))
),
paramwise_cfg=dict(
custom_keys={'relative_position_bias_table': dict(decay_mult=0.)}))

# learning policy
param_scheduler = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
lr=5e-4,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={'relative_position_bias_table': dict(
decay_mult=0.)})))
),
paramwise_cfg=dict(
custom_keys={'relative_position_bias_table': dict(decay_mult=0.)}))

# learning policy
param_scheduler = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
lr=5e-4,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
})))
),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

# learning policy
param_scheduler = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
lr=5e-4,
betas=(0.9, 0.999),
weight_decay=0.01,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
})))
),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

# learning policy
param_scheduler = [
Expand Down
9 changes: 3 additions & 6 deletions mmpose/models/backbones/swin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from copy import deepcopy

import torch
Expand Down Expand Up @@ -668,13 +667,11 @@ def init_weights(self, pretrained=None):
and self.init_cfg['type'] == 'Pretrained'):
# Suppress zero_init_residual if use pretrained model.
logger = get_root_logger()
_state_dict = get_state_dict(
state_dict = get_state_dict(
self.init_cfg['checkpoint'], map_location='cpu')
if self.convert_weights:
# supported loading weight from original repo,
_state_dict = swin_converter(_state_dict)

state_dict = OrderedDict()
# supported loading weight from original repo
state_dict = swin_converter(state_dict)

# strip prefix of state_dict
if list(state_dict.keys())[0].startswith('module.'):
Expand Down

0 comments on commit 8f17d2a

Please sign in to comment.