-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Image Dataloaders and Flax Resnet18 model #1
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly LGTM! Just a few minor suggestions/questions for the most part.
Also, I assume that you didn't mean to upload the __pycache__
folder?
And I think that the img_resnets.py
can just be renamed to resnets.py
, since a) there are no other kinds of resnets (yet?), and b) by default I think it is assumed that resnets will be the convolutional kind. But that is also minor.
data/image.py
Outdated
@@ -0,0 +1,227 @@ | |||
"""Image dataset functionality, borrowed from https://github.com/JamesAllingham/learning-invariances/blob/main/src/data/image.py.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It isn't really necessary to put this here! As I mentioned, this is code that has been adapted by me and Javi over the course of our PhDs, so this repo isn't really the origin.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it!
'CIFAR10': 10_000, | ||
'CIFAR100': 10_000, | ||
}, | ||
'mean': { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice addition!
} | ||
|
||
|
||
TRAIN_TRANSFORMATIONS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is also good!
data/image.py
Outdated
if flatten_img: | ||
common_transforms += [Flatten()] | ||
|
||
# Important when fitting linear model and sample-then-optimise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just to say that you need augmentations for this project, or is there something more to this comment? :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added some more detail. Basically, we need to not use augmented train data, when training for the linear mode, and for sampling from the posterior. This was actually a bug I'd chased down, where if we augment, then the sampling is misspecified, and we never converge.
if perform_augmentations: | ||
train_augmentations = TRAIN_TRANSFORMATIONS[dataset_name] | ||
else: | ||
train_augmentations = TEST_TRANSFORMATIONS[dataset_name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the non-imagenet cases, this makes sense since the test augementations are empty, but for imagenet does it make sense to be applying transformations when the user of the function has set perform_augmentations=False
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason we need Resize(256) and CenterCrop(224) for Imagenet is because, by default, test images in Imagenet are all of random sizes, and not uniform. So we still need some deterministic preprocessing to ensure all images are of size 224x224x3.
data/image.py
Outdated
random_seed: the `int` random seed for splitting the val data and | ||
applying random affine transformations. (Default: 42) | ||
perform_augmentations: a `bool` indicating whether to apply random | ||
affine transformations to the training data. (Default: `True`) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This docstring isn't exactly accurate, since the augmentations are not limited to affine transformations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point, affine is wrong there. made the change.
models/img_resnets.py
Outdated
y = self.conv(self.filters, (3, 3), padding=((1, 1), (1, 1)))(y) | ||
# ^ Not using Flax default padding since it doesn't match PyTorch | ||
|
||
# For pretrained bayesian-lottery-tickets models, don't init with 0 here. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this should be a flag that can be set by the user (if it is specific to BLT models)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do this when I add the torchhub model conversion fns that you had.
models/lenets.py
Outdated
return x | ||
|
||
|
||
class LeNetSmall(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is very similar to LeNet
, with only the sizes changing, I think this is a good place to use inheritance to reduce code duplication. Specifically, I'd change LeNet
to also have a separate setup
and __call__
and then when defining LeNetSmall
inherit from LeNet
and simply override the setup
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I have a doubt about inheritance. Since I'm redefining setup for LeNetSmall, I'll have to redefine self.dense
here, even though it is the same. Is there a way to only overwrite a subset of class properties after inheriting within Flax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm, maybe you can call super().setup()
in your inherited class's setup
function and then redefine only the attributes which need to change. However, I think this is something I've tried before and not had luck. You'll probably have to bite the bullet and have some code duplication I am afraid.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, you can definitely call super.setup()
and then add new attributes, however, I think an error is thrown when you try and overwrite attributes which have been defined in the parent's setup
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I've confirmed that redefining attributes when calling super.setup()
inside of the child class's setup()
doesn't work. E.g., trying this:
class MyNet(nn.Module):
def setup(self):
self.fc2 = nn.Dense(10, 10)
self.fc1 = nn.Dense(10, 10)
def __call__(self, x):
return self.fc2(self.fc1(x))
class MyNet2(MyNet):
def setup(self):
super().setup()
self.fc2 = nn.Dense(333, 10)
def __call__(self, x):
return self.fc2(self.fc1(x))
will result in this error, when trying to call init
:
ValueError: Duplicate use of scope name: "fc2"
However, if you really want to you can avoid the code duplication like this:
class MyNet(nn.Module):
def _partial_setup(self):
self.fc2 = nn.Dense(10, 10)
def setup(self):
self._partial_setup()
self.fc1 = nn.Dense(10, 10)
def __call__(self, x):
return self.fc2(self.fc1(x))
class MyNet2(MyNet):
def _partial_setup(self):
self.fc2 = nn.Dense(333, 10)
def __call__(self, x):
return self.fc2(self.fc1(x))
But I think in this case it is better to have the duplicated code, since it is only 1 line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for trying these out, this is super useful to know!
models/lenets.py
Outdated
self.conv2 = conv3_block(32, stride=2) | ||
self.conv3 = conv3_block(32, stride=2) | ||
|
||
@nn.compact |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This decorator should be removed if you have a setup
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way in Flax to mix and match? For example, I define self.conv{i}
in the setup, but if I want to compactly define nn.Dense
within the call without mentioning it in the setup? I'm curious about the syntax here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, it actually doesn't break anything to mix and match setup
and compact
, however, as far as I understand, this is not advertised functionality and there are plans to remove this functionality in future: google/flax#2018.
Why do you actually want to mux and match though? Seems to me that everything could be in setup
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, good to know! I wanted to be cheeky and define the Dense layer inside, but good to know it's not preferred behaviour.
Adding the first combined snippets of code to this repo! This PR contains -
jaxutils.data.image
that should contain all the necessary mean, std, size, num_datapoints information for all datasets.bayesian-lottery-tickets
to Flax.