diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 8c5460212c..0e883711fd 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -151,7 +151,7 @@ def configure_optimizers(self): def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--online_ft', action='store_true', help='run online finetuner') - parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, imagenet2012, stl10') + parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet2012', 'stl10']) (args, _) = parser.parse_known_args() diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 83955acfcb..832866b1cc 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -320,7 +320,7 @@ def add_model_specific_args(parent_parser): parser.add_argument('--momentum', type=float, default=0.9) parser.add_argument('--weight_decay', type=float, default=1e-4) parser.add_argument('--data_dir', type=str, default='./') - parser.add_argument('--dataset', type=str, default='cifar10', help='cifar10, stl10, imagenet2012') + parser.add_argument('--dataset', type=str, default='cifar10', choices=['cifar10', 'imagenet2012', 'stl10']) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--use_mlp', action='store_true') parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') diff --git a/tests/models/self_supervised/test_ssl_scripts.py b/tests/models/self_supervised/test_ssl_scripts.py index 94796669b2..35641d9947 100644 --- a/tests/models/self_supervised/test_ssl_scripts.py +++ b/tests/models/self_supervised/test_ssl_scripts.py @@ -4,7 +4,7 @@ from tests import _MARK_REQUIRE_GPU, DATASETS_PATH -_DEFAULT_ARGS = f"--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 4 --batch_size 8 --num_workers 0" +_DEFAULT_ARGS = f"--data_dir {DATASETS_PATH} --max_epochs 1 --max_steps 2 --batch_size 8 --num_workers 0" # todo: failing for GPU as some is on CPU other on GPU