Skip to content

Commit

Permalink
Added encoder argument to BYOL constructor (#637)
Browse files Browse the repository at this point in the history
* Added base_encoder argument to BYOL constructor

* Added encoder and MLP dimension args to BYOL

* Update changelog

* Updated docstring

* Update CHANGELOG.md

Co-authored-by: O'Donnell, Garry (DLSLtd,RAL,LSCI) <garry.o'donnell@diamond.ac.uk>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Jun 16, 2021
1 parent 6b7358c commit 3636142
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 8 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Replaced `load_boston` with `load_diabetes` in the docs and tests ([#629](https://github.com/PyTorchLightning/lightning-bolts/pull/629))


- Added base encoder and MLP dimension arguments to BYOL constructor ([#637](https://github.com/PyTorchLightning/lightning-bolts/pull/637))


### Deprecated


Expand Down
14 changes: 11 additions & 3 deletions pl_bolts/models/self_supervised/byol/byol_module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from argparse import ArgumentParser
from copy import deepcopy
from typing import Any
from typing import Any, Union

import pytorch_lightning as pl
import torch
Expand Down Expand Up @@ -72,6 +72,10 @@ def __init__(
num_workers: int = 0,
warmup_epochs: int = 10,
max_epochs: int = 1000,
base_encoder: Union[str, torch.nn.Module] = 'resnet50',
encoder_out_dim: int = 2048,
projector_hidden_size: int = 4096,
projector_out_dim: int = 256,
**kwargs
):
"""
Expand All @@ -84,11 +88,15 @@ def __init__(
num_workers: number of workers
warmup_epochs: num of epochs for scheduler warm up
max_epochs: max epochs for scheduler
base_encoder: the base encoder module or resnet name
encoder_out_dim: output dimension of base_encoder
projector_hidden_size: hidden layer size of projector MLP
projector_out_dim: output size of projector MLP
"""
super().__init__()
self.save_hyperparameters()
self.save_hyperparameters(ignore='base_encoder')

self.online_network = SiameseArm()
self.online_network = SiameseArm(base_encoder, encoder_out_dim, projector_hidden_size, projector_out_dim)
self.target_network = deepcopy(self.online_network)
self.weight_callback = BYOLMAWeightUpdate()

Expand Down
10 changes: 5 additions & 5 deletions pl_bolts/models/self_supervised/byol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,17 @@ def forward(self, x):

class SiameseArm(nn.Module):

def __init__(self, encoder=None):
def __init__(self, encoder='resnet50', encoder_out_dim=2048, projector_hidden_size=4096, projector_out_dim=256):
super().__init__()

if encoder is None:
encoder = torchvision_ssl_encoder('resnet50')
if isinstance(encoder, str):
encoder = torchvision_ssl_encoder(encoder)
# Encoder
self.encoder = encoder
# Projector
self.projector = MLP()
self.projector = MLP(encoder_out_dim, projector_hidden_size, projector_out_dim)
# Predictor
self.predictor = MLP(input_dim=256)
self.predictor = MLP(projector_out_dim, projector_hidden_size, projector_out_dim)

def forward(self, x):
y = self.encoder(x)[0]
Expand Down

0 comments on commit 3636142

Please sign in to comment.