diff --git a/CHANGELOG.md b/CHANGELOG.md index 52a67f64af..93db4fe211 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Replaced `load_boston` with `load_diabetes` in the docs and tests ([#629](https://github.com/PyTorchLightning/lightning-bolts/pull/629)) +- Added base encoder and MLP dimension arguments to BYOL constructor ([#637](https://github.com/PyTorchLightning/lightning-bolts/pull/637)) + + ### Deprecated diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 3107a1956f..514e6e2964 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -1,6 +1,6 @@ from argparse import ArgumentParser from copy import deepcopy -from typing import Any +from typing import Any, Union import pytorch_lightning as pl import torch @@ -72,6 +72,10 @@ def __init__( num_workers: int = 0, warmup_epochs: int = 10, max_epochs: int = 1000, + base_encoder: Union[str, torch.nn.Module] = 'resnet50', + encoder_out_dim: int = 2048, + projector_hidden_size: int = 4096, + projector_out_dim: int = 256, **kwargs ): """ @@ -84,11 +88,15 @@ def __init__( num_workers: number of workers warmup_epochs: num of epochs for scheduler warm up max_epochs: max epochs for scheduler + base_encoder: the base encoder module or resnet name + encoder_out_dim: output dimension of base_encoder + projector_hidden_size: hidden layer size of projector MLP + projector_out_dim: output size of projector MLP """ super().__init__() - self.save_hyperparameters() + self.save_hyperparameters(ignore='base_encoder') - self.online_network = SiameseArm() + self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_size, projector_out_dim) self.target_network = deepcopy(self.online_network) self.weight_callback = BYOLMAWeightUpdate() diff --git a/pl_bolts/models/self_supervised/byol/models.py b/pl_bolts/models/self_supervised/byol/models.py index 53b90bf6ef..d7e5e87a29 100644 --- a/pl_bolts/models/self_supervised/byol/models.py +++ b/pl_bolts/models/self_supervised/byol/models.py @@ -23,17 +23,17 @@ def forward(self, x): class SiameseArm(nn.Module): - def __init__(self, encoder=None): + def __init__(self, encoder='resnet50', encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256): super().__init__() - if encoder is None: - encoder = torchvision_ssl_encoder('resnet50') + if isinstance(encoder, str): + encoder = torchvision_ssl_encoder(encoder) # Encoder self.encoder = encoder # Projector - self.projector = MLP() + self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim) # Predictor - self.predictor = MLP(input_dim=256) + self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim) def forward(self, x): y = self.encoder(x)[0]