33# This source code is licensed under the BSD license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ import copy
7+ import copyreg
68from contextlib import contextmanager
7- from typing import Any , Type
9+ from typing import Type
810
911import torch
1012from torch .distributed .fsdp import MixedPrecisionPolicy
1113from torch .utils ._pytree import tree_map
1214
1315
14- def _unimplemented_deepcopy (* args : Any , ** kwargs : Any ):
16+ def make_getter (self , p_name , mp_policy ):
17+ def getter (
18+ self_mod = self ,
19+ _param_name = p_name ,
20+ _dtype = mp_policy .param_dtype ,
21+ ):
22+ _param = self_mod ._parameters [_param_name ]
23+ if not active_param ():
24+ return _param
25+ return torch .ops .autoparallel .dtype_cast (_param , _dtype )
26+
27+ return getter
28+
29+
30+ # taken from PyTorch's parametrize module from
31+ # https://github.com/pytorch/pytorch/blob/5d9653d90ee003173dd03f93e09fed236500ef06/torch/nn/utils/parametrize.py#L324-L351
32+ # with some improvements
33+ def default_deepcopy (self , memo ):
34+ # Just emulate a standard deepcopy procedure when __deepcopy__ doesn't exist in the current class.
35+ obj = memo .get (id (self ), None )
36+ if obj is not None :
37+ return obj
38+ replica = self .__new__ (self .__class__ )
39+ memo [id (self )] = replica
40+ replica .__dict__ = copy .deepcopy (self .__dict__ , memo )
41+
42+ # Fix the parametrization getters to point to the replica instead of the original
43+ if hasattr (replica , "_name_to_dtype_cast_managed_attr_getter" ) and hasattr (
44+ replica , "_mp_policy"
45+ ):
46+ # Recreate the getter functions to point to the replica
47+ param_properties = {}
48+ for p_name in list (replica ._name_to_dtype_cast_managed_attr_getter .keys ()):
49+ # Use a function factory to properly capture the loop variable
50+ # def make_getter(param_name):
51+ param_properties [p_name ] = make_getter (replica , p_name , replica ._mp_policy )
52+ replica ._name_to_dtype_cast_managed_attr_getter = param_properties
53+
54+ # Also save all slots if they exist.
55+ slots_to_save = copyreg ._slotnames (self .__class__ ) # type: ignore[attr-defined]
56+ for slot in slots_to_save :
57+ if hasattr (self , slot ):
58+ setattr (replica , slot , copy .deepcopy (getattr (self , slot ), memo ))
59+ return replica
60+
61+
62+ def getstate (self ):
1563 raise RuntimeError (
16- "DTypeCast does not support deepcopy. Please use state dict for serialization." ,
64+ "Serialization of parametrized modules is only "
65+ "supported through state_dict(). See:\n "
66+ "https://pytorch.org/tutorials/beginner/saving_loading_models.html"
67+ "#saving-loading-a-general-checkpoint-for-inference-and-or-resuming-training"
1768 )
1869
1970
@@ -103,27 +154,17 @@ def apply_dtype_cast(model, mp_policy: MixedPrecisionPolicy):
103154 params_dict = dict (mod .named_parameters (recurse = False ))
104155
105156 # Create new class for this module with all parametrized parameters
106- param_properties = {}
107- for p_name , p in params_dict .items ():
108-
109- def getter (
110- self_mod = mod ,
111- _param_name = p_name ,
112- _dtype = mp_policy .param_dtype ,
113- ):
114- _param = self_mod ._parameters [_param_name ]
115- if not active_param ():
116- return _param
117- return torch .ops .autoparallel .dtype_cast (_param , _dtype )
118-
119- param_properties [p_name ] = getter
120-
121157 cls = mod .__class__
122- param_properties_key = "#" .join (sorted (param_properties .keys ()))
158+ param_properties_key = "#" .join (sorted (params_dict .keys ()))
123159 new_cls = cls_key_to_dtype_cast_cls .get ((cls , param_properties_key ), None )
124160 if not new_cls :
125- namespace = {"__deepcopy__" : _unimplemented_deepcopy }
126- for p_name in param_properties :
161+ namespace = {"__getstate__" : getstate }
162+ # We don't allow serialization of parametrized modules but should still allow deepcopying.
163+ # Default 'deepcopy' function invokes __deepcopy__ method instead of __getstate__ when it exists.
164+ if not hasattr (cls , "__deepcopy__" ):
165+ namespace ["__deepcopy__" ] = default_deepcopy # type: ignore[assignment]
166+
167+ for p_name in params_dict .keys ():
127168 # NOTE: it's important to have this indirection, to make sure that:
128169 # Different instances of the same class can resolve their parameter access to instance-specific getters
129170 # (which contains unique objects used in that instance-specific parameter's unshard operation).
@@ -132,6 +173,11 @@ def getter(
132173 new_cls = type (f"DTypeCast{ cls .__name__ } " , cls_t , namespace )
133174 cls_key_to_dtype_cast_cls [(cls , param_properties_key )] = new_cls
134175 mod .__class__ = new_cls
176+
177+ param_properties = {}
178+ for p_name in params_dict .keys ():
179+ param_properties [p_name ] = make_getter (mod , p_name , mp_policy )
180+
135181 mod ._name_to_dtype_cast_managed_attr_getter = param_properties
136182 mod ._mp_policy = mp_policy
137183
0 commit comments