Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

properly support deepcopying and serialization of model weights #7107

Merged
merged 10 commits into from
Jan 19, 2023

Conversation

pmeier
Copy link
Collaborator

@pmeier pmeier commented Jan 18, 2023

Fixes #7099 and properly addresses #6871.

The whole idea behind the TransformsFactory class is to reimplement functools.partial, but with a deepcopy / serialize behavior that better suits our needs. As OP in in #7099 describes, functools.partial can be deepcopied and serialized, but the result is no longer equal to input:

import pickle
from copy import deepcopy
from functools import partial


def foo():
    pass


bar = partial(foo)
assert bar != deepcopy(bar)
assert bar != pickle.loads(pickle.dumps(bar))

However, we need this equality to work, since it is an integral part of enums. Creating a fresh enum through either of these ops, internally invokes Enum.__new__. In there, a bunch of equality checks are performed against the known members of the enum, i.e. here and here. This ultimately leads us to the error the users reported.

Together with @dataclasses.dataclass, the functionality that we want is trivial to implement. Re-using the example from above:

import pickle
from copy import deepcopy

from torchvision.models._api import TransformsFactory


def foo():
    pass


bar = TransformsFactory(foo)
assert bar == deepcopy(bar)
assert bar == pickle.loads(pickle.dumps(bar))

Another option that we could do is what OP describes in #7099. Basically we need to implement a custom __eq__ on the Weights class that is able to handle the partial. I found this a little more brittle, since we don't enforce anything about the callable and thus specializing the comparison to partial might lead to other errors. Note that in the TransformsFactory implementation, we don't need to touch this behavior at all and it is handled by @dataclass automatically.

ToDo:

  • expand new architecture to all weights
  • write tests that make sure both deepcopy and serialization works and gives the expected results

Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot Philip for looking into this,

The addition new TransformsFactory class looks like a reasonable solution. I just have a few questions below

I found this a little more brittle, since we don't enforce anything about the callable and thus specializing the comparison to partial might lead to other errors

Are you saying that the __eq__ proposed in the other issue assumes that the transforms fields are implemented as partial, and that this assumption isn't enforced and thus may break? (I agree)

Does the metaclass solution still work with torchscript?

torchvision/models/_api.py Outdated Show resolved Hide resolved
torchvision/models/_api.py Outdated Show resolved Hide resolved
torchvision/models/_api.py Outdated Show resolved Hide resolved
@pmeier
Copy link
Collaborator Author

pmeier commented Jan 18, 2023

Are you saying that the __eq__ proposed in the other issue assumes that the transforms fields are implemented as partial, and that this assumption isn't enforced and thus may break? (I agree)

The proposal looks like this:

if not isinstance(self.transforms, functools.partial) or not isinstance(other.transforms, functools.partial):
    return self.transforms == other.transforms
return all(
    getattr(self.transforms, a) == getattr(other.transforms, a)
    for a in ["func", "args", "keywords", "__dict__"]
)

So we have a special treatment for functools.partial and everything else will be handled by a regular ==. So no, it doesn't assume that the attribute is a partial (I think we already have cases where it is not), but special cases it. The reason I somewhat dislike this is that there are potentially more callable types that need special handling. Of the top of my mind, what if someone would use lambda: ImageClassification(crop_size=[224]).

This ties a little bit into #7107 (comment). If we don't want a public helper, we are probably better off with the user suggestion since it is a best effort approach for arbitrary callables.

@NicolasHug
Copy link
Member

Thanks for the details Philip.

Perhaps this is something @adamjstewart and @nilsleh can give their opinion on, since they're the ones who are going to be using those Weight outside of torchvision.

For some context: we're trying to address 2 different bugs with the WeightEnum: pickle-ability #7099 and support for deepcopy (and Lightning, apparently) #6871 (comment).

We're thinking of 2 options:

  • We override __eq__ of Weights in torchvision: you have nothing to do in torchGeo, but we can only guarantee that the implem of __eq__ supports torchvision's needs: i.e. right now our transforms field is implemented as partial so that's what we support. Should you rely on something different (e.g. a lambda or something more esoteric), it may or may not break. If it breaks, we can address it in torchvision when you let us know, or you can also fix it in torchgeo by replacing __eq__, but it's of course a bit brittle.
  • We provide a public TransformsFactory helper (name TBD) which you would have to use when defining each of the transforms fields (https://github.com/pytorch/vision/pull/7107/files#diff-6aff9cc95a81d1947b2e5fed0363017759e6e54a6a4c047394a3ab61ffbc35d4R359)

Personally, I have a slight pref for overriding the __eq__ because it avoids updating all the transforms fields, and it keeps the user-facing API simpler. But I'm interested in what @adamjstewart and @nilsleh might think

@adamjstewart
Copy link
Contributor

Former is probably sufficient for us. Our current weights have no transforms field, we just use nn.Identity. Unfortunately, different datasets use different normalizations (0–1, 0–2^8, 0–2^16, etc.) so our users will have to figure out the appropriate normalizations themselves.

@NicolasHug
Copy link
Member

Thanks for your input @adamjstewart .

Our current weights have no transforms field

Does that mean you wouldn't need the Weight class to be public, only the WeightEnum one?

@pmeier are you OK to go for __eq__ or do you still have some concerns?

@pmeier
Copy link
Collaborator Author

pmeier commented Jan 19, 2023

I've pushed a version with a custom Weights.__eq__. @NicolasHug PTAL. If you are ok with the design, I'm going to write and fix tests.

@NicolasHug
Copy link
Member

@NicolasHug PTAL. If you are ok with the design, I'm going to write and fix tests.

Yes this looks very clean, thank you! We may also want to add a comment in __eq__ about why this is all needed. We may not need to go too deep in the details but at least explain that this is needed to handle deepcopy and pickle

@pmeier pmeier marked this pull request as ready for review January 19, 2023 12:40
@pmeier pmeier requested a review from NicolasHug January 19, 2023 12:40
@pmeier pmeier changed the title [PoC] properly support deepcopying and serialization of model weights properly support deepcopying and serialization of model weights Jan 19, 2023
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Philip, some minor comments / nits but LGTM anyway

torchvision/models/_api.py Outdated Show resolved Hide resolved
test/test_extended_models.py Outdated Show resolved Hide resolved
@pmeier pmeier requested a review from NicolasHug January 19, 2023 13:40
Copy link
Member

@NicolasHug NicolasHug left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still LGTM! thanks

@pmeier pmeier merged commit c06d52b into pytorch:main Jan 19, 2023
@pmeier pmeier deleted the weights branch January 19, 2023 13:56
@github-actions
Copy link

Hey @pmeier!

You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py

facebook-github-bot pushed a commit that referenced this pull request Jan 24, 2023
…hts (#7107)

Reviewed By: YosuaMichael

Differential Revision: D42706909

fbshipit-source-id: 78fe196c07687633eca4082440d53feeb5148360
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Weights enums cannot be pickled
4 participants