diff --git a/docs/source/solo/utils.rst b/docs/source/solo/utils.rst index 83eb3636b..e623bdc47 100644 --- a/docs/source/solo/utils.rst +++ b/docs/source/solo/utils.rst @@ -257,4 +257,3 @@ forward ~~~~~~~ .. automethod:: solo.utils.positional_encoding.Summer.forward :noindex: - diff --git a/main_linear.py b/main_linear.py index bb0b14549..03820b9a2 100644 --- a/main_linear.py +++ b/main_linear.py @@ -203,9 +203,9 @@ def main(cfg: DictConfig): "logger": wandb_logger if cfg.wandb.enabled else None, "callbacks": callbacks, "enable_checkpointing": False, - "strategy": DDPStrategy(find_unused_parameters=False) - if cfg.strategy == "ddp" - else cfg.strategy, + "strategy": ( + DDPStrategy(find_unused_parameters=False) if cfg.strategy == "ddp" else cfg.strategy + ), } ) trainer = Trainer(**trainer_kwargs) diff --git a/main_pretrain.py b/main_pretrain.py index e83d5591e..3e06ae40b 100644 --- a/main_pretrain.py +++ b/main_pretrain.py @@ -221,9 +221,9 @@ def main(cfg: DictConfig): "logger": wandb_logger if cfg.wandb.enabled else None, "callbacks": callbacks, "enable_checkpointing": False, - "strategy": DDPStrategy(find_unused_parameters=False) - if cfg.strategy == "ddp" - else cfg.strategy, + "strategy": ( + DDPStrategy(find_unused_parameters=False) if cfg.strategy == "ddp" else cfg.strategy + ), } ) trainer = Trainer(**trainer_kwargs) diff --git a/scripts/pretrain/cifar/all4one.yaml b/scripts/pretrain/cifar/all4one.yaml index 7db7ea761..ccd7e7cd5 100644 --- a/scripts/pretrain/cifar/all4one.yaml +++ b/scripts/pretrain/cifar/all4one.yaml @@ -28,7 +28,7 @@ data: dataset: cifar100 # change here for cifar10 train_path: "./datasets/" val_path: "./datasets/" - format: "image_folder" + format: "image_folder" num_workers: 4 optimizer: name: "lars" diff --git a/solo/methods/base.py b/solo/methods/base.py index d771b4f3d..4020c7b7a 100644 --- a/solo/methods/base.py +++ b/solo/methods/base.py @@ -389,9 +389,11 @@ def configure_optimizers(self) -> Tuple[List, List]: if idxs_no_scheduler: partial_fn = partial( static_lr, - get_lr=scheduler["scheduler"].get_lr - if isinstance(scheduler, dict) - else scheduler.get_lr, + get_lr=( + scheduler["scheduler"].get_lr + if isinstance(scheduler, dict) + else scheduler.get_lr + ), param_group_indexes=idxs_no_scheduler, lrs_to_replace=[self.lr] * len(idxs_no_scheduler), ) diff --git a/solo/utils/positional_encodings.py b/solo/utils/positional_encodings.py index e65be6a49..c72483bfb 100644 --- a/solo/utils/positional_encodings.py +++ b/solo/utils/positional_encodings.py @@ -100,9 +100,7 @@ def forward(self, tensor): sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) emb_x = get_emb(sin_inp_x).unsqueeze(1) emb_y = get_emb(sin_inp_y) - emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type( - tensor.type() - ) + emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(tensor.type()) emb[:, :, : self.channels] = emb_x emb[:, :, self.channels : 2 * self.channels] = emb_y @@ -165,9 +163,7 @@ def forward(self, tensor): emb_x = get_emb(sin_inp_x).unsqueeze(1).unsqueeze(1) emb_y = get_emb(sin_inp_y).unsqueeze(1) emb_z = get_emb(sin_inp_z) - emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type( - tensor.type() - ) + emb = torch.zeros((x, y, z, self.channels * 3), device=tensor.device).type(tensor.type()) emb[:, :, :, : self.channels] = emb_x emb[:, :, :, self.channels : 2 * self.channels] = emb_y emb[:, :, :, 2 * self.channels :] = emb_z