Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simclr fixes #329

Merged
merged 46 commits into from
Nov 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
91e3034
initial
ananyahjha93 Nov 2, 2020
1fe59ba
updated transform
ananyahjha93 Nov 12, 2020
e0053be
changes
ananyahjha93 Nov 12, 2020
0f178de
updates simclr
ananyahjha93 Nov 13, 2020
dab4370
added to docs
ananyahjha93 Nov 13, 2020
71504a2
fix imports
ananyahjha93 Nov 13, 2020
5c81197
updated transforms
ananyahjha93 Nov 13, 2020
a21ddb7
encoder fix
ananyahjha93 Nov 13, 2020
20fdce1
encoder fix
ananyahjha93 Nov 13, 2020
254b236
.
ananyahjha93 Nov 13, 2020
673a56e
.
ananyahjha93 Nov 13, 2020
f3e5769
resnet update
ananyahjha93 Nov 13, 2020
14877fa
resnet
ananyahjha93 Nov 13, 2020
9402ae4
.
ananyahjha93 Nov 13, 2020
af04225
fix transforms
ananyahjha93 Nov 13, 2020
8e3d246
fix
ananyahjha93 Nov 13, 2020
9268416
fix
ananyahjha93 Nov 13, 2020
bc38088
fix
ananyahjha93 Nov 13, 2020
fc49882
fix
ananyahjha93 Nov 13, 2020
0b1b746
sync
ananyahjha93 Nov 15, 2020
c3bc62a
sync
ananyahjha93 Nov 15, 2020
f599015
fix negative sync
ananyahjha93 Nov 15, 2020
03fb75c
transform fix
ananyahjha93 Nov 15, 2020
89f0c49
fix transforms
ananyahjha93 Nov 15, 2020
e8aa9d1
transforms + finetuner
ananyahjha93 Nov 15, 2020
59b036b
transforms + finetuner
ananyahjha93 Nov 15, 2020
73d93d8
run
ananyahjha93 Nov 16, 2020
7315f0a
fix tests
ananyahjha93 Nov 16, 2020
58d98ec
fix val_split issue for cifar
ananyahjha93 Nov 16, 2020
3c454b9
fix val_split issue
ananyahjha93 Nov 16, 2020
e72ce70
fix tests
ananyahjha93 Nov 16, 2020
c77a6cd
fixes
ananyahjha93 Nov 16, 2020
3486b2a
fix tests
ananyahjha93 Nov 16, 2020
0ac1ef8
fix tests
ananyahjha93 Nov 16, 2020
77626d7
fix tests
ananyahjha93 Nov 16, 2020
b080ca0
fix tests
ananyahjha93 Nov 16, 2020
b1cb38a
updated swav stl10 path
ananyahjha93 Nov 16, 2020
b8595eb
updated nt_xent_loss
ananyahjha93 Nov 16, 2020
dfe8e7d
grad version of all_reduce
ananyahjha93 Nov 16, 2020
3e2b8f3
grad version of all_reduce
ananyahjha93 Nov 16, 2020
f7bc230
pep8
ananyahjha93 Nov 16, 2020
c1ced0e
syncfunc with all_redice
ananyahjha93 Nov 17, 2020
786f570
fix pep8
ananyahjha93 Nov 17, 2020
07bdb1e
sum at sync
ananyahjha93 Nov 17, 2020
77de560
change grad sync to sum op
ananyahjha93 Nov 17, 2020
634a050
doc fixes
ananyahjha93 Nov 17, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions docs/source/self_supervised_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ Model implemented by:

- `William Falcon <https://github.com/williamFalcon>`_
- `Tullie Murrell <https://github.com/tullie>`_
- `Ananya Harsh Jha <https://github.com/ananyahjha93>`_

To Train::

Expand All @@ -289,7 +290,7 @@ To Train::
dm.val_transforms = SimCLREvalDataTransform(32)

# model
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size)
model = SimCLR(num_samples=dm.num_samples, batch_size=dm.batch_size, dataset='cifar10')

# fit
trainer = pl.Trainer()
Expand All @@ -310,21 +311,29 @@ CIFAR-10 baseline
- Hardware
- LR
* - `Original <https://github.com/google-research/simclr#finetuning-the-linear-head-linear-eval>`_
- `92.00? <https://github.com/google-research/simclr#finetuning-the-linear-head-linear-eval>`_
- `~94.00 <https://github.com/google-research/simclr#finetuning-the-linear-head-linear-eval>`_
- resnet50
- LARS
- 512
- 1000
- 1 V100 (32GB)
- 1.0
- 2048
- 800
- TPUs
- 1.0/1.5
* - Ours
- `85.68 <https://tensorboard.dev/experiment/GlS1eLXMQsqh3T5DAec6UQ/#scalars>`_
- `resnet50 <https://github.com/PyTorchLightning/PyTorch-Lightning-Bolts/blob/master/pl_bolts/models/self_supervised/resnets.py#L301-L309>`_
- `LARS <https://pytorch-lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- 512
- 960 (12 hr)
- 1 V100 (32GB)
- 1e-6
- `LARS-SGD <https://pytorch-lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- 2048
- 800 (~4 hours)
- 8 V100 (16GB)
- 1.5
* - Ours
- `85.68 <https://tensorboard.dev/experiment/GlS1eLXMQsqh3T5DAec6UQ/#scalars>`_
- `resnet50 <https://github.com/PyTorchLightning/PyTorch-Lightning-Bolts/blob/master/pl_bolts/models/self_supervised/resnets.py#L301-L309>`_
- `LARS-Adam <https://pytorch-lightning-bolts.readthedocs.io/en/latest/api/pl_bolts.optimizers.lars_scheduling.html#pl_bolts.optimizers.lars_scheduling.LARSWrapper>`_
- 2048
- 800 (~4 hours)
- 8 V100 (16GB)
- 1e-3

|

Expand Down Expand Up @@ -422,7 +431,7 @@ To Train::
model = SwAV(
gpus=1,
num_samples=dm.num_unlabeled_samples,
datamodule=dm,
dataset='stl10,
batch_size=batch_size
)

Expand Down
3 changes: 2 additions & 1 deletion pl_bolts/callbacks/ssl_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ class SSLOnlineEvaluator(Callback): # pragma: no-cover
"""
def __init__(
self,
dataset: str,
drop_p: float = 0.2,
hidden_dim: Optional[int] = None,
z_dim: int = None,
num_classes: int = None,
dataset: str = 'stl10'
):
"""
Args:
dataset: if stl10, need to get the labeled batch
drop_p: Dropout probability
hidden_dim: Hidden dimension for the fine-tune MLP
z_dim: Representation dimension
Expand Down
11 changes: 2 additions & 9 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def cosine_similarity(self, a, b):
return sim

def shared_step(self, batch, batch_idx):
(img_1, img_2), y = batch
(img_1, img_2, _), y = batch

# Image 1 to image 2 loss
y1, z1, h1 = self.online_network(img_1)
Expand Down Expand Up @@ -220,15 +220,8 @@ def cli_main():

model = BYOL(**args.__dict__)

def to_device(batch, device):
(x1, x2), y = batch
x1 = x1.to(device)
y = y.to(device)
return x1, y

# finetune in real-time
online_eval = SSLOnlineEvaluator(z_dim=2048, num_classes=dm.num_classes)
online_eval.to_device = to_device
online_eval = SSLOnlineEvaluator(dataset=args.dataset, z_dim=2048, num_classes=dm.num_classes)

trainer = pl.Trainer.from_argparse_args(args, max_steps=300000, callbacks=[online_eval])

Expand Down
4 changes: 0 additions & 4 deletions pl_bolts/models/self_supervised/byol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,13 @@ def __init__(self, encoder=None):
encoder = torchvision_ssl_encoder('resnet50')
# Encoder
self.encoder = encoder
# Pooler
self.pooler = nn.AdaptiveAvgPool2d((1, 1))
# Projector
self.projector = MLP()
# Predictor
self.predictor = MLP(input_dim=256)

def forward(self, x):
y = self.encoder(x)[0]
y = self.pooler(y)
y = y.view(y.size(0), -1)
z = self.projector(y)
h = self.predictor(z)
return y, z, h
99 changes: 36 additions & 63 deletions pl_bolts/models/self_supervised/resnets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn

from pl_bolts.utils.warnings import warn_missing_pkg
Expand All @@ -12,7 +13,6 @@
'resnet18',
'resnet34',
'resnet50',
'resnet50_bn',
'resnet101',
'resnet152',
'resnext50_32x4d',
Expand All @@ -25,7 +25,7 @@
MODEL_URLS = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50_bn': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
Expand Down Expand Up @@ -92,49 +92,6 @@ class Bottleneck(nn.Module):
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
# if norm_layer is None:
# norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
# self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
# self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
# self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride

def forward(self, x):
identity = x

out = self.conv1(x)
# out = self.bn1(out)
out = self.relu(out)

out = self.conv2(out)
# out = self.bn2(out)
out = self.relu(out)

out = self.conv3(out)
# out = self.bn3(out)

if self.downsample is not None:
identity = self.downsample(x)

out += identity
out = self.relu(out)

return out


class BottleneckBN(nn.Module):
expansion = 4

def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BottleneckBN, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
Expand Down Expand Up @@ -174,9 +131,20 @@ def forward(self, x):

class ResNet(nn.Module):

def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None, return_all_feature_maps=False):
def __init__(
self,
block,
layers,
num_classes=1000,
zero_init_residual=False,
groups=1,
width_per_group=64,
replace_stride_with_dilation=None,
norm_layer=None,
return_all_feature_maps=False,
first_conv=True,
maxpool1=True,
ananyahjha93 marked this conversation as resolved.
Show resolved Hide resolved
):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
Expand All @@ -194,11 +162,24 @@ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)

if first_conv:
self.conv1 = nn.Conv2d(
3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
)
else:
self.conv1 = nn.Conv2d(
3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False
)

self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

if maxpool1:
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
else:
self.maxpool = nn.MaxPool2d(kernel_size=1, stride=1)

self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
Expand Down Expand Up @@ -269,6 +250,9 @@ def forward(self, x):
x0 = self.layer3(x0)
x0 = self.layer4(x0)

x0 = self.avgpool(x0)
x0 = torch.flatten(x0, 1)

return [x0]


Expand Down Expand Up @@ -314,17 +298,6 @@ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs):
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet50_bn(pretrained: bool = False, progress: bool = True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_

Args:
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50_bn', BottleneckBN, [3, 4, 6, 3], pretrained, progress, **kwargs)


def resnet101(pretrained: bool = False, progress: bool = True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Expand All @@ -333,7 +306,7 @@ def resnet101(pretrained: bool = False, progress: bool = True, **kwargs):
pretrained: If True, returns a model pre-trained on ImageNet
progress: If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', BottleneckBN, [3, 4, 23, 3], pretrained, progress, **kwargs)
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)


def resnet152(pretrained: bool = False, progress: bool = True, **kwargs):
Expand Down
Loading