Skip to content

Commit

Permalink
Add and endpoints (#1181)
Browse files Browse the repository at this point in the history
  • Loading branch information
Wauplin authored Nov 14, 2022
1 parent 8f7d0d0 commit 337351d
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 5 deletions.
4 changes: 4 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@
"ModelSearchArguments",
"change_discussion_status",
"comment_discussion",
"create_branch",
"create_commit",
"create_discussion",
"create_pull_request",
"create_repo",
"create_tag",
"dataset_info",
"delete_branch",
"delete_file",
"delete_folder",
"delete_repo",
Expand Down Expand Up @@ -329,12 +331,14 @@ def __dir__():
from .hf_api import ModelSearchArguments # noqa: F401
from .hf_api import change_discussion_status # noqa: F401
from .hf_api import comment_discussion # noqa: F401
from .hf_api import create_branch # noqa: F401
from .hf_api import create_commit # noqa: F401
from .hf_api import create_discussion # noqa: F401
from .hf_api import create_pull_request # noqa: F401
from .hf_api import create_repo # noqa: F401
from .hf_api import create_tag # noqa: F401
from .hf_api import dataset_info # noqa: F401
from .hf_api import delete_branch # noqa: F401
from .hf_api import delete_file # noqa: F401
from .hf_api import delete_folder # noqa: F401
from .hf_api import delete_repo # noqa: F401
Expand Down
108 changes: 103 additions & 5 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,6 +2425,103 @@ def delete_folder(
parent_commit=parent_commit,
)

@validate_hf_hub_args
def create_branch(
self,
repo_id: str,
*,
branch: str,
token: Optional[str] = None,
repo_type: Optional[str] = None,
) -> None:
"""
Create a new branch from `main` on a repo on the Hub.
Args:
repo_id (`str`):
The repository in which the branch will be created.
Example: `"user/my-cool-model"`.
branch (`str`):
The name of the branch to create.
token (`str`, *optional*):
Authentication token. Will default to the stored token.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if creating a branch on a dataset or
space, `None` or `"model"` if tagging a model. Default is `None`.
Raises:
[`~utils.RepositoryNotFoundError`]:
If repository is not found (error 404): wrong repo_id/repo_type, private
but not authenticated or repo does not exist.
[`~utils.BadRequestError`]:
If invalid reference for a branch. Ex: `refs/pr/5` or 'refs/foo/bar'.
[`~utils.HfHubHTTPError`]:
If the branch already exists on the repo (error 409).
"""
if repo_type is None:
repo_type = REPO_TYPE_MODEL
branch = quote(branch, safe="")

# Prepare request
branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}"
headers = self._build_hf_headers(token=token, is_write_action=True)

# Create branch
response = requests.post(url=branch_url, headers=headers)
hf_raise_for_status(response)

@validate_hf_hub_args
def delete_branch(
self,
repo_id: str,
*,
branch: str,
token: Optional[str] = None,
repo_type: Optional[str] = None,
) -> None:
"""
Delete a branch from a repo on the Hub.
Args:
repo_id (`str`):
The repository in which a branch will be deleted.
Example: `"user/my-cool-model"`.
branch (`str`):
The name of the branch to delete.
token (`str`, *optional*):
Authentication token. Will default to the stored token.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if creating a branch on a dataset or
space, `None` or `"model"` if tagging a model. Default is `None`.
Raises:
[`~utils.RepositoryNotFoundError`]:
If repository is not found (error 404): wrong repo_id/repo_type, private
but not authenticated or repo does not exist.
[`~utils.HfHubHTTPError`]:
If trying to delete a protected branch. Ex: `main` cannot be deleted.
[`~utils.HfHubHTTPError`]:
If trying to delete a branch that does not exist.
"""
if repo_type is None:
repo_type = REPO_TYPE_MODEL
branch = quote(branch, safe="")

# Prepare request
branch_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/branch/{branch}"
headers = self._build_hf_headers(token=token, is_write_action=True)

# Delete branch
response = requests.delete(url=branch_url, headers=headers)
hf_raise_for_status(response)

@validate_hf_hub_args
def create_tag(
self,
Expand Down Expand Up @@ -2501,7 +2598,7 @@ def delete_tag(
Args:
repo_id (`str`):
The repository in which a commit will be deleted.
The repository in which a tag will be deleted.
Example: `"user/my-cool-model"`.
tag (`str`):
Expand All @@ -2511,20 +2608,19 @@ def delete_tag(
Authentication token. Will default to the stored token.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if tagging a dataset or
space, `None` or `"model"` if tagging a model. Default is
`None`.
Set to `"dataset"` or `"space"` if tagging a dataset or space, `None` or
`"model"` if tagging a model. Default is `None`.
Raises:
[`~utils.RepositoryNotFoundError`]:
If repository is not found (error 404): wrong repo_id/repo_type, private
but not authenticated or repo does not exist.
[`~utils.RevisionNotFoundError`]:
If tag is not found.
"""
if repo_type is None:
repo_type = REPO_TYPE_MODEL
tag = quote(tag, safe="")

# Prepare request
tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{tag}"
Expand Down Expand Up @@ -3443,6 +3539,8 @@ def _warn_if_truncated(
upload_folder = api.upload_folder
delete_file = api.delete_file
delete_folder = api.delete_folder
create_branch = api.create_branch
delete_branch = api.delete_branch
create_tag = api.create_tag
delete_tag = api.delete_tag
get_full_repo_name = api.get_full_repo_name
Expand Down
70 changes: 70 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
repo_type_and_id_from_hf_id,
)
from huggingface_hub.utils import (
BadRequestError,
EntryNotFoundError,
HfFolder,
HfHubHTTPError,
Expand Down Expand Up @@ -1329,6 +1330,75 @@ def test_delete_tag_with_branch_name(self) -> None:
self._api.delete_tag(self._repo_id, tag="main")


class HfApiBranchEndpointTest(HfApiCommonTestWithLogin):
_user = USER
_repo_id: str

@retry_endpoint
@use_tmp_repo()
def test_create_and_delete_branch(self) -> None:
"""Test `create_branch` from main branch."""
self._api.create_branch(self._repo_id, branch="cool-branch")

# Check `cool-branch` branch exists
self._api.model_info(self._repo_id, revision="cool-branch")

# Delete it
self._api.delete_branch(self._repo_id, branch="cool-branch")

# Check doesn't exist anymore
with self.assertRaises(RevisionNotFoundError):
self._api.model_info(self._repo_id, revision="cool-branch")

@retry_endpoint
@use_tmp_repo()
def test_create_branch_existing_branch_fails(self) -> None:
"""Test `create_branch` on existing branch."""
self._api.create_branch(self._repo_id, branch="cool-branch")

with self.assertRaisesRegex(HfHubHTTPError, "Reference already exists"):
self._api.create_branch(self._repo_id, branch="cool-branch")

with self.assertRaisesRegex(HfHubHTTPError, "Reference already exists"):
self._api.create_branch(self._repo_id, branch="main")

@retry_endpoint
@use_tmp_repo()
def test_create_branch_existing_tag_does_not_fail(self) -> None:
"""Test `create_branch` on existing tag."""
self._api.create_tag(self._repo_id, tag="tag")
self._api.create_branch(self._repo_id, branch="tag")

@retry_endpoint
@use_tmp_repo()
def test_create_branch_forbidden_ref_branch_fails(self) -> None:
"""Test `create_branch` on forbidden ref branch."""
with self.assertRaisesRegex(BadRequestError, "Invalid reference for a branch"):
self._api.create_branch(self._repo_id, branch="refs/pr/5")

with self.assertRaisesRegex(BadRequestError, "Invalid reference for a branch"):
self._api.create_branch(self._repo_id, branch="refs/something/random")

@retry_endpoint
@use_tmp_repo()
def test_delete_branch_on_protected_branch_fails(self) -> None:
"""Test `delete_branch` fails on protected branch."""
with self.assertRaisesRegex(HfHubHTTPError, "Cannot delete refs/heads/main"):
self._api.delete_branch(self._repo_id, branch="main")

@retry_endpoint
@use_tmp_repo()
def test_delete_branch_on_missing_branch_fails(self) -> None:
"""Test `delete_branch` fails on missing branch."""
with self.assertRaisesRegex(HfHubHTTPError, "Reference does not exist"):
self._api.delete_branch(self._repo_id, branch="cool-branch")

# Using a tag instead of branch -> fails
self._api.create_tag(self._repo_id, tag="cool-tag")
with self.assertRaisesRegex(HfHubHTTPError, "Reference does not exist"):
self._api.delete_branch(self._repo_id, branch="cool-tag")


class HfApiPublicStagingTest(unittest.TestCase):
def setUp(self) -> None:
self._api = HfApi()
Expand Down

0 comments on commit 337351d

Please sign in to comment.