Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

properly support deepcopying and serialization of model weights #7107

Merged
merged 10 commits into from
Jan 19, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions test/test_extended_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import os
import pickle

import pytest
import test_models as TM
Expand Down Expand Up @@ -73,10 +74,30 @@ def test_get_model_weights(name, weight):
],
)
def test_weights_copyable(copy_fn, name):
model_weights = models.get_model_weights(name)
for weights in list(model_weights):
copied_weights = copy_fn(weights)
assert copied_weights is weights
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that (deep-)copying is an identity operation here,
# but this is the default behavior of enums.
pmeier marked this conversation as resolved.
Show resolved Hide resolved
# See https://github.com/pytorch/vision/pull/7107 for details.
assert copy_fn(weights) is weights


@pytest.mark.parametrize(
"name",
[
"resnet50",
"retinanet_resnet50_fpn_v2",
"raft_large",
"quantized_resnet50",
"lraspp_mobilenet_v3_large",
"mvit_v1_b",
],
)
def test_weights_deserializable(name):
for weights in list(models.get_model_weights(name)):
# It is somewhat surprising that deserialization is an identity operation here,
# but this is the default behavior of enums.
# See https://github.com/pytorch/vision/pull/7107 for details.
assert pickle.loads(pickle.dumps(weights)) is weights


@pytest.mark.parametrize(
Expand Down
30 changes: 27 additions & 3 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import sys
from dataclasses import dataclass, fields
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, TypeVar, Union
Expand Down Expand Up @@ -37,6 +38,32 @@ class Weights:
transforms: Callable
meta: Dict[str, Any]

def __eq__(self, other: Any) -> bool:
# We need this custom implementation for correct deep-copy and deserialization behavior.
# TL;DR: After the definition of an enum, creating a new instance, i.e. by deep-copying or deserializing it,
# involves an equality check against the defined members. Unfortunately, the `transforms` attribute is often
# defined with `functools.partial` and `fn = partial(...); assert deepcopy(fn) != fn`. Without custom handling
# for it, the check against the defined members would fail and effectively prevent the weights from being
# deep-copied or deserialized.
# See https://github.com/pytorch/vision/pull/7107 for details.
if not isinstance(other, Weights):
return NotImplemented

if self.url != other.url:
return False

if self.meta != other.meta:
return False

if isinstance(self.transforms, partial) and isinstance(other.transforms, partial):
return (
self.transforms.func == other.transforms.func
and self.transforms.args == other.transforms.args
and self.transforms.keywords == other.transforms.keywords
)
else:
return self.transforms == other.transforms


class WeightsEnum(StrEnum):
"""
Expand Down Expand Up @@ -75,9 +102,6 @@ def __getattr__(self, name):
return object.__getattribute__(self.value, name)
return super().__getattr__(name)

def __deepcopy__(self, memodict=None):
return self


def get_weight(name: str) -> WeightsEnum:
"""
Expand Down