From cfe58dfbb79de1eca110f199ee3117f8d8258391 Mon Sep 17 00:00:00 2001 From: Zachary Mueller Date: Mon, 24 Jan 2022 12:03:23 -0600 Subject: [PATCH] Improvements non git mixin (#618) Add in commit_message + tests Co-authored-by: Lysandre --- src/huggingface_hub/hf_api.py | 12 +++++++-- src/huggingface_hub/hub_mixin.py | 12 ++++++--- tests/test_hubmixin.py | 42 ++++++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 5 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 404bee95cb..591eb5b864 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -1213,6 +1213,7 @@ def upload_file( path_or_fileobj: Union[str, bytes, IO], path_in_repo: str, repo_id: str, + commit_message: Optional[str] = None, token: Optional[str] = None, repo_type: Optional[str] = None, revision: Optional[str] = None, @@ -1232,7 +1233,10 @@ def upload_file( repo_id (``str``): The repository to which the file will be uploaded, for example: :obj:`"username/custom_transformers"` - token (``str``): + commit_message (``str``, Optional): + A commit message to be logged in the revision history when pushing this file. Defaults to "Upload" or "Update" + + token (``str``, Optional): Authentication token, obtained with :function:`HfApi.login` method. Will default to the stored token. repo_type (``str``, Optional): @@ -1332,7 +1336,11 @@ def upload_file( path = f"{self.endpoint}/api/{repo_id}/upload/{revision}/{path_in_repo}" - headers = {"authorization": f"Bearer {token}"} if token is not None else None + headers = {} + if token is not None: + headers["authorization"] = f"Bearer {token}" + if commit_message is not None: + headers["Commit-Summary"] = commit_message if isinstance(path_or_fileobj, str): with open(path_or_fileobj, "rb") as bytestream: diff --git a/src/huggingface_hub/hub_mixin.py b/src/huggingface_hub/hub_mixin.py index d670fb313a..0ca8d087c1 100644 --- a/src/huggingface_hub/hub_mixin.py +++ b/src/huggingface_hub/hub_mixin.py @@ -58,7 +58,7 @@ def save_pretrained( json.dump(config, f) # saving model weights/files - files = self._save_pretrained(save_directory, **kwargs) + files = self._save_pretrained(save_directory) if push_to_hub: return self.push_to_hub(save_directory, **kwargs) @@ -259,8 +259,8 @@ def push_to_hub( if repo_url is None: repo_name = Path(repo_path_or_name).name repo_url = HfApi(endpoint=api_endpoint).create_repo( + token, repo_name, - token=token, organization=organization, private=private, repo_type=None, @@ -285,7 +285,13 @@ def push_to_hub( for file in files: common_prefix = os.path.commonprefix([saved_path, file]) relative_path = os.path.relpath(file, common_prefix) - api.upload_file(token, file, path_in_repo=relative_path, repo_id=name) + api.upload_file( + file, + path_in_repo=relative_path, + token=token, + repo_id=name, + commit_message=commit_message, + ) class PyTorchModelHubMixin(ModelHubMixin): diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index e9f673d572..f08d7c301b 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -3,6 +3,7 @@ import time import unittest import uuid +from io import BytesIO from huggingface_hub import HfApi from huggingface_hub.file_download import is_torch_available @@ -131,3 +132,44 @@ def test_push_to_hub(self): self.assertEqual(model_info.modelId, f"{USER}/{REPO_NAME}") self._api.delete_repo(name=f"{REPO_NAME}", token=self._token) + + def test_push_to_hub_with_other_files(self): + REPO_A = repo_name("with_files") + REPO_B = repo_name("without_files") + self._api.create_repo(token=self._token, name=REPO_A) + self._api.create_repo(token=self._token, name=REPO_B) + for i in range(5): + self._api.upload_file( + # Each are .5mb in size + path_or_fileobj=BytesIO(os.urandom(500000)), + path_in_repo=f"temp/new_file_{i}.bytes", + repo_id=f"{USER}/{REPO_A}", + token=self._token, + ) + model = DummyModel() + start_time = time.time() + model.push_to_hub( + repo_path_or_name=f"{WORKING_REPO_SUBDIR}/{REPO_A}", + api_endpoint=ENDPOINT_STAGING, + use_auth_token=self._token, + git_user="ci", + git_email="ci@dummy.com", + config={"num": 7, "act": "gelu_fast"}, + ) + REPO_A_TIME = start_time - time.time() + + start_time = time.time() + model.push_to_hub( + repo_path_or_name=f"{USER}/{REPO_B}", + api_endpoint=ENDPOINT_STAGING, + use_auth_token=self._token, + git_user="ci", + git_email="ci@dummy.com", + config={"num": 7, "act": "gelu_fast"}, + ) + REPO_B_TIME = start_time - time.time() + # Less than half a second from each other + self.assertLess(REPO_A_TIME - REPO_B_TIME, 0.5) + + self._api.delete_repo(name=REPO_A, token=self._token) + self._api.delete_repo(name=REPO_B, token=self._token)