Skip to content

Commit

Permalink
[fbsync] properly support deepcopying and serialization of model weig…
Browse files Browse the repository at this point in the history
…hts (#7107)

Reviewed By: YosuaMichael

Differential Revision: D42706909

fbshipit-source-id: 78fe196c07687633eca4082440d53feeb5148360
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Jan 24, 2023
1 parent 83b23fb commit 94ecbbc
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
31 changes: 27 additions & 4 deletions test/test_extended_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import pickle

import pytest
import test_models as TM
Expand Down Expand Up @@ -73,10 +74,32 @@ def test_get_model_weights(name, weight):
],
)
def test_weights_copyable(copy_fn, name):
model_weights = models.get_model_weights(name)
for weights in list(model_weights):
copied_weights = copy_fn(weights)
assert copied_weights is weights
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that (deep-)copying is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert copy_fn(weights) is weights


@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_deserializable(name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that deserialization is an identity operation here, but this is the default behavior
# of enums: https://docs.python.org/3/howto/enum.html#enum-members-aka-instances
# Checking for equality, i.e. `==`, is sufficient (and even preferable) for our use case, should we need to drop
# support for the identity operation in the future.
assert pickle.loads(pickle.dumps(weights)) is weights


@pytest.mark.parametrize(
Expand Down
30 changes: 27 additions & 3 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import sys
from dataclasses import dataclass, fields
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
Expand Down Expand Up @@ -37,6 +38,32 @@ class Weights:
transforms: Callable
meta: Dict[str, Any]

def __eq__(self, other: Any) -> bool:
# We need this custom implementation for correct deep-copy and deserialization behavior.
# TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
# involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
# defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
# for it, the check against the defined members would fail and effectively prevent the weights from being
# deep-copied or deserialized.
# See https://github.com/pytorch/vision/pull/7107 for details.
if not isinstance(other, Weights):
return NotImplemented

if self.url != other.url:
return False

if self.meta != other.meta:
return False

if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
return (
self.transforms.func == other.transforms.func
and self.transforms.args == other.transforms.args
and self.transforms.keywords == other.transforms.keywords
)
else:
return self.transforms == other.transforms


class WeightsEnum(StrEnum):
"""
Expand Down Expand Up @@ -75,9 +102,6 @@ def __getattr__(self, name):
return object.__getattribute__(self.value, name)
return super().__getattr__(name)

def __deepcopy__(self, memodict=None):
return self


def get_weight(name: str) -> WeightsEnum:
"""
Expand Down

0 comments on commit 94ecbbc

Please sign in to comment.