Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix swav to run on imagenet #348

Merged
merged 15 commits into from
Nov 10, 2020
8 changes: 4 additions & 4 deletions pl_bolts/callbacks/ssl_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data

# log metrics
train_acc = accuracy(mlp_preds, y)
pl_module.log('train_acc', train_acc, on_step=True, on_epoch=False)
pl_module.log('train_mlp_loss', mlp_loss, on_step=True, on_epoch=False)
pl_module.log('online_train_acc', train_acc, on_step=True, on_epoch=False)
pl_module.log('online_train_loss', mlp_loss, on_step=True, on_epoch=False)

def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
x, y = self.to_device(batch, pl_module.device)
Expand All @@ -119,5 +119,5 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx,

# log metrics
val_acc = accuracy(mlp_preds, y)
pl_module.log('val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log('val_mlp_loss', mlp_loss, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log('online_val_acc', val_acc, on_step=False, on_epoch=True, sync_dist=True)
pl_module.log('online_val_loss', mlp_loss, on_step=False, on_epoch=True, sync_dist=True)
2 changes: 2 additions & 0 deletions pl_bolts/datamodules/imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
self.meta_dir = meta_dir
self.num_imgs_per_val_class = num_imgs_per_val_class
self.batch_size = batch_size
self.num_samples = 1281167 - self.num_imgs_per_val_class * self.num_classes

@property
def num_classes(self):
Expand Down Expand Up @@ -144,6 +145,7 @@ def train_dataloader(self):

dataset = UnlabeledImagenet(self.data_dir,
num_imgs_per_class=-1,
num_imgs_per_class_val_split=self.num_imgs_per_val_class,
meta_dir=self.meta_dir,
split='train',
transform=transforms)
Expand Down
5 changes: 3 additions & 2 deletions pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,12 @@ def cli_main(): # pragma: no-cover

backbone = SwAV(
gpus=args.gpus,
nodes=1,
num_samples=args.num_samples,
batch_size=args.batch_size,
datamodule=dm,
maxpool1=args.maxpool1,
first_conv=args.first_conv,
dataset='imagenet',
dataset=args.dataset,
).load_from_checkpoint(args.ckpt_path, strict=False)

tuner = SSLFineTuner(
Expand All @@ -117,6 +117,7 @@ def cli_main(): # pragma: no-cover

trainer = pl.Trainer(
gpus=args.gpus,
num_nodes=1,
precision=16,
max_epochs=args.num_epochs,
distributed_backend='ddp',
Expand Down
58 changes: 50 additions & 8 deletions pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.optimizers.lars_scheduling import LARSWrapper
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization, stl10_normalization
from pl_bolts.transforms.dataset_normalizations import (
stl10_normalization,
cifar10_normalization,
imagenet_normalization
)


class SwAV(pl.LightningModule):
def __init__(
self,
gpus: int,
nodes: int,
num_samples: int,
batch_size: int,
dataset: str,
Expand Down Expand Up @@ -54,8 +59,9 @@ def __init__(
):
"""
Args:
gpus: number of gpus used in training, passed to SwAV module
gpus: number of gpus per node used in training, passed to SwAV module
to manage the queue and select distributed sinkhorn
nodes: number of nodes to train on
num_samples: number of image samples used for training
batch_size: batch size per GPU in ddp
dataset: dataset being used for train/val
Expand Down Expand Up @@ -94,6 +100,7 @@ def __init__(
self.save_hyperparameters()

self.gpus = gpus
self.nodes = nodes
self.arch = arch
self.dataset = dataset
self.num_samples = num_samples
Expand Down Expand Up @@ -127,15 +134,15 @@ def __init__(
self.warmup_epochs = warmup_epochs
self.max_epochs = max_epochs

if self.gpus > 1:
if self.gpus * self.nodes > 1:
self.get_assignments = self.distributed_sinkhorn
else:
self.get_assignments = self.sinkhorn

self.model = self.init_model()

# compute iters per epoch
global_batch_size = self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
global_batch_size = self.nodes * self.gpus * self.batch_size if self.gpus > 0 else self.batch_size
self.train_iters_per_epoch = self.num_samples // global_batch_size

# define LR schedule
Expand Down Expand Up @@ -435,6 +442,7 @@ def add_model_specific_args(parent_parser):

# training params
parser.add_argument("--fast_dev_run", action='store_true')
parser.add_argument("--nodes", default=1, type=int, help="number of nodes for training")
parser.add_argument("--gpus", default=1, type=int, help="number of gpus to train on")
parser.add_argument("--num_workers", default=16, type=int, help="num of workers per GPU")
parser.add_argument("--optimizer", default="adam", type=str, help="choose between adam/sgd")
Expand Down Expand Up @@ -471,8 +479,8 @@ def add_model_specific_args(parent_parser):

def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform
from pl_bolts.models.self_supervised.swav.transforms import SwAVTrainDataTransform, SwAVEvalDataTransform
from pl_bolts.datamodules import STL10DataModule, CIFAR10DataModule, ImagenetDataModule

parser = ArgumentParser()

Expand Down Expand Up @@ -515,11 +523,44 @@ def cli_main():
args.size_crops = [32, 16]
args.nmb_crops = [2, 1]
args.gaussian_blur = False
elif args.dataset == 'imagenet':
args.maxpool1 = True
args.first_conv = True
normalization = imagenet_normalization()

args.size_crops = [224, 96]
args.min_scale_crops = [0.14, 0.05]
args.max_scale_crops = [1., 0.14]
args.gaussian_blur = True
args.jitter_strength = 1.

args.batch_size = 64
args.nodes = 8
args.gpus = 8 # per-node
args.max_epochs = 800

args.optimizer = 'sgd'
args.lars_wrapper = True
args.learning_rate = 4.8
args.final_lr = 0.0048
args.start_lr = 0.3

args.nmb_prototypes = 3000
args.online_ft = True

dm = ImagenetDataModule(
data_dir=args.data_path,
batch_size=args.batch_size,
num_workers=args.num_workers
)

args.num_samples = dm.num_samples
args.input_height = dm.size()[-1]
else:
raise NotImplementedError("other datasets have not been implemented till now")

dm.train_transforms = SwAVTrainDataTransform(
normalize=stl10_normalization(),
normalize=normalization,
size_crops=args.size_crops,
nmb_crops=args.nmb_crops,
min_scale_crops=args.min_scale_crops,
Expand All @@ -529,7 +570,7 @@ def cli_main():
)

dm.val_transforms = SwAVEvalDataTransform(
normalize=stl10_normalization(),
normalize=normalization,
size_crops=args.size_crops,
nmb_crops=args.nmb_crops,
min_scale_crops=args.min_scale_crops,
Expand All @@ -556,6 +597,7 @@ def cli_main():
max_epochs=args.max_epochs,
max_steps=None if args.max_steps == -1 else args.max_steps,
gpus=args.gpus,
num_nodes=args.nodes,
distributed_backend='ddp' if args.gpus > 1 else None,
sync_batchnorm=True if args.gpus > 1 else False,
precision=32 if args.fp32 else 16,
Expand Down
6 changes: 5 additions & 1 deletion pl_bolts/models/self_supervised/swav/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,12 @@ def __init__(
]

if self.gaussian_blur:
kernel_size = int(0.1 * self.size_crops[0])
if kernel_size % 2 == 0:
kernel_size += 1

color_transform.append(
GaussianBlur(kernel_size=int(0.1 * self.size_crops[0]), p=0.5)
GaussianBlur(kernel_size=kernel_size, p=0.5)
)

self.color_transform = transforms.Compose(color_transform)
Expand Down
1 change: 1 addition & 0 deletions tests/models/self_supervised/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_swav(tmpdir, datadir):
arch='resnet18',
hidden_mlp=512,
gpus=0,
nodes=1,
num_samples=datamodule.num_samples,
batch_size=batch_size,
nmb_crops=[2, 1],
Expand Down