Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cyclic imports in self-supervised #350

Merged
merged 1 commit into from
Nov 9, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions pl_bolts/models/self_supervised/resnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover

__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet50_bn', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']


model_urls = {
__all__ = [
'ResNet',
'resnet18',
'resnet34',
'resnet50',
'resnet50_bn',
'resnet101',
'resnet152',
'resnext50_32x4d',
'resnext101_32x8d',
'wide_resnet50_2',
'wide_resnet101_2',
]


MODEL_URLS = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50_bn': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
Expand Down Expand Up @@ -265,7 +275,7 @@ def forward(self, x):
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
state_dict = load_state_dict_from_url(MODEL_URLS[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
Expand All @@ -279,8 +289,7 @@ def resnet18(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)


def resnet34(pretrained=False, progress=True, **kwargs):
Expand All @@ -291,8 +300,7 @@ def resnet34(pretrained=False, progress=True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet50(pretrained: bool = False, progress: bool = True, **kwargs):
Expand All @@ -303,8 +311,7 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet50_bn(pretrained: bool = False, progress: bool = True, **kwargs):
Expand All @@ -315,8 +322,7 @@ def resnet50_bn(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50_bn', BottleneckBN, [3, 4, 6, 3], pretrained, progress,
**kwargs)
return _resnet('resnet50_bn', BottleneckBN, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet101(pretrained: bool = False, progress: bool = True, **kwargs):
Expand All @@ -327,8 +333,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', BottleneckBN, [3, 4, 23, 3], pretrained, progress,
**kwargs)
return _resnet('resnet101', BottleneckBN, [3, 4, 23, 3], pretrained, progress, **kwargs)


def resnet152(pretrained: bool = False, progress: bool = True, **kwargs):
Expand All @@ -339,8 +344,7 @@ def resnet152(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)


def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs):
Expand Down
3 changes: 2 additions & 1 deletion pl_bolts/utils/self_supervised.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pl_bolts.models.self_supervised import resnets
from pl_bolts.utils.semi_supervised import Identity


def torchvision_ssl_encoder(name, pretrained=False, return_all_feature_maps=False):
from pl_bolts.models.self_supervised import resnets

pretrained_model = getattr(resnets, name)(pretrained=pretrained, return_all_feature_maps=return_all_feature_maps)

pretrained_model.fc = Identity()
Expand Down