-
Notifications
You must be signed in to change notification settings - Fork 323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
simclr fixes #329
simclr fixes #329
Conversation
Codecov Report
@@ Coverage Diff @@
## master #329 +/- ##
==========================================
- Coverage 82.00% 81.18% -0.82%
==========================================
Files 100 100
Lines 5639 5715 +76
==========================================
+ Hits 4624 4640 +16
- Misses 1015 1075 +60
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@ananyahjha93 how is it going here, seems like very important fix... |
d21b54a
to
b6ee9da
Compare
Hello @ananyahjha93! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2020-11-17 18:09:42 UTC |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pls check the argparser issues
e018bc4
to
2fb1f23
Compare
2fb1f23
to
b8595eb
Compare
47e8fdf
to
c1ced0e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dm.val_transforms = SimCLREvalDataTransform(32) | ||
dm.test_transforms = SimCLREvalDataTransform(32) | ||
args.num_samples = dm.num_samples | ||
dm = CIFAR10DataModule( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we simplify this one, wrap it as a func which does this hard overwrite?
num_workers=args.num_workers | ||
) | ||
|
||
dm.train_transforms = SimCLRFinetuneTransform( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well in fact this shall be in the dm as def set_finetune_transfroms(self) -> None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe @akihironitta could take it in another PR...
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform | ||
from pl_bolts.optimizers.lars_scheduling import LARSWrapper | ||
from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR | ||
class SyncFunction(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add docs to what it does?
if self.arch == 'resnet18': | ||
backbone = resnet18 | ||
elif self.arch == 'resnet50': | ||
backbone = resnet50 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if self.arch == 'resnet18': | |
backbone = resnet18 | |
elif self.arch == 'resnet50': | |
backbone = resnet50 | |
backbones = { | |
'resnet18': resnet18, | |
'resnet50': resnet50, | |
} | |
backbone = backbones[self.arch] |
|
||
def nt_xent_loss(self, out_1, out_2, temperature, eps=1e-6): | ||
""" | ||
assume out_1 and out_2 are normalized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assume out_1 and out_2 are normalized | |
Args: | |
assume out_1 and out_2 are normalized |
parser.add_argument('--meta_dir', default='.', type=str, help='path to meta.bin for imagenet') | ||
|
||
# model params | ||
parser.add_argument("--arch", default="resnet50", type=str, help="convnet architecture") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not for PL, this could be taken automatically from Model init arguments...
dm.val_transforms = SimCLREvalDataTransform(32) | ||
args.num_samples = dm.num_samples | ||
if args.dataset == 'stl10': | ||
dm = STL10DataModule( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we wrap these hardcoded dataset decedent changes in a function?
) | ||
|
||
# Implements Gaussian blur as described in the SimCLR paper | ||
def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0): | |
def __init__(self, kernel_size: int, p: float = 0.5, min: float = 0.1, max: float = 2.0): |
pls use longer var then p
and do not overwrite native func like min/max
just for the record @ananyahjha93 promised to fix all comments above in another PR 🐰 |
What does this PR do?
1 PR to fix a bunch of SimCLR issues.
Fixes #318
Fixes #227
Fixes #320
Fixes #322
Fixes #327
Fixes #179
Fixes #309
Fixes #355
Fixes #357
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃