Skip to content

Commit

Permalink
Add Support for Decompressing Models from HF Hub (#2212) commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dbogunowicz committed Apr 8, 2024
1 parent 5ef7fbf commit 0f89712
Showing 1 changed file with 47 additions and 43 deletions.
90 changes: 47 additions & 43 deletions src/sparsetensors/utils/safetensors_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,57 @@


__all__ = [
"get_safetensors_header",
"get_safetensors_folder",
"get_safetensors_header",
"match_param_name",
"merge_names",
"get_weight_mappings",
"get_nested_weight_mappings",
]


def get_safetensors_folder(
pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
) -> str:
"""
Given a Hugging Face stub or a local path, return the folder containing the
safetensors weight files
:param pretrained_model_name_or_path: local path to model or HF stub
:param cache_dir: optional cache dir to search through, if none is specified the
model will be searched for in the default TRANSFORMERS_CACHE
:return: local folder containing model data
"""
if os.path.exists(pretrained_model_name_or_path):
# argument is a path to a local folder
return pretrained_model_name_or_path

safetensors_path = cached_file(
pretrained_model_name_or_path,
SAFE_WEIGHTS_NAME,
cache_dir=cache_dir,
_raise_exceptions_for_missing_entries=False,
)
index_path = cached_file(
pretrained_model_name_or_path,
SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir,
_raise_exceptions_for_missing_entries=False,
)
if safetensors_path is not None:
# found a single cached safetensors file
return os.path.split(safetensors_path)[0]
if index_path is not None:
# found a cached safetensors weight index file
return os.path.split(index_path)[0]

# model weights could not be found locally or cached from HF Hub
raise ValueError(
"Could not locate safetensors weight or index file from "
f"{pretrained_model_name_or_path}."
)


def get_safetensors_header(safetensors_path: str) -> Dict[str, str]:
"""
Extracts the metadata from a safetensors file as JSON
Expand Down Expand Up @@ -106,6 +148,10 @@ def get_weight_mappings(model_path: str) -> Dict[str, str]:
with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)
header = index["weight_map"]
else:
raise ValueError(
f"Could not find a safetensors weight or index file at {model_path}"
)

# convert weight locations to full paths
for key, value in header.items():
Expand Down Expand Up @@ -148,45 +194,3 @@ def get_nested_weight_mappings(
nested_weight_mappings[dense_param][param_name] = weight_mappings[key]

return nested_weight_mappings


def get_safetensors_folder(
pretrained_model_name_or_path: str, cache_dir: Optional[str] = None
) -> str:
"""
Given a Hugging Face stub or a local path, return the folder containing the
safetensors weight files
:param pretrained_model_name_or_path: local path to model or HF stub
:param cache_dir: optional cache dir to search through, if none is specified the
model will be searched for in the default TRANSFORMERS_CACHE
:return: local folder containing model data
"""
if os.path.exists(pretrained_model_name_or_path):
# argument is a path to a local folder
return pretrained_model_name_or_path

safetensors_path = cached_file(
pretrained_model_name_or_path,
SAFE_WEIGHTS_NAME,
cache_dir=cache_dir,
_raise_exceptions_for_missing_entries=False,
)
index_path = cached_file(
pretrained_model_name_or_path,
SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir,
_raise_exceptions_for_missing_entries=False,
)
if safetensors_path is not None:
# found a single cached safetensors file
return os.path.split(safetensors_path)[0]
if index_path is not None:
# found a cached safetensors weight index file
return os.path.split(index_path)[0]

# model weights could not be found locally or cached from HF Hub
raise ValueError(
"Could not locate safetensors weight or index file from "
f"{pretrained_model_name_or_path}."
)

0 comments on commit 0f89712

Please sign in to comment.