Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Image Dataloaders and Flax Resnet18 model #1

Merged
merged 7 commits into from
Jun 13, 2022
Merged

Conversation

shreyaspadhy
Copy link
Owner

@shreyaspadhy shreyaspadhy commented Jun 10, 2022

Adding the first combined snippets of code to this repo! This PR contains -

  1. Image dataloaders from learning-invariances, with definitions of default transformations.
  2. Defining a big METADATA dict in jaxutils.data.image that should contain all the necessary mean, std, size, num_datapoints information for all datasets.
  3. Adding a model definition of Resnet18, and a conversion script that converts Pytorch resnet18 models from bayesian-lottery-tickets to Flax.

@shreyaspadhy shreyaspadhy removed the request for review from JamesAllingham June 13, 2022 11:05
@shreyaspadhy shreyaspadhy changed the title Adding Image Dataloaders Adding Image Dataloaders and Flax Resnet18 model Jun 13, 2022
Copy link
Collaborator

@JamesAllingham JamesAllingham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mostly LGTM! Just a few minor suggestions/questions for the most part.

Also, I assume that you didn't mean to upload the __pycache__ folder?

And I think that the img_resnets.py can just be renamed to resnets.py, since a) there are no other kinds of resnets (yet?), and b) by default I think it is assumed that resnets will be the convolutional kind. But that is also minor.

data/image.py Outdated
@@ -0,0 +1,227 @@
"""Image dataset functionality, borrowed from https://github.com/JamesAllingham/learning-invariances/blob/main/src/data/image.py."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't really necessary to put this here! As I mentioned, this is code that has been adapted by me and Javi over the course of our PhDs, so this repo isn't really the origin.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it!

'CIFAR10': 10_000,
'CIFAR100': 10_000,
},
'mean': {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice addition!

}


TRAIN_TRANSFORMATIONS = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also good!

data/image.py Outdated
if flatten_img:
common_transforms += [Flatten()]

# Important when fitting linear model and sample-then-optimise
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just to say that you need augmentations for this project, or is there something more to this comment? :)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some more detail. Basically, we need to not use augmented train data, when training for the linear mode, and for sampling from the posterior. This was actually a bug I'd chased down, where if we augment, then the sampling is misspecified, and we never converge.

if perform_augmentations:
train_augmentations = TRAIN_TRANSFORMATIONS[dataset_name]
else:
train_augmentations = TEST_TRANSFORMATIONS[dataset_name]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the non-imagenet cases, this makes sense since the test augementations are empty, but for imagenet does it make sense to be applying transformations when the user of the function has set perform_augmentations=False?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason we need Resize(256) and CenterCrop(224) for Imagenet is because, by default, test images in Imagenet are all of random sizes, and not uniform. So we still need some deterministic preprocessing to ensure all images are of size 224x224x3.

data/image.py Outdated
random_seed: the `int` random seed for splitting the val data and
applying random affine transformations. (Default: 42)
perform_augmentations: a `bool` indicating whether to apply random
affine transformations to the training data. (Default: `True`)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This docstring isn't exactly accurate, since the augmentations are not limited to affine transformations?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point, affine is wrong there. made the change.

y = self.conv(self.filters, (3, 3), padding=((1, 1), (1, 1)))(y)
# ^ Not using Flax default padding since it doesn't match PyTorch

# For pretrained bayesian-lottery-tickets models, don't init with 0 here.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this should be a flag that can be set by the user (if it is specific to BLT models)?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do this when I add the torchhub model conversion fns that you had.

models/lenets.py Outdated
return x


class LeNetSmall(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is very similar to LeNet, with only the sizes changing, I think this is a good place to use inheritance to reduce code duplication. Specifically, I'd change LeNet to also have a separate setup and __call__ and then when defining LeNetSmall inherit from LeNet and simply override the setup.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! I have a doubt about inheritance. Since I'm redefining setup for LeNetSmall, I'll have to redefine self.dense here, even though it is the same. Is there a way to only overwrite a subset of class properties after inheriting within Flax?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, maybe you can call super().setup() in your inherited class's setup function and then redefine only the attributes which need to change. However, I think this is something I've tried before and not had luck. You'll probably have to bite the bullet and have some code duplication I am afraid.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, you can definitely call super.setup() and then add new attributes, however, I think an error is thrown when you try and overwrite attributes which have been defined in the parent's setup.

Copy link
Collaborator

@JamesAllingham JamesAllingham Jun 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I've confirmed that redefining attributes when calling super.setup() inside of the child class's setup() doesn't work. E.g., trying this:

class MyNet(nn.Module):

    def setup(self):
        self.fc2 = nn.Dense(10, 10)
        self.fc1 = nn.Dense(10, 10)

    def __call__(self, x):
        return self.fc2(self.fc1(x))

class MyNet2(MyNet):

    def setup(self):
        super().setup()

        self.fc2 = nn.Dense(333, 10)

    def __call__(self, x):
        return self.fc2(self.fc1(x))

will result in this error, when trying to call init:

ValueError: Duplicate use of scope name: "fc2"

However, if you really want to you can avoid the code duplication like this:

class MyNet(nn.Module):

    def _partial_setup(self):
        self.fc2 = nn.Dense(10, 10)

    def setup(self):
        self._partial_setup()
        self.fc1 = nn.Dense(10, 10)

    def __call__(self, x):
        return self.fc2(self.fc1(x))


class MyNet2(MyNet):

    def _partial_setup(self):
        self.fc2 = nn.Dense(333, 10)

    def __call__(self, x):
        return self.fc2(self.fc1(x))

But I think in this case it is better to have the duplicated code, since it is only 1 line.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for trying these out, this is super useful to know!

models/lenets.py Outdated
self.conv2 = conv3_block(32, stride=2)
self.conv3 = conv3_block(32, stride=2)

@nn.compact
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This decorator should be removed if you have a setup method.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a way in Flax to mix and match? For example, I define self.conv{i} in the setup, but if I want to compactly define nn.Dense within the call without mentioning it in the setup? I'm curious about the syntax here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, it actually doesn't break anything to mix and match setup and compact, however, as far as I understand, this is not advertised functionality and there are plans to remove this functionality in future: google/flax#2018.

Why do you actually want to mux and match though? Seems to me that everything could be in setup here?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good to know! I wanted to be cheeky and define the Dense layer inside, but good to know it's not preferred behaviour.

@shreyaspadhy shreyaspadhy merged commit 23fdd4e into main Jun 13, 2022
@shreyaspadhy shreyaspadhy deleted the dataloading branch October 10, 2022 19:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants