Skip to content

Commit

Permalink
Fix disk offload for full safetensors checkpoints (huggingface#20497)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgugger authored and Magnus Pierrau committed Dec 15, 2022
1 parent 64fc338 commit 6aa8f88
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,9 @@ def _load_state_dict_into_meta_model(
# in int/uint/bool and not cast them.
if dtype is not None and torch.is_floating_point(param):
param = param.to(dtype)
# For compatibility with PyTorch which loads float16/bfloat16 weights in fp32
if is_safetensors and dtype is None and torch.is_floating_point(param):
param = param.to(torch.float32)

if device_map is None:
param_device = "cpu"
Expand Down Expand Up @@ -2452,6 +2455,7 @@ def _load_pretrained_model(
if offload_state_dict is None:
offload_state_dict = True

is_sharded_safetensors = is_safetensors and sharded_metadata is not None
# Retrieve missing & unexpected_keys
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
Expand Down Expand Up @@ -2567,12 +2571,21 @@ def _find_mismatched_keys(

folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, sharded_metadata["all_checkpoint_keys"])

str_dtype = str(dtype).replace("torch.", "")
param_device_map = expand_device_map(device_map, original_loaded_keys)

str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
if sharded_metadata is None:
archive_file = (
resolved_archive_file[0]
if isinstance(resolved_archive_file, (list, tuple))
else resolved_archive_file
)
weight_map = {p: archive_file for p in original_loaded_keys}
else:
weight_map = {p: os.path.join(folder, f) for p, f in sharded_metadata["weight_map"].items()}
offload_index = {
p: {"safetensors_file": os.path.join(folder, f), "weight_name": p, "dtype": str_dtype}
for p, f in sharded_metadata["weight_map"].items()
p: {"safetensors_file": f, "weight_name": p, "dtype": str_dtype}
for p, f in weight_map.items()
if param_device_map[p] == "disk"
}

Expand Down Expand Up @@ -2606,7 +2619,7 @@ def _find_mismatched_keys(
state_dict_folder = None
state_dict_index = None

if is_safetensors:
if is_sharded_safetensors:
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else:
Expand Down

0 comments on commit 6aa8f88

Please sign in to comment.