diff --git a/pl_bolts/models/self_supervised/swav/swav_module.py b/pl_bolts/models/self_supervised/swav/swav_module.py index b604bd3d65..78e1be350a 100644 --- a/pl_bolts/models/self_supervised/swav/swav_module.py +++ b/pl_bolts/models/self_supervised/swav/swav_module.py @@ -11,9 +11,12 @@ import torch import torch.distributed as dist from torch import nn +from torch.optim.optimizer import Optimizer -from pl_bolts.models.self_supervised.swav.swav_resnet import resnet50, resnet18 +from typing import Callable, Optional +from pytorch_lightning.utilities import AMPType +from pl_bolts.models.self_supervised.swav.swav_resnet import resnet50, resnet18 from pl_bolts.transforms.dataset_normalizations import stl10_normalization, cifar10_normalization from pl_bolts.optimizers.lars_scheduling import LARSWrapper @@ -321,15 +324,15 @@ def configure_optimizers(self): def optimizer_step( self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - second_order_closure=None, - on_tpu=False, - using_native_amp=False, - using_lbfgs=False - ): + epoch: int, + batch_idx: int, + optimizer: Optimizer, + optimizer_idx: int, + optimizer_closure: Optional[Callable] = None, + on_tpu: bool = False, + using_native_amp: bool = False, + using_lbfgs: bool = False, + ) -> None: # warm-up + decay schedule placed here since LARSWrapper is not optimizer class # adjust LR of optim contained within LARSWrapper if self.lars_wrapper: @@ -340,14 +343,18 @@ def optimizer_step( param_group["lr"] = self.lr_schedule[self.trainer.global_step] # log LR (LearningRateLogger callback doesn't work with LARSWrapper) - learning_rate = {'learning_rate': self.lr_schedule[self.trainer.global_step]} - self.logger.log_metrics(learning_rate, step=self.trainer.global_step) - - # from lightning implementation - if using_native_amp: - self.trainer.scaler.step(optimizer) - else: - optimizer.step() + self.log('learning_rate', self.lr_schedule[self.trainer.global_step], on_step=True, on_epoch=False) + + super().optimizer_step( + epoch=epoch, + batch_idx=batch_idx, + optimizer=optimizer, + optimizer_idx=optimizer_idx, + optimizer_closure=optimizer_closure, + on_tpu=on_tpu, + using_native_amp=using_native_amp, + using_lbfgs=using_lbfgs, + ) def sinkhorn(self, Q, nmb_iters): with torch.no_grad(): diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index 97d626ca10..0521594ef2 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -131,5 +131,7 @@ def test_swav(tmpdir): gpus=0, fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, max_steps=3 ) - results = trainer.fit(model, datamodule) - assert results == 1 + trainer.fit(model, datamodule) + loss = trainer.progress_bar_dict['loss'] + + assert float(loss) > 0