Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SimCLR: add new trainer #1195

Merged
merged 10 commits into from
Mar 29, 2023
Merged

SimCLR: add new trainer #1195

merged 10 commits into from
Mar 29, 2023

Conversation

adamjstewart
Copy link
Collaborator

@adamjstewart adamjstewart commented Mar 25, 2023

@nilsleh this is the new LightningModule template I would like to use. If this looks good to people, I'll update our older LightningModules too. Summary of differences:

  • No *args or **kwargs (hides argument typos, no default values, no type hints)
  • No typing.Any (disables type checking)
  • No typing.cast or # type: ignore (not necessary in latest Lightning version)
  • __init__ first (should be first thing in docs)
  • No custom functions (stick with LightningModule methods)
  • Simpler configure_optimizers (no need for fancy dictionary)

@adamjstewart adamjstewart added this to the 0.5.0 milestone Mar 25, 2023
@github-actions github-actions bot added testing Continuous integration testing trainers PyTorch Lightning trainers labels Mar 25, 2023
@github-actions github-actions bot added the documentation Improvements or additions to documentation label Mar 25, 2023
Args:
model: Name of the timm model to use.
in_channels: Number of input channels to model.
version: Version of SimCLR, 1--2.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't really used at the moment since layers and weight_decay are also parameters, but it could be used to control other things in the future (see TODOs).


# TODO
# v1+: add global batch norm
# v2: add selective kernels, channel-wise attention mechanism, memory bank
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure exactly how to make these changes, and I don't really want to change the architecture too much to ensure that our pre-trained weights can be loaded in a vanilla model. The memory bank only adds +1% performance, so I don't really think it's worth the complexity.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that when most papers use SimCLR, they use v1 without all the tricks that get the ~1% improvement. I think it would be better to keep it simple.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The majority of the performance bump in v2 is thanks to the deeper projection head, which we have, so we should be good.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, in terms of performance bump:

  • Bigger ResNets, SK, channel-wise attention: +29%
  • Deeper projection head: +14%
  • Memory bank: +1%

So memory bank isn't high on my priority list, but adding SK and channel-wise attention may be worth it.

# Find positive example -> batch_size // 2 away from the original example
pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0)

# NT-Xent loss (aka InfoNCE loss)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both SimCLR and MoCo use InfoNCE loss, but there is no implementation in PyTorch. There are many libraries that implement it, but I'd rather not add yet another dependency.

def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None:
"""No-op, does nothing."""
# TODO
# v2: add distillation step
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would actually be very useful to add someday. Both using a large model to better train a small model, and self-distillation, have been found to greatly improve performance. I didn't do this because I'm not super familiar with teacher-student distillation methods.

isaaccorley
isaaccorley previously approved these changes Mar 27, 2023
Copy link
Collaborator

@isaaccorley isaaccorley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM

# Data augmentation
# https://github.com/google-research/simclr/blob/master/data_util.py
self.aug = K.AugmentationSequential(
K.RandomResizedCrop(size=(96, 96)),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be hardcoding this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hardcoded in BYOL (should actually be 224, not 96, let me fix this). We can make it a parameter if you want, but at the moment I don't know if we need it to be.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with fixing it for now but it's only set to 224 because that's what imagenet experiments use. It's probably better to not restrict to 224 in case we use higher res imagery.

cos_sim = F.cosine_similarity(x[:, None, :], x[None, :, :], dim=-1)

# Mask out cosine similarity to itself
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make NT-XEnt loss it's own nn.Module so we can test it and reuse it. Maybe in a future PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found several repos with their own InfoNCE loss implementation but they all implement it different and I don't know the math well enough to decide which is best. The implementation here assumes that there is exactly 1 positive pair and everything else is a negative pair. A more general implementation, or a faster implementation, is a lot more work to get right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the implementation is fine. I was suggesting that we make it a separate module since other SSL methods use it as well. But until we have another SSL method that uses it, I think it's fine to leave as is.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I'm about to add MoCo which also uses it, although their implementation is completely different, and I have no idea what the difference is.

@calebrob6
Copy link
Member

Can you test this on a real dataset (maybe eurosat100?) before merging?

Optimizer and learning rate scheduler.
"""
# Original paper uses LARS optimizer, but this is not defined in PyTorch
optimizer = AdamW(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the optimizer choice also be user defineable, as different model architectures work better with certain optimizers? Or would you expect/want a user to overwrite the configure_optimizers method in their inherited trainer class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our BYOL trainer supports specifying an optimizer, but none of the other trainers do. For now, I'm just using the optimizer used in the original paper. The only difficulty with making it user configurable is that each optimizer has different arguments. We could add a **kwargs that is used in the optimizer to handle this, but then we can't use it anywhere else (without a bit of hacking like we did in NAIPChesapeakeDataModule.

# For the middle layers, use bias and ReLU
self.model.fc = nn.Sequential(
self.model.fc,
nn.ReLU(inplace=True),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ReLU the desired/required activation function choice here or should that be more flexible?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just what the original paper used. Depends on how much customization we want to support.

@adamjstewart adamjstewart dismissed stale reviews from isaaccorley and ghost via 41a10d2 March 27, 2023 15:24
@adamjstewart
Copy link
Collaborator Author

Can you test this on a real dataset (maybe eurosat100?) before merging?

The following script runs without crashing:

from lightning.pytorch import Trainer

from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.trainers import SimCLRTask


datamodule = EuroSAT100DataModule(
    root="data/eurosat",
    batch_size=2,
    download=True,
)

model = SimCLRTask(
    model="resnet18",
    in_channels=13,
    max_epochs=1
)

trainer = Trainer(
    accelerator="cpu",
    max_epochs=1,
)

trainer.fit(model=model, datamodule=datamodule)

I'm kind of trusting our tests to make sure things "work". Once I add all of these trainers and @isaaccorley finishes the pretrain+train pipeline I'm planning on testing all of them on SSL4EO-S12 to make sure they actually work.

@adamjstewart adamjstewart dismissed stale reviews from ghost via 41a10d2 March 27, 2023 21:22
@calebrob6
Copy link
Member

To be a little more specific, I would expect that if you used this trainer with the default settings that you would at least observe the loss decreasing. Tests will check if the code executes but not whether it is doing what you would expect in a ML training sense -- I'm interested in whether this actually does self-supervised learning! Regardless, I think you'll figure that out in short order if you're running experiments.

@calebrob6 calebrob6 merged commit 39d6941 into main Mar 29, 2023
@calebrob6 calebrob6 deleted the trainers/simclr branch March 29, 2023 10:06
adamjstewart added a commit that referenced this pull request Mar 29, 2023
@adamjstewart adamjstewart removed this from the 0.5.0 milestone Mar 29, 2023
isaaccorley pushed a commit that referenced this pull request Mar 29, 2023
@adamjstewart adamjstewart mentioned this pull request Apr 16, 2023
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
* SimCLR: add new trainer

* Add tests

* Support custom number of MLP layers

* Change default params, add TODOs

* Fix mypy

* Fix docs and most of tests

* Fix all tests

* Fix support for older Kornia versions

* Fix support for older Kornia versions

* Crop should be 224, not 96
yichiac pushed a commit to yichiac/torchgeo that referenced this pull request Apr 29, 2023
@adamjstewart adamjstewart mentioned this pull request Sep 1, 2023
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants