Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

List branches and tags from a repo #1276

Merged
merged 6 commits into from
Dec 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/package_reference/hf_api.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ hf_api = HfApi(

[[autodoc]] huggingface_hub.hf_api.RepoFile

### GitRefs

[[autodoc]] huggingface_hub.hf_api.GitRefs

### GitRefInfo

[[autodoc]] huggingface_hub.hf_api.GitRefInfo

### CommitInfo

[[autodoc]] huggingface_hub.hf_api.CommitInfo
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
"CommitOperationAdd",
"CommitOperationDelete",
"DatasetSearchArguments",
"GitRefInfo",
"GitRefs",
"HfApi",
"ModelSearchArguments",
"RepoUrl",
Expand Down Expand Up @@ -139,6 +141,7 @@
"list_metrics",
"list_models",
"list_repo_files",
"list_repo_refs",
"list_spaces",
"merge_pull_request",
"model_info",
Expand Down Expand Up @@ -344,6 +347,8 @@ def __dir__():
from .hf_api import CommitOperationAdd # noqa: F401
from .hf_api import CommitOperationDelete # noqa: F401
from .hf_api import DatasetSearchArguments # noqa: F401
from .hf_api import GitRefInfo # noqa: F401
from .hf_api import GitRefs # noqa: F401
from .hf_api import HfApi # noqa: F401
from .hf_api import ModelSearchArguments # noqa: F401
from .hf_api import RepoUrl # noqa: F401
Expand Down Expand Up @@ -377,6 +382,7 @@ def __dir__():
from .hf_api import list_metrics # noqa: F401
from .hf_api import list_models # noqa: F401
from .hf_api import list_repo_files # noqa: F401
from .hf_api import list_repo_refs # noqa: F401
from .hf_api import list_spaces # noqa: F401
from .hf_api import merge_pull_request # noqa: F401
from .hf_api import model_info # noqa: F401
Expand Down
108 changes: 108 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,52 @@ def clean(s: str):
self["author"] = author_dict


@dataclass
class GitRefInfo:
"""
Contains information about a git reference for a repo on the Hub.

Args:
name (`str`):
Name of the reference (e.g. tag name or branch name).
ref (`str`):
Full git ref on the Hub (e.g. `"refs/heads/main"` or `"refs/tags/v1.0"`).
target_commit (`str`):
OID of the target commit for the ref (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`)
"""

name: str
ref: str
target_commit: str

def __init__(self, data: Dict) -> None:
self.name = data["name"]
self.ref = data["ref"]
self.target_commit = data["targetCommit"]


@dataclass
class GitRefs:
"""
Contains information about all git references for a repo on the Hub.

Object is returned by [`list_repo_refs`].

Args:
branches (`List[GitRefInfo]`):
A list of [`GitRefInfo`] containing information about branches on the repo.
converts (`List[GitRefInfo]`):
A list of [`GitRefInfo`] containing information about "convert" refs on the repo.
Converts are refs used (internally) to push preprocessed data in Dataset repos.
tags (`List[GitRefInfo]`):
A list of [`GitRefInfo`] containing information about tags on the repo.
"""

branches: List[GitRefInfo]
converts: List[GitRefInfo]
tags: List[GitRefInfo]


@dataclass
class UserLikes:
"""
Expand Down Expand Up @@ -1772,6 +1818,67 @@ def list_repo_files(
)
return [f.rfilename for f in repo_info.siblings]

@validate_hf_hub_args
def list_repo_refs(
self,
repo_id: str,
*,
repo_type: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
) -> GitRefs:
"""
Get the list of refs of a given repo (both tags and branches).

Args:
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if listing refs from a dataset or a Space,
`None` or `"model"` if listing from a model. Default is `None`.
token (`bool` or `str`, *optional*):
A valid authentication token (see https://huggingface.co/settings/token).
If `None` or `True` and machine is logged in (through `huggingface-cli login`
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
If `False`, token is not sent in the request header.

Example:
```py
>>> from huggingface_hub import HfApi
>>> api = HfApi()
>>> api.list_repo_refs("gpt2")
GitRefs(branches=[GitRefInfo(name='main', ref='refs/heads/main', target_commit='e7da7f221d5bf496a48136c0cd264e630fe9fcc8')], converts=[], tags=[])

>>> api.list_repo_refs("bigcode/the-stack", repo_type='dataset')
GitRefs(
branches=[
GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'),
GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714')
],
converts=[],
tags=[
GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da')
]
)
```

Returns:
[`GitRefs`]: object containing all information about branches and tags for a
repo on the Hub.
"""
repo_type = repo_type or REPO_TYPE_MODEL
response = requests.get(
f"{self.endpoint}/api/{repo_type}s/{repo_id}/refs",
severo marked this conversation as resolved.
Show resolved Hide resolved
headers=self._build_hf_headers(token=token),
)
hf_raise_for_status(response)
data = response.json()
return GitRefs(
branches=[GitRefInfo(item) for item in data["branches"]],
converts=[GitRefInfo(item) for item in data["converts"]],
tags=[GitRefInfo(item) for item in data["tags"]],
)

@validate_hf_hub_args
def create_repo(
self,
Expand Down Expand Up @@ -3915,6 +4022,7 @@ def _parse_revision_from_pr_url(pr_url: str) -> str:

repo_info = api.repo_info
list_repo_files = api.list_repo_files
list_repo_refs = api.list_repo_refs

list_metrics = api.list_metrics

Expand Down
35 changes: 35 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,6 +2449,41 @@ def test_request_space_hardware(self, post_mock: Mock) -> None:
)


class ListGitRefsTest(unittest.TestCase):
@classmethod
@with_production_testing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart to only define it on the setUpClass for it to act on all subsequent methods. I wonder if it wouldn't be cleaner to put it on the class itself so that it's very clear that it should affect the entire class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I did that as a workaround because

@with_production_testing
class ListRepoRefsTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):

was not working for me 😢
But I have to admit my debug time on that was approximately 30s, just the time to try some possibilities and take the first one that worked 🙄

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha in that case we can also keep it there, it was just a nit, to be fair.

def setUpClass(cls) -> None:
cls.api = HfApi()
return super().setUpClass()

def test_list_refs_gpt2(self) -> None:
refs = self.api.list_repo_refs("gpt2")
self.assertGreater(len(refs.branches), 0)
main_branch = [branch for branch in refs.branches if branch.name == "main"][0]
self.assertEqual(main_branch.ref, "refs/heads/main")
# Can get info by revision
self.api.repo_info("gpt2", revision=main_branch.target_commit)

def test_list_refs_bigcode(self) -> None:
refs = self.api.list_repo_refs("bigcode/evaluation", repo_type="dataset")
self.assertGreater(len(refs.branches), 0)
self.assertGreater(len(refs.converts), 0)
main_branch = [branch for branch in refs.branches if branch.name == "main"][0]
self.assertEqual(main_branch.ref, "refs/heads/main")

convert_branch = [
branch for branch in refs.converts if branch.name == "parquet"
][0]
self.assertEqual(convert_branch.ref, "refs/convert/parquet")

# Can get info by convert revision
self.api.repo_info(
"bigcode/evaluation",
repo_type="dataset",
revision=convert_branch.target_commit,
)


@patch("huggingface_hub.hf_api.build_hf_headers")
class HfApiTokenAttributeTest(unittest.TestCase):
def test_token_passed(self, mock_build_hf_headers: Mock) -> None:
Expand Down