-
Notifications
You must be signed in to change notification settings - Fork 379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SimCLR: add new trainer #1195
SimCLR: add new trainer #1195
Conversation
Args: | ||
model: Name of the timm model to use. | ||
in_channels: Number of input channels to model. | ||
version: Version of SimCLR, 1--2. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't really used at the moment since layers and weight_decay are also parameters, but it could be used to control other things in the future (see TODOs).
|
||
# TODO | ||
# v1+: add global batch norm | ||
# v2: add selective kernels, channel-wise attention mechanism, memory bank |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure exactly how to make these changes, and I don't really want to change the architecture too much to ensure that our pre-trained weights can be loaded in a vanilla model. The memory bank only adds +1% performance, so I don't really think it's worth the complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would say that when most papers use SimCLR, they use v1 without all the tricks that get the ~1% improvement. I think it would be better to keep it simple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The majority of the performance bump in v2 is thanks to the deeper projection head, which we have, so we should be good.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, in terms of performance bump:
- Bigger ResNets, SK, channel-wise attention: +29%
- Deeper projection head: +14%
- Memory bank: +1%
So memory bank isn't high on my priority list, but adding SK and channel-wise attention may be worth it.
# Find positive example -> batch_size // 2 away from the original example | ||
pos_mask = self_mask.roll(shifts=cos_sim.shape[0] // 2, dims=0) | ||
|
||
# NT-Xent loss (aka InfoNCE loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both SimCLR and MoCo use InfoNCE loss, but there is no implementation in PyTorch. There are many libraries that implement it, but I'd rather not add yet another dependency.
def test_step(self, batch: Dict[str, Tensor], batch_idx: int) -> None: | ||
"""No-op, does nothing.""" | ||
# TODO | ||
# v2: add distillation step |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would actually be very useful to add someday. Both using a large model to better train a small model, and self-distillation, have been found to greatly improve performance. I didn't do this because I'm not super familiar with teacher-student distillation methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LGTM
torchgeo/trainers/simclr.py
Outdated
# Data augmentation | ||
# https://github.com/google-research/simclr/blob/master/data_util.py | ||
self.aug = K.AugmentationSequential( | ||
K.RandomResizedCrop(size=(96, 96)), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we be hardcoding this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's hardcoded in BYOL (should actually be 224, not 96, let me fix this). We can make it a parameter if you want, but at the moment I don't know if we need it to be.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay with fixing it for now but it's only set to 224 because that's what imagenet experiments use. It's probably better to not restrict to 224 in case we use higher res imagery.
cos_sim = F.cosine_similarity(x[:, None, :], x[None, :, :], dim=-1) | ||
|
||
# Mask out cosine similarity to itself | ||
self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could make NT-XEnt loss it's own nn.Module so we can test it and reuse it. Maybe in a future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found several repos with their own InfoNCE loss implementation but they all implement it different and I don't know the math well enough to decide which is best. The implementation here assumes that there is exactly 1 positive pair and everything else is a negative pair. A more general implementation, or a faster implementation, is a lot more work to get right.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the implementation is fine. I was suggesting that we make it a separate module since other SSL methods use it as well. But until we have another SSL method that uses it, I think it's fine to leave as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well I'm about to add MoCo which also uses it, although their implementation is completely different, and I have no idea what the difference is.
Can you test this on a real dataset (maybe eurosat100?) before merging? |
Optimizer and learning rate scheduler. | ||
""" | ||
# Original paper uses LARS optimizer, but this is not defined in PyTorch | ||
optimizer = AdamW( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the optimizer choice also be user defineable, as different model architectures work better with certain optimizers? Or would you expect/want a user to overwrite the configure_optimizers
method in their inherited trainer class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our BYOL trainer supports specifying an optimizer, but none of the other trainers do. For now, I'm just using the optimizer used in the original paper. The only difficulty with making it user configurable is that each optimizer has different arguments. We could add a **kwargs
that is used in the optimizer to handle this, but then we can't use it anywhere else (without a bit of hacking like we did in NAIPChesapeakeDataModule
.
# For the middle layers, use bias and ReLU | ||
self.model.fc = nn.Sequential( | ||
self.model.fc, | ||
nn.ReLU(inplace=True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is ReLU the desired/required activation function choice here or should that be more flexible?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just what the original paper used. Depends on how much customization we want to support.
The following script runs without crashing: from lightning.pytorch import Trainer
from torchgeo.datamodules import EuroSAT100DataModule
from torchgeo.trainers import SimCLRTask
datamodule = EuroSAT100DataModule(
root="data/eurosat",
batch_size=2,
download=True,
)
model = SimCLRTask(
model="resnet18",
in_channels=13,
max_epochs=1
)
trainer = Trainer(
accelerator="cpu",
max_epochs=1,
)
trainer.fit(model=model, datamodule=datamodule) I'm kind of trusting our tests to make sure things "work". Once I add all of these trainers and @isaaccorley finishes the pretrain+train pipeline I'm planning on testing all of them on SSL4EO-S12 to make sure they actually work. |
To be a little more specific, I would expect that if you used this trainer with the default settings that you would at least observe the loss decreasing. Tests will check if the code executes but not whether it is doing what you would expect in a ML training sense -- I'm interested in whether this actually does self-supervised learning! Regardless, I think you'll figure that out in short order if you're running experiments. |
This reverts commit 39d6941.
* SimCLR: add new trainer * Add tests * Support custom number of MLP layers * Change default params, add TODOs * Fix mypy * Fix docs and most of tests * Fix all tests * Fix support for older Kornia versions * Fix support for older Kornia versions * Crop should be 224, not 96
This reverts commit 39d6941.
@nilsleh this is the new LightningModule template I would like to use. If this looks good to people, I'll update our older LightningModules too. Summary of differences:
*args
or**kwargs
(hides argument typos, no default values, no type hints)typing.Any
(disables type checking)typing.cast
or# type: ignore
(not necessary in latest Lightning version)__init__
first (should be first thing in docs)configure_optimizers
(no need for fancy dictionary)