Skip to content

Commit ba73b2e

Browse files
authored
Implement deepcopy to DTypeCast module (#98)
1 parent 563bbf6 commit ba73b2e

File tree

1 file changed

+67
-21
lines changed

1 file changed

+67
-21
lines changed

autoparallel/cast_parametrization.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,68 @@
33
# This source code is licensed under the BSD license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
import copy
7+
import copyreg
68
from contextlib import contextmanager
7-
from typing import Any, Type
9+
from typing import Type
810

911
import torch
1012
from torch.distributed.fsdp import MixedPrecisionPolicy
1113
from torch.utils._pytree import tree_map
1214

1315

14-
def _unimplemented_deepcopy(*args: Any, **kwargs: Any):
16+
def make_getter(self, p_name, mp_policy):
17+
def getter(
18+
self_mod=self,
19+
_param_name=p_name,
20+
_dtype=mp_policy.param_dtype,
21+
):
22+
_param = self_mod._parameters[_param_name]
23+
if not active_param():
24+
return _param
25+
return torch.ops.autoparallel.dtype_cast(_param, _dtype)
26+
27+
return getter
28+
29+
30+
# taken from PyTorch's parametrize module from
31+
# https://github.com/pytorch/pytorch/blob/5d9653d90ee003173dd03f93e09fed236500ef06/torch/nn/utils/parametrize.py#L324-L351
32+
# with some improvements
33+
def default_deepcopy(self, memo):
34+
# Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
35+
obj = memo.get(id(self), None)
36+
if obj is not None:
37+
return obj
38+
replica = self.__new__(self.__class__)
39+
memo[id(self)] = replica
40+
replica.__dict__ = copy.deepcopy(self.__dict__, memo)
41+
42+
# Fix the parametrization getters to point to the replica instead of the original
43+
if hasattr(replica, "_name_to_dtype_cast_managed_attr_getter") and hasattr(
44+
replica, "_mp_policy"
45+
):
46+
# Recreate the getter functions to point to the replica
47+
param_properties = {}
48+
for p_name in list(replica._name_to_dtype_cast_managed_attr_getter.keys()):
49+
# Use a function factory to properly capture the loop variable
50+
# def make_getter(param_name):
51+
param_properties[p_name] = make_getter(replica, p_name, replica._mp_policy)
52+
replica._name_to_dtype_cast_managed_attr_getter = param_properties
53+
54+
# Also save all slots if they exist.
55+
slots_to_save = copyreg._slotnames(self.__class__) # type: ignore[attr-defined]
56+
for slot in slots_to_save:
57+
if hasattr(self, slot):
58+
setattr(replica, slot, copy.deepcopy(getattr(self, slot), memo))
59+
return replica
60+
61+
62+
def getstate(self):
1563
raise RuntimeError(
16-
"DTypeCast does not support deepcopy. Please use state dict for serialization.",
64+
"Serialization of parametrized modules is only "
65+
"supported through state_dict(). See:\n"
66+
"https://pytorch.org/tutorials/beginner/saving_loading_models.html"
67+
"#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
1768
)
1869

1970

@@ -103,27 +154,17 @@ def apply_dtype_cast(model, mp_policy: MixedPrecisionPolicy):
103154
params_dict = dict(mod.named_parameters(recurse=False))
104155

105156
# Create new class for this module with all parametrized parameters
106-
param_properties = {}
107-
for p_name, p in params_dict.items():
108-
109-
def getter(
110-
self_mod=mod,
111-
_param_name=p_name,
112-
_dtype=mp_policy.param_dtype,
113-
):
114-
_param = self_mod._parameters[_param_name]
115-
if not active_param():
116-
return _param
117-
return torch.ops.autoparallel.dtype_cast(_param, _dtype)
118-
119-
param_properties[p_name] = getter
120-
121157
cls = mod.__class__
122-
param_properties_key = "#".join(sorted(param_properties.keys()))
158+
param_properties_key = "#".join(sorted(params_dict.keys()))
123159
new_cls = cls_key_to_dtype_cast_cls.get((cls, param_properties_key), None)
124160
if not new_cls:
125-
namespace = {"__deepcopy__": _unimplemented_deepcopy}
126-
for p_name in param_properties:
161+
namespace = {"__getstate__": getstate}
162+
# We don't allow serialization of parametrized modules but should still allow deepcopying.
163+
# Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
164+
if not hasattr(cls, "__deepcopy__"):
165+
namespace["__deepcopy__"] = default_deepcopy # type: ignore[assignment]
166+
167+
for p_name in params_dict.keys():
127168
# NOTE: it's important to have this indirection, to make sure that:
128169
# Different instances of the same class can resolve their parameter access to instance-specific getters
129170
# (which contains unique objects used in that instance-specific parameter's unshard operation).
@@ -132,6 +173,11 @@ def getter(
132173
new_cls = type(f"DTypeCast{cls.__name__}", cls_t, namespace)
133174
cls_key_to_dtype_cast_cls[(cls, param_properties_key)] = new_cls
134175
mod.__class__ = new_cls
176+
177+
param_properties = {}
178+
for p_name in params_dict.keys():
179+
param_properties[p_name] = make_getter(mod, p_name, mp_policy)
180+
135181
mod._name_to_dtype_cast_managed_attr_getter = param_properties
136182
mod._mp_policy = mp_policy
137183

0 commit comments

Comments
 (0)