Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safetensors offload #20321

Merged
merged 6 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 82 additions & 17 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections
import gc
import json
import os
Expand All @@ -28,7 +28,7 @@

import torch
from packaging import version
from torch import Tensor, device, nn
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import was dangerous with so many local variables named device. It was only used in a type annotation.

from torch import Tensor, nn
from torch.nn import CrossEntropyLoss

from transformers.utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
Expand Down Expand Up @@ -545,6 +545,7 @@ def _load_state_dict_into_meta_model(
state_dict_index=None,
dtype=None,
load_in_8bit=False,
is_safetensors=False,
):
"""
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
Expand Down Expand Up @@ -609,7 +610,8 @@ def _load_state_dict_into_meta_model(
raise ValueError(f"{param_name} doesn't have any device set.")
param_device = device_map[module_name]
if param_device == "disk":
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
if not is_safetensors:
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
elif param_device == "cpu" and state_dict_index is not None:
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
elif not load_in_8bit:
Expand Down Expand Up @@ -673,7 +675,7 @@ def reset_memory_hooks_state(self):
module.mem_rss_pre_forward = 0

@property
def device(self) -> device:
def device(self) -> torch.device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
Expand Down Expand Up @@ -2331,7 +2333,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
if dtype_orig is not None:
torch.set_default_dtype(dtype_orig)

model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
(
model,
missing_keys,
unexpected_keys,
mismatched_keys,
offload_index,
error_msgs,
) = cls._load_pretrained_model(
model,
state_dict,
loaded_state_dict_keys, # XXX: rename?
Expand All @@ -2358,7 +2367,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# Dispatch model with hooks on all devices if necessary
if device_map is not None:
dispatch_model(model, device_map=device_map, offload_dir=offload_folder)
dispatch_model(model, device_map=device_map, offload_dir=offload_folder, offload_index=offload_index)

if output_loading_info:
if loading_info is None:
Expand Down Expand Up @@ -2390,16 +2399,23 @@ def _load_pretrained_model(
dtype=None,
load_in_8bit=False,
):
is_safetensors = False
if load_in_8bit:
from .utils.bitsandbytes import set_module_8bit_tensor_to_device

if device_map is not None and "disk" in device_map.values():
if offload_folder is None:
archive_file = (
resolved_archive_file[0] if isinstance(resolved_archive_file, (list, tuple)) else resolved_archive_file
)
is_safetensors = archive_file.endswith(".safetensors")
if offload_folder is None and not is_safetensors:
raise ValueError(
"The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
" for them."
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
" offers the weights in this format."
)
os.makedirs(offload_folder, exist_ok=True)
if offload_folder is not None:
os.makedirs(offload_folder, exist_ok=True)
if offload_state_dict is None:
offload_state_dict = True

Expand Down Expand Up @@ -2516,6 +2532,17 @@ def _find_mismatched_keys(
del state_dict[checkpoint_key]
return mismatched_keys

folder = os.path.sep.join(resolved_archive_file[0].split(os.path.sep)[:-1])
if device_map is not None and is_safetensors:
param_device_map = expand_device_map(device_map, sharded_metadata["all_checkpoint_keys"])

str_dtype = str(dtype).replace("torch.", "")
offload_index = {
p: {"safetensors_file": os.path.join(folder, f), "weight_name": p, "dtype": str_dtype}
for p, f in sharded_metadata["weight_map"].items()
if param_device_map[p] == "disk"
}
Comment on lines +2540 to +2544
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We build the offload_index here when the checkpoin is safetensors, this way we can skip loading all the shards and gain time.


if state_dict is not None:
# Whole checkpoint
mismatched_keys = _find_mismatched_keys(
Expand All @@ -2527,6 +2554,7 @@ def _find_mismatched_keys(
ignore_mismatched_sizes,
)
error_msgs = _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
offload_index = None
else:
# Sharded checkpoint or whole but low_cpu_mem_usage==True

Expand All @@ -2536,15 +2564,25 @@ def _find_mismatched_keys(

error_msgs = []
mismatched_keys = []
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
if not is_safetensors:
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
if offload_state_dict:
state_dict_folder = tempfile.mkdtemp()
state_dict_index = {}
else:
state_dict_folder = None
state_dict_index = None

if is_safetensors:
disk_only_shard_files = get_disk_only_shard_files(device_map, sharded_metadata=sharded_metadata)
disk_only_shard_files = [os.path.join(folder, f) for f in disk_only_shard_files]
else:
disk_only_shard_files = []

for shard_file in resolved_archive_file:
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
continue
state_dict = load_state_dict(shard_file)

# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
Expand Down Expand Up @@ -2572,6 +2610,7 @@ def _find_mismatched_keys(
state_dict_index=state_dict_index,
dtype=dtype,
load_in_8bit=load_in_8bit,
is_safetensors=is_safetensors,
)
error_msgs += new_error_msgs
else:
Expand All @@ -2585,13 +2624,16 @@ def _find_mismatched_keys(
if model != model_to_load:
# We need to add the prefix of the base model
prefix = cls.base_model_prefix
for weight_name in offload_index:
shutil.move(
os.path.join(offload_folder, f"{weight_name}.dat"),
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
)
if not is_safetensors:
for weight_name in offload_index:
shutil.move(
os.path.join(offload_folder, f"{weight_name}.dat"),
os.path.join(offload_folder, f"{prefix}.{weight_name}.dat"),
)
offload_index = {f"{prefix}.{key}": value for key, value in offload_index.items()}
save_offload_index(offload_index, offload_folder)
if not is_safetensors:
save_offload_index(offload_index, offload_folder)
offload_index = None

if offload_state_dict:
# Load back temporarily offloaded state dict
Expand Down Expand Up @@ -2645,7 +2687,7 @@ def _find_mismatched_keys(
" to use it for predictions and inference."
)

return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
return model, missing_keys, unexpected_keys, mismatched_keys, offload_index, error_msgs

def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = set([".".join(key.split(".")[:-1]) for key in names])
Expand Down Expand Up @@ -3158,3 +3200,26 @@ def unwrap_model(model: nn.Module) -> nn.Module:
return unwrap_model(model.module)
else:
return model


def expand_device_map(device_map, param_names):
"""
Expand a device map to return the correspondance parameter name to device.
"""
new_device_map = {}
for module, device in device_map.items():
new_device_map.update({p: device for p in param_names if p == module or p.startswith(f"{module}.")})
return new_device_map


def get_disk_only_shard_files(device_map, sharded_metadata):
"""
Returns the list of shard files containing only weights offloaded to disk.
"""
files_content = collections.defaultdict(list)
for weight_name, filename in sharded_metadata["weight_map"].items():
while len(weight_name) > 0 and weight_name not in device_map:
weight_name = ".".join(weight_name.split(".")[:-1])
files_content[filename].append(device_map[weight_name])

return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
1 change: 1 addition & 0 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ def get_checkpoint_shard_files(
shard_filenames = sorted(list(set(index["weight_map"].values())))
sharded_metadata = index["metadata"]
sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
sharded_metadata["weight_map"] = index["weight_map"].copy()

# First, let's deal with local folder.
if os.path.isdir(pretrained_model_name_or_path):
Expand Down