diff --git a/docs/source/en/guides/community.md b/docs/source/en/guides/community.md index ba6979fe80..0baa34e4e5 100644 --- a/docs/source/en/guides/community.md +++ b/docs/source/en/guides/community.md @@ -14,7 +14,7 @@ The `HfApi` class allows you to retrieve Discussions and Pull Requests on a give ```python >>> from huggingface_hub import get_repo_discussions ->>> for discussion in get_repo_discussions(repo_id="bigscience/bloom-1b3"): +>>> for discussion in get_repo_discussions(repo_id="bigscience/bloom"): ... print(f"{discussion.num} - {discussion.title}, pr: {discussion.is_pull_request}") # 11 - Add Flax weights, pr: True @@ -25,6 +25,21 @@ The `HfApi` class allows you to retrieve Discussions and Pull Requests on a give [...] ``` +`HfApi.get_repo_discussions` supports filtering by author, type (Pull Request or Discussion) and status (`open` or `closed`): + +```python +>>> from huggingface_hub import get_repo_discussions +>>> for discussion in get_repo_discussions( +... repo_id="bigscience/bloom", +... author="ArthurZ", +... discussion_type="pull_request", +... discussion_status="open", +... ): +... print(f"{discussion.num} - {discussion.title} by {discussion.author}, pr: {discussion.is_pull_request}") + +# 19 - Add Flax weights by ArthurZ, pr: True +``` + `HfApi.get_repo_discussions` returns a [generator](https://docs.python.org/3.7/howto/functional.html#generators) that yields [`Discussion`] objects. To get all the Discussions in a single list, run: diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index 26bf8a410d..9a55b1d3f3 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -1,6 +1,7 @@ import os import re -from typing import Optional +import typing +from typing import Literal, Optional, Tuple # Possible values for env variables @@ -79,6 +80,10 @@ def _as_int(value: Optional[str]) -> Optional[int]: "models": REPO_TYPE_MODEL, } +DiscussionTypeFilter = Literal["all", "discussion", "pull_request"] +DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter) +DiscussionStatusFilter = Literal["all", "open", "closed"] +DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter) # default cache default_home = os.path.join(os.path.expanduser("~"), ".cache") diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 0a42de34c4..de1cd8f8fc 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -40,7 +40,7 @@ Union, overload, ) -from urllib.parse import quote +from urllib.parse import quote, urlencode import requests from requests.exceptions import HTTPError @@ -93,6 +93,8 @@ from .constants import ( DEFAULT_ETAG_TIMEOUT, DEFAULT_REVISION, + DISCUSSION_STATUS, + DISCUSSION_TYPES, ENDPOINT, INFERENCE_ENDPOINTS_ENDPOINT, REGEX_COMMIT_OID, @@ -101,6 +103,8 @@ REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES, SPACES_SDK_TYPES, + DiscussionStatusFilter, + DiscussionTypeFilter, ) from .file_download import ( get_hf_file_metadata, @@ -5196,6 +5200,9 @@ def get_repo_discussions( self, repo_id: str, *, + author: Optional[str] = None, + discussion_type: Optional[DiscussionTypeFilter] = None, + discussion_status: Optional[DiscussionStatusFilter] = None, repo_type: Optional[str] = None, token: Optional[str] = None, ) -> Iterator[Discussion]: @@ -5206,6 +5213,18 @@ def get_repo_discussions( repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. + author (`str`, *optional*): + Pass a value to filter by discussion author. `None` means no filter. + Default is `None`. + discussion_type (`str`, *optional*): + Set to `"pull_request"` to fetch only pull requests, `"discussion"` + to fetch only discussions. Set to `"all"` or `None` to fetch both. + Default is `None`. + discussion_status (`str`, *optional*): + Set to `"open"` (respectively `"closed"`) to fetch only open + (respectively closed) discussions. Set to `"all"` or `None` + to fetch both. + Default is `None`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if fetching from a dataset or space, `None` or `"model"` if fetching from a model. Default is @@ -5236,11 +5255,23 @@ def get_repo_discussions( raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") if repo_type is None: repo_type = REPO_TYPE_MODEL + if discussion_type is not None and discussion_type not in DISCUSSION_TYPES: + raise ValueError(f"Invalid discussion_type, must be one of {DISCUSSION_TYPES}") + if discussion_status is not None and discussion_status not in DISCUSSION_STATUS: + raise ValueError(f"Invalid discussion_status, must be one of {DISCUSSION_STATUS}") headers = self._build_hf_headers(token=token) + query_dict: Dict[str, str] = {} + if discussion_type is not None: + query_dict["type"] = discussion_type + if discussion_status is not None: + query_dict["status"] = discussion_status + if author is not None: + query_dict["author"] = author def _fetch_discussion_page(page_index: int): - path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}" + query_string = urlencode({**query_dict, "page_index": page_index}) + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?{query_string}" resp = get_session().get(path, headers=headers) hf_raise_for_status(resp) paginated_discussions = resp.json() diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index a7107c93e7..0d4fdfedf8 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -2427,6 +2427,43 @@ def test_get_repo_discussion(self): list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] ) + def test_get_repo_discussion_by_type(self): + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="pull_request") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num]) + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="discussion") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num]) + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="all") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual( + list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] + ) + + def test_get_repo_discussion_by_author(self): + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author="unknown") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), []) + + def test_get_repo_discussion_by_status(self): + self._api.change_discussion_status(self.repo_id, self.discussion.num, "closed") + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="open") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num]) + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="closed") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num]) + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="all") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual( + list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] + ) + def test_get_discussion_details(self): retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=2) self.assertEqual(retrieved, self.discussion)