Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make using safetensors files automated. #27571

Merged
merged 16 commits into from
Dec 1, 2023
13 changes: 11 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
prune_layer,
prune_linear_layer,
)
from .safetensors_conversion import auto_conversion
from .utils import (
ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME,
Expand Down Expand Up @@ -3054,9 +3055,14 @@ def from_pretrained(
if resolved_archive_file is not None:
is_sharded = True
elif use_safetensors:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
resolved_archive_file, revision = auto_conversion(
pretrained_model_name_or_path, filename, **cached_file_kwargs
)
cached_file_kwargs["revision"] = revision
if resolved_archive_file is None:
raise EnvironmentError(
f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} and thus cannot be loaded with `safetensors`. Please make sure that the model has been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved
)
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved
else:
# This repo has no safetensors file of any kind, we switch to PyTorch.
filename = _add_variant(WEIGHTS_NAME, variant)
Expand All @@ -3080,6 +3086,9 @@ def from_pretrained(
"proxies": proxies,
"token": token,
}
import ipdb

ipdb.set_trace()
if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named"
Expand Down
99 changes: 99 additions & 0 deletions src/transformers/safetensors_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Optional

from huggingface_hub import Discussion, HfApi

from .utils import cached_file, logging


logger = logging.get_logger(__name__)


def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]:
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved
try:
main_commit = api.list_repo_commits(model_id)[0].commit_id
discussions = api.get_repo_discussions(repo_id=model_id)
except Exception:
return None
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved
for discussion in discussions:
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
commits = api.list_repo_commits(model_id, revision=discussion.git_reference)

if main_commit == commits[1].commit_id:
return discussion
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can implem a Hub API to replace this whole function IMO cc @SBrandeis

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Courtesy of @SBrandeis (huggingface/huggingface_hub#1845):

for discussion in get_repo_discussions(
    repo_id="openai/whisper-large-v3",
    author="sanchit-gandhi",
    discussion_type="pull_request",
    discussion_status="open",
):
    ...

Copy link
Member

@LysandreJik LysandreJik Nov 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Simon and Lucain! Applied in d32a18e



def spawn_conversion(token: str, model_id: str):
print("Sending conversion request")
import asyncio
import json
import uuid

import websockets
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved

async def start(websocket, payload):
_hash = str(uuid.uuid4())
while True:
data = await websocket.recv()
print(f"<{data}")
data = json.loads(data)
if data["msg"] == "send_hash":
data = json.dumps({"fn_index": 0, "session_hash": _hash})
print(f">{data}")
await websocket.send(data)
elif data["msg"] == "send_data":
data = json.dumps({"fn_index": 0, "session_hash": _hash, "data": payload})
print(f">{data}")
await websocket.send(data)
elif data["msg"] == "process_completed":
break

async def main():
print("======================")
uri = "wss://safetensors-convert.hf.space/queue/join"
async with websockets.connect(uri) as websocket:
# inputs and parameters are classic, "id" is a way to track that query
data = [token, model_id]
try:
await start(websocket, data)
except Exception as e:
print(f"Error during space conversion: {e}")

asyncio.run(main())


def get_sha(model_id: str, filename: str, **kwargs):
api = HfApi(token=kwargs.get("token"))
# model_info = api.model_info(model_id)
# refs = api.list_repo_refs(model_id)

# main_refs = [branch.target_commit for branch in refs.branches if branch.ref == "refs/heads/main"]
# main_sha = None
# if main_refs:
# main_sha = main_refs[0]

logger.info("Attempting to create safetensors variant")
pr_title = "Adding `safetensors` variant of this model"
pr = previous_pr(api, model_id, pr_title)
if pr is None:
from multiprocessing import Process

process = Process(target=spawn_conversion, args=(kwargs.get("token"), model_id))
process.start()
process.join()
pr = previous_pr(api, model_id, pr_title)
sha = f"refs/pr/{pr.num}"
else:
logger.info("Safetensors PR exists")
sha = f"refs/pr/{pr.num}"
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved
return sha


def auto_conversion(pretrained_model_name_or_path: str, filename: str, **cached_file_kwargs):
sha = get_sha(pretrained_model_name_or_path, filename, **cached_file_kwargs)
if sha is None:
return None, None
cached_file_kwargs["revision"] = sha
del cached_file_kwargs["_commit_hash"]
resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved
return resolved_archive_file, sha
Loading