Skip to content

Commit

Permalink
List branches and tags from a repo (#1276)
Browse files Browse the repository at this point in the history
* List branches and tags from a repo

* lazy import RefIfno and RepoRefs

* wrong type

* docs

* requested changes

* code style
  • Loading branch information
Wauplin authored Dec 21, 2022
1 parent c580233 commit 3c422cf
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/source/package_reference/hf_api.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ hf_api = HfApi(

[[autodoc]] huggingface_hub.hf_api.RepoFile

### GitRefs

[[autodoc]] huggingface_hub.hf_api.GitRefs

### GitRefInfo

[[autodoc]] huggingface_hub.hf_api.GitRefInfo

### CommitInfo

[[autodoc]] huggingface_hub.hf_api.CommitInfo
Expand Down
6 changes: 6 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
"CommitOperationAdd",
"CommitOperationDelete",
"DatasetSearchArguments",
"GitRefInfo",
"GitRefs",
"HfApi",
"ModelSearchArguments",
"RepoUrl",
Expand Down Expand Up @@ -139,6 +141,7 @@
"list_metrics",
"list_models",
"list_repo_files",
"list_repo_refs",
"list_spaces",
"merge_pull_request",
"model_info",
Expand Down Expand Up @@ -344,6 +347,8 @@ def __dir__():
from .hf_api import CommitOperationAdd # noqa: F401
from .hf_api import CommitOperationDelete # noqa: F401
from .hf_api import DatasetSearchArguments # noqa: F401
from .hf_api import GitRefInfo # noqa: F401
from .hf_api import GitRefs # noqa: F401
from .hf_api import HfApi # noqa: F401
from .hf_api import ModelSearchArguments # noqa: F401
from .hf_api import RepoUrl # noqa: F401
Expand Down Expand Up @@ -377,6 +382,7 @@ def __dir__():
from .hf_api import list_metrics # noqa: F401
from .hf_api import list_models # noqa: F401
from .hf_api import list_repo_files # noqa: F401
from .hf_api import list_repo_refs # noqa: F401
from .hf_api import list_spaces # noqa: F401
from .hf_api import merge_pull_request # noqa: F401
from .hf_api import model_info # noqa: F401
Expand Down
108 changes: 108 additions & 0 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,52 @@ def clean(s: str):
self["author"] = author_dict


@dataclass
class GitRefInfo:
"""
Contains information about a git reference for a repo on the Hub.
Args:
name (`str`):
Name of the reference (e.g. tag name or branch name).
ref (`str`):
Full git ref on the Hub (e.g. `"refs/heads/main"` or `"refs/tags/v1.0"`).
target_commit (`str`):
OID of the target commit for the ref (e.g. `"e7da7f221d5bf496a48136c0cd264e630fe9fcc8"`)
"""

name: str
ref: str
target_commit: str

def __init__(self, data: Dict) -> None:
self.name = data["name"]
self.ref = data["ref"]
self.target_commit = data["targetCommit"]


@dataclass
class GitRefs:
"""
Contains information about all git references for a repo on the Hub.
Object is returned by [`list_repo_refs`].
Args:
branches (`List[GitRefInfo]`):
A list of [`GitRefInfo`] containing information about branches on the repo.
converts (`List[GitRefInfo]`):
A list of [`GitRefInfo`] containing information about "convert" refs on the repo.
Converts are refs used (internally) to push preprocessed data in Dataset repos.
tags (`List[GitRefInfo]`):
A list of [`GitRefInfo`] containing information about tags on the repo.
"""

branches: List[GitRefInfo]
converts: List[GitRefInfo]
tags: List[GitRefInfo]


@dataclass
class UserLikes:
"""
Expand Down Expand Up @@ -1772,6 +1818,67 @@ def list_repo_files(
)
return [f.rfilename for f in repo_info.siblings]

@validate_hf_hub_args
def list_repo_refs(
self,
repo_id: str,
*,
repo_type: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
) -> GitRefs:
"""
Get the list of refs of a given repo (both tags and branches).
Args:
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if listing refs from a dataset or a Space,
`None` or `"model"` if listing from a model. Default is `None`.
token (`bool` or `str`, *optional*):
A valid authentication token (see https://huggingface.co/settings/token).
If `None` or `True` and machine is logged in (through `huggingface-cli login`
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
If `False`, token is not sent in the request header.
Example:
```py
>>> from huggingface_hub import HfApi
>>> api = HfApi()
>>> api.list_repo_refs("gpt2")
GitRefs(branches=[GitRefInfo(name='main', ref='refs/heads/main', target_commit='e7da7f221d5bf496a48136c0cd264e630fe9fcc8')], converts=[], tags=[])
>>> api.list_repo_refs("bigcode/the-stack", repo_type='dataset')
GitRefs(
branches=[
GitRefInfo(name='main', ref='refs/heads/main', target_commit='18edc1591d9ce72aa82f56c4431b3c969b210ae3'),
GitRefInfo(name='v1.1.a1', ref='refs/heads/v1.1.a1', target_commit='f9826b862d1567f3822d3d25649b0d6d22ace714')
],
converts=[],
tags=[
GitRefInfo(name='v1.0', ref='refs/tags/v1.0', target_commit='c37a8cd1e382064d8aced5e05543c5f7753834da')
]
)
```
Returns:
[`GitRefs`]: object containing all information about branches and tags for a
repo on the Hub.
"""
repo_type = repo_type or REPO_TYPE_MODEL
response = requests.get(
f"{self.endpoint}/api/{repo_type}s/{repo_id}/refs",
headers=self._build_hf_headers(token=token),
)
hf_raise_for_status(response)
data = response.json()
return GitRefs(
branches=[GitRefInfo(item) for item in data["branches"]],
converts=[GitRefInfo(item) for item in data["converts"]],
tags=[GitRefInfo(item) for item in data["tags"]],
)

@validate_hf_hub_args
def create_repo(
self,
Expand Down Expand Up @@ -3915,6 +4022,7 @@ def _parse_revision_from_pr_url(pr_url: str) -> str:

repo_info = api.repo_info
list_repo_files = api.list_repo_files
list_repo_refs = api.list_repo_refs

list_metrics = api.list_metrics

Expand Down
35 changes: 35 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2449,6 +2449,41 @@ def test_request_space_hardware(self, post_mock: Mock) -> None:
)


class ListGitRefsTest(unittest.TestCase):
@classmethod
@with_production_testing
def setUpClass(cls) -> None:
cls.api = HfApi()
return super().setUpClass()

def test_list_refs_gpt2(self) -> None:
refs = self.api.list_repo_refs("gpt2")
self.assertGreater(len(refs.branches), 0)
main_branch = [branch for branch in refs.branches if branch.name == "main"][0]
self.assertEqual(main_branch.ref, "refs/heads/main")
# Can get info by revision
self.api.repo_info("gpt2", revision=main_branch.target_commit)

def test_list_refs_bigcode(self) -> None:
refs = self.api.list_repo_refs("bigcode/evaluation", repo_type="dataset")
self.assertGreater(len(refs.branches), 0)
self.assertGreater(len(refs.converts), 0)
main_branch = [branch for branch in refs.branches if branch.name == "main"][0]
self.assertEqual(main_branch.ref, "refs/heads/main")

convert_branch = [
branch for branch in refs.converts if branch.name == "parquet"
][0]
self.assertEqual(convert_branch.ref, "refs/convert/parquet")

# Can get info by convert revision
self.api.repo_info(
"bigcode/evaluation",
repo_type="dataset",
revision=convert_branch.target_commit,
)


@patch("huggingface_hub.hf_api.build_hf_headers")
class HfApiTokenAttributeTest(unittest.TestCase):
def test_token_passed(self, mock_build_hf_headers: Mock) -> None:
Expand Down

0 comments on commit 3c422cf

Please sign in to comment.