diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 95175b7bb6..b89f178072 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -22,9 +22,9 @@ def test_cpcv2(tmpdir, datadir): datamodule.train_transforms = CPCTrainTransformsCIFAR10() datamodule.val_transforms = CPCEvalTransformsCIFAR10() - model = CPCV2(encoder='resnet18', data_dir=datadir, batch_size=2, online_ft=True, datamodule=datamodule) + model = CPCV2(encoder='resnet18', online_ft=True, num_classes=datamodule.num_classes) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir) - trainer.fit(model) + trainer.fit(model, datamodule) loss = trainer.progress_bar_dict['val_nce'] assert float(loss) > 0