-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
properly support deepcopying and serialization of model weights #7107
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot Philip for looking into this,
The addition new TransformsFactory
class looks like a reasonable solution. I just have a few questions below
I found this a little more brittle, since we don't enforce anything about the callable and thus specializing the comparison to partial might lead to other errors
Are you saying that the __eq__
proposed in the other issue assumes that the transforms
fields are implemented as partial
, and that this assumption isn't enforced and thus may break? (I agree)
Does the metaclass solution still work with torchscript?
The proposal looks like this: if not isinstance(self.transforms, functools.partial) or not isinstance(other.transforms, functools.partial):
return self.transforms == other.transforms
return all(
getattr(self.transforms, a) == getattr(other.transforms, a)
for a in ["func", "args", "keywords", "__dict__"]
) So we have a special treatment for This ties a little bit into #7107 (comment). If we don't want a public helper, we are probably better off with the user suggestion since it is a best effort approach for arbitrary callables. |
Thanks for the details Philip. Perhaps this is something @adamjstewart and @nilsleh can give their opinion on, since they're the ones who are going to be using those Weight outside of torchvision. For some context: we're trying to address 2 different bugs with the We're thinking of 2 options:
Personally, I have a slight pref for overriding the |
Former is probably sufficient for us. Our current weights have no |
Thanks for your input @adamjstewart .
Does that mean you wouldn't need the @pmeier are you OK to go for |
I've pushed a version with a custom |
Yes this looks very clean, thank you! We may also want to add a comment in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Philip, some minor comments / nits but LGTM anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still LGTM! thanks
Hey @pmeier! You merged this PR, but no labels were added. The list of valid labels is available at https://github.com/pytorch/vision/blob/main/.github/process_commit.py |
…hts (#7107) Reviewed By: YosuaMichael Differential Revision: D42706909 fbshipit-source-id: 78fe196c07687633eca4082440d53feeb5148360
Fixes #7099 and properly addresses #6871.
The whole idea behind the
TransformsFactory
class is to reimplementfunctools.partial
, but with a deepcopy / serialize behavior that better suits our needs. As OP in in #7099 describes,functools.partial
can be deepcopied and serialized, but the result is no longer equal to input:However, we need this equality to work, since it is an integral part of enums. Creating a fresh enum through either of these ops, internally invokes
Enum.__new__
. In there, a bunch of equality checks are performed against the known members of the enum, i.e. here and here. This ultimately leads us to the error the users reported.Together with
@dataclasses.dataclass
, the functionality that we want is trivial to implement. Re-using the example from above:Another option that we could do is what OP describes in #7099. Basically we need to implement a custom
__eq__
on theWeights
class that is able to handle thepartial
. I found this a little more brittle, since we don't enforce anything about the callable and thus specializing the comparison topartial
might lead to other errors. Note that in theTransformsFactory
implementation, we don't need to touch this behavior at all and it is handled by@dataclass
automatically.ToDo: