Skip to content

Commit

Permalink
remove context manager when loading shards and handle mlx weights (#2709
Browse files Browse the repository at this point in the history
)
  • Loading branch information
hanouticelina authored Dec 13, 2024
1 parent 4b0b179 commit ca3f674
Showing 1 changed file with 11 additions and 33 deletions.
44 changes: 11 additions & 33 deletions src/huggingface_hub/serialization/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@
import os
import re
from collections import defaultdict, namedtuple
from contextlib import contextmanager
from functools import lru_cache
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union

from packaging import version

Expand Down Expand Up @@ -538,13 +537,15 @@ def _load_sharded_checkpoint(
for shard_file in shard_files:
# Load shard into memory
shard_path = os.path.join(save_directory, shard_file)
with _load_shard_into_memory(
state_dict = load_state_dict_from_file(
shard_path,
load_fn=load_state_dict_from_file,
kwargs={"weights_only": weights_only},
) as state_dict:
# Update model with parameters from this shard
model.load_state_dict(state_dict, strict=strict)
map_location="cpu",
weights_only=weights_only,
)
# Update model with parameters from this shard
model.load_state_dict(state_dict, strict=strict)
# Explicitly remove the state dict from memory
del state_dict

# 4. Return compatibility info
loaded_keys = set(index["weight_map"].keys())
Expand Down Expand Up @@ -630,7 +631,8 @@ def load_state_dict_from_file(
# Check format of the archive
with safe_open(checkpoint_file, framework="pt") as f: # type: ignore[attr-defined]
metadata = f.metadata()
if metadata.get("format") != "pt":
# see comment: https://github.com/huggingface/transformers/blob/3d213b57fe74302e5902d68ed9478c3ad1aaa713/src/transformers/modeling_utils.py#L3966
if metadata is not None and metadata.get("format") not in ["pt", "mlx"]:
raise OSError(
f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
"you save your model with the `save_torch_model` method."
Expand Down Expand Up @@ -668,30 +670,6 @@ def load_state_dict_from_file(
# HELPERS


@contextmanager
def _load_shard_into_memory(
shard_path: str,
load_fn: Callable,
kwargs: Optional[Dict[str, Any]] = None,
):
"""
Context manager to handle loading and cleanup of model shards.
Args:
shard_path: Path to the shard file
load_fn: Function to load the shard (either torch.load or safetensors.load)
Yields:
The loaded state dict for this shard
"""
try:
state_dict = load_fn(shard_path, **kwargs) # type: ignore[arg-type]
yield state_dict
finally:
# Explicitly remove the state dict from memory
del state_dict


def _validate_keys_for_strict_loading(
model: "torch.nn.Module",
loaded_keys: Iterable[str],
Expand Down

0 comments on commit ca3f674

Please sign in to comment.