Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
As per Issue gudovskiy#30 we get error when we use ResNet50 as feature extractor. When is register hook to final layer of bottleneck then proper pool dims are generated.
  • Loading branch information
shivarajkarki authored Jun 6, 2023
1 parent b2ebf9e commit 0024ccd
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,21 +92,21 @@ def load_encoder_arch(c, L):
#
if L >= 3:
encoder.layer2.register_forward_hook(get_activation(pool_layers[pool_cnt]))
if 'wide' in c.enc_arch:
if 'wide' or 'resnet50' in c.enc_arch:
pool_dims.append(encoder.layer2[-1].conv3.out_channels)
else:
pool_dims.append(encoder.layer2[-1].conv2.out_channels)
pool_cnt = pool_cnt + 1
if L >= 2:
encoder.layer3.register_forward_hook(get_activation(pool_layers[pool_cnt]))
if 'wide' in c.enc_arch:
if 'wide' or 'resnet50' in c.enc_arch:
pool_dims.append(encoder.layer3[-1].conv3.out_channels)
else:
pool_dims.append(encoder.layer3[-1].conv2.out_channels)
pool_cnt = pool_cnt + 1
if L >= 1:
encoder.layer4.register_forward_hook(get_activation(pool_layers[pool_cnt]))
if 'wide' in c.enc_arch:
if 'wide' or 'resnet50' in c.enc_arch:
pool_dims.append(encoder.layer4[-1].conv3.out_channels)
else:
pool_dims.append(encoder.layer4[-1].conv2.out_channels)
Expand Down

0 comments on commit 0024ccd

Please sign in to comment.