1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import contextlib
1516from functools import wraps
1617from typing import Any , Callable , Optional
1718
1819import torch
19- from compressed_tensors .utils .helpers import getattr_chain
2020
2121
2222try :
23- from accelerate .hooks import AlignDevicesHook
23+ from accelerate .hooks import (
24+ AlignDevicesHook ,
25+ add_hook_to_module ,
26+ remove_hook_from_module ,
27+ )
2428 from accelerate .utils import (
2529 OffloadedWeightsLoader ,
2630 PrefixedDataset ,
4246 "update_offload_data" ,
4347 "delete_offload_parameter" ,
4448 "has_offloaded_params" ,
49+ "disable_hf_hook" ,
50+ "align_module_device" ,
4551]
4652
4753
@@ -167,6 +173,7 @@ def update_offload_data(
167173 :param data: tensor to update parameter with
168174 """
169175 param = getattr (module , name )
176+ data = data .to (param .dtype )
170177
171178 # copy data into onloaded parameter if applicable
172179 if param .device != "meta" :
@@ -178,23 +185,34 @@ def update_offload_data(
178185
179186 # for upstreaming, better to add write capabilities to weight map classes first
180187 if isinstance (weights_map , PrefixedDataset ):
181- dataset = getattr_chain ( module , "module._hf_hook.weights_map. dataset" , None )
188+ dataset = getattr ( weights_map , "dataset" , None )
182189 if dataset is not None :
183190 prefix = module ._hf_hook .weights_map .prefix
184191 key = f"{ prefix } { name } "
185192
186193 offload_device = (
187194 dataset [key ].device
188195 if key in dataset
189- else next (dataset .values ()).device
196+ else next (iter ( dataset .values () )).device
190197 )
191- dataset [key ] = param .data .to (device = offload_device )
198+ dataset [key ] = data .to (device = offload_device )
199+
200+ elif isinstance (weights_map , dict ):
201+ offload_device = (
202+ weights_map [name ].device
203+ if name in weights_map
204+ else next (iter (weights_map .values ())).device
205+ )
206+ weights_map [name ] = data .to (device = offload_device )
192207
193- if isinstance (weights_map , OffloadedWeightsLoader ):
208+ elif isinstance (weights_map , OffloadedWeightsLoader ):
194209 raise NotImplementedError ()
195210
196211 else :
197- raise NotImplementedError ()
212+ raise NotImplementedError (
213+ "Updating offload data not implemented for weights_map of type "
214+ f"{ type (weights_map )} "
215+ )
198216
199217
200218def delete_offload_parameter (module : torch .nn .Module , name : str ):
@@ -216,6 +234,9 @@ def delete_offload_parameter(module: torch.nn.Module, name: str):
216234 if dataset is not None :
217235 del dataset [f"{ prefix } { name } " ]
218236
237+ elif isinstance (weights_map , dict ):
238+ del weights_map [name ]
239+
219240 elif isinstance (weights_map , OffloadedWeightsLoader ):
220241 raise NotImplementedError ()
221242
@@ -225,6 +246,20 @@ def delete_offload_parameter(module: torch.nn.Module, name: str):
225246 )
226247
227248
249+ @check_accelerate (fallback = contextlib .nullcontext ())
250+ @contextlib .contextmanager
251+ def disable_hf_hook (module : torch .nn .Module , recurse : bool = False ):
252+ offloaded = has_offloaded_params (module )
253+ if offloaded :
254+ hook = module ._hf_hook
255+ remove_hook_from_module (module , recurse = recurse )
256+
257+ yield
258+
259+ if offloaded :
260+ add_hook_to_module (module , hook )
261+
262+
228263""" Upstreamed Functions """
229264
230265
@@ -247,3 +282,48 @@ def has_offloaded_params(module: torch.nn.Module) -> bool:
247282 and isinstance (module ._hf_hook , AlignDevicesHook )
248283 and module ._hf_hook .offload
249284 )
285+
286+
287+ # introduced in accelerate v1.1.0
288+ @check_accelerate (fallback = contextlib .nullcontext ())
289+ @contextlib .contextmanager
290+ def align_module_device (
291+ module : torch .nn .Module , execution_device : Optional [torch .device ] = None
292+ ):
293+ """
294+ Context manager that moves a module's parameters to the specified execution device.
295+
296+ Args:
297+ module (`torch.nn.Module`):
298+ Module with parameters to align.
299+ execution_device (`torch.device`, *optional*):
300+ If provided, overrides the module's execution device within the context.
301+ Otherwise, use hook execution device or pass
302+ """
303+ if has_offloaded_params (module ):
304+ if execution_device is not None :
305+ original_device = module ._hf_hook .execution_device
306+ module ._hf_hook .execution_device = execution_device
307+
308+ try :
309+ module ._hf_hook .pre_forward (module )
310+ yield
311+ finally :
312+ module ._hf_hook .post_forward (module , None )
313+ if execution_device is not None :
314+ module ._hf_hook .execution_device = original_device
315+
316+ elif execution_device is not None :
317+ devices = {
318+ name : param .device for name , param in module .named_parameters (recurse = False )
319+ }
320+ try :
321+ for name in devices :
322+ set_module_tensor_to_device (module , name , execution_device )
323+ yield
324+ finally :
325+ for name , device in devices .items ():
326+ set_module_tensor_to_device (module , name , device )
327+
328+ else :
329+ yield
0 commit comments