diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b943b5e7989f03..eccc8537bdb865 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -958,6 +958,9 @@ def _load_state_dict_into_meta_model( ) ) ): + if is_fsdp_enabled(): + param_device = "cpu" if is_local_dist_rank_0() else "meta" + # For backward compatibility with older versions of `accelerate` and for non-quantized params set_module_tensor_to_device(model, param_name, param_device, **set_module_kwargs) else: @@ -968,7 +971,10 @@ def _load_state_dict_into_meta_model( if is_fsdp_enabled() or is_deepspeed_zero3_enabled(): module, tensor_name = get_module_from_name(model, param_name) value = getattr(module, tensor_name) - value = type(value)(value.data.to("cpu"), **value.__dict__) + param_to = "cpu" + if is_fsdp_enabled() and not is_local_dist_rank_0(): + param_to = "meta" + value = type(value)(value.data.to(param_to), **value.__dict__) setattr(module, tensor_name, value) # TODO: consider removing used param_parts from state_dict before return