Skip to content

Commit

Permalink
fix cpc (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
williamFalcon authored Jul 29, 2020
1 parent d794e2a commit 2fd2969
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,21 @@ def add_model_specific_args(parent_parser):
datamodule = CIFAR10DataModule.from_argparse_args(args)
datamodule.train_transforms = CPCTrainTransformsCIFAR10()
datamodule.val_transforms = CPCEvalTransformsCIFAR10()
args.patch_size = 8

elif args.dataset == 'stl10':
datamodule = STL10DataModule.from_argparse_args(args)
datamodule.train_dataloader = datamodule.train_dataloader_mixed
datamodule.val_dataloader = datamodule.val_dataloader_mixed
datamodule.train_transforms = CPCTrainTransformsSTL10()
datamodule.val_transforms = CPCEvalTransformsSTL10()
args.patch_size = 16

elif args.dataset == 'imagenet2012':
datamodule = SSLImagenetDataModule.from_argparse_args(args)
datamodule.train_transforms = CPCTrainTransformsImageNet128()
datamodule.val_transforms = CPCEvalTransformsImageNet128()
args.patch_size = 32

model = CPCV2(**vars(args), datamodule=datamodule)
trainer = pl.Trainer.from_argparse_args(args)
Expand Down

0 comments on commit 2fd2969

Please sign in to comment.