From 33baffe4d29e1a9cf12786aed2f227b43268d720 Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Mon, 20 Nov 2023 14:00:45 +0000 Subject: [PATCH 1/8] =?UTF-8?q?=E2=9C=A8=20Add=20filters=20to=20HfApi.get?= =?UTF-8?q?=5Frepo=5Fdiscussions?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/huggingface_hub/constants.py | 2 ++ src/huggingface_hub/hf_api.py | 35 +++++++++++++++++++++++++++++--- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index 26bf8a410d..22243b7db8 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -79,6 +79,8 @@ def _as_int(value: Optional[str]) -> Optional[int]: "models": REPO_TYPE_MODEL, } +DISCUSSION_TYPES = ["all", "discussion", "pull_request"] +DISCUSSION_STATUS = ["all", "open", "closed"] # default cache default_home = os.path.join(os.path.expanduser("~"), ".cache") diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index b9ef66b4ed..f0e941f576 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -40,7 +40,7 @@ Union, overload, ) -from urllib.parse import quote +from urllib.parse import quote, urlencode import requests from requests.exceptions import HTTPError @@ -93,6 +93,8 @@ from .constants import ( DEFAULT_ETAG_TIMEOUT, DEFAULT_REVISION, + DISCUSSION_STATUS, + DISCUSSION_TYPES, ENDPOINT, INFERENCE_ENDPOINTS_ENDPOINT, REGEX_COMMIT_OID, @@ -3675,7 +3677,7 @@ def create_commits_on_pr( logger.info(f"Multi-commits strategy with ID {strategy.id}.") # 2. Get or create a PR with this strategy ID - for discussion in self.get_repo_discussions(repo_id=repo_id, repo_type=repo_type, token=token): + for discussion in self.get_repo_discussions(repo_id=repo_id, repo_type=repo_type, token=token, discussion_type="pull_request"): # search for a draft PR with strategy ID if discussion.is_pull_request and discussion.status == "draft" and strategy.id in discussion.title: pr = self.get_discussion_details( @@ -5191,6 +5193,9 @@ def get_repo_discussions( self, repo_id: str, *, + author: Optional[str] = None, + discussion_type: Optional[str] = None, + discussion_status: Optional[str] = None, repo_type: Optional[str] = None, token: Optional[str] = None, ) -> Iterator[Discussion]: @@ -5201,6 +5206,18 @@ def get_repo_discussions( repo_id (`str`): A namespace (user or an organization) and a repo name separated by a `/`. + author (`str`, *optional*): + Pass a value to filter by discussion author. `None` means no filter. + Default is `None`. + discussion_type (`str`, *optional*): + Set to `"pull_request"` to fetch only pull requests, `"discussion"` + to fetch only discussions. Set to `"all"` or `None` to fetch both. + Default is `None`. + discussion_status (`str`, *optional*): + Set to `"open"` (respectively `"closed"`) to fetch only open + (respectively closed) discussions. Set to `"all"` or `None` + to fetch both. + Default is `None`. repo_type (`str`, *optional*): Set to `"dataset"` or `"space"` if fetching from a dataset or space, `None` or `"model"` if fetching from a model. Default is @@ -5231,11 +5248,23 @@ def get_repo_discussions( raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") if repo_type is None: repo_type = REPO_TYPE_MODEL + if discussion_type is not None and discussion_type not in DISCUSSION_TYPES: + raise ValueError(f"Invalid discussion_type, must be one of {DISCUSSION_TYPES}") + if discussion_status is not None and discussion_status not in DISCUSSION_STATUS: + raise ValueError(f"Invalid discussion_status, must be one of {DISCUSSION_STATUS}") headers = self._build_hf_headers(token=token) + query_dict = { } + if discussion_type is not None: + query_dict["type"] = discussion_type + if discussion_status is not None: + query_dict["status"] = discussion_status + if author is not None: + query_dict["author"] = author def _fetch_discussion_page(page_index: int): - path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}" + query_string = urlencode({ **query_dict, "page_index": page_index }) + path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?{query_string}" resp = get_session().get(path, headers=headers) hf_raise_for_status(resp) paginated_discussions = resp.json() From 75f9e73a378ff6a50daff58cab5c5bb104729edd Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Mon, 20 Nov 2023 14:17:32 +0000 Subject: [PATCH 2/8] =?UTF-8?q?=F0=9F=92=84=20make=20style?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/huggingface_hub/hf_api.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index f0e941f576..a25712bdf8 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -3677,7 +3677,9 @@ def create_commits_on_pr( logger.info(f"Multi-commits strategy with ID {strategy.id}.") # 2. Get or create a PR with this strategy ID - for discussion in self.get_repo_discussions(repo_id=repo_id, repo_type=repo_type, token=token, discussion_type="pull_request"): + for discussion in self.get_repo_discussions( + repo_id=repo_id, repo_type=repo_type, token=token, discussion_type="pull_request" + ): # search for a draft PR with strategy ID if discussion.is_pull_request and discussion.status == "draft" and strategy.id in discussion.title: pr = self.get_discussion_details( @@ -5254,7 +5256,7 @@ def get_repo_discussions( raise ValueError(f"Invalid discussion_status, must be one of {DISCUSSION_STATUS}") headers = self._build_hf_headers(token=token) - query_dict = { } + query_dict = {} if discussion_type is not None: query_dict["type"] = discussion_type if discussion_status is not None: @@ -5263,7 +5265,7 @@ def get_repo_discussions( query_dict["author"] = author def _fetch_discussion_page(page_index: int): - query_string = urlencode({ **query_dict, "page_index": page_index }) + query_string = urlencode({**query_dict, "page_index": page_index}) path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?{query_string}" resp = get_session().get(path, headers=headers) hf_raise_for_status(resp) From f42b21a183e07ad97e96edc3a02f522a0b13af79 Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Mon, 20 Nov 2023 14:39:24 +0000 Subject: [PATCH 3/8] =?UTF-8?q?=E2=8F=AA=20Revert=20extraneous=20change?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/huggingface_hub/hf_api.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index a25712bdf8..67ec1f8089 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -3677,9 +3677,7 @@ def create_commits_on_pr( logger.info(f"Multi-commits strategy with ID {strategy.id}.") # 2. Get or create a PR with this strategy ID - for discussion in self.get_repo_discussions( - repo_id=repo_id, repo_type=repo_type, token=token, discussion_type="pull_request" - ): + for discussion in self.get_repo_discussions(repo_id=repo_id, repo_type=repo_type, token=token): # search for a draft PR with strategy ID if discussion.is_pull_request and discussion.status == "draft" and strategy.id in discussion.title: pr = self.get_discussion_details( From ad37d0d12ed1521f410f400443ef3cb01bd7d6a2 Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Wed, 22 Nov 2023 13:21:29 +0000 Subject: [PATCH 4/8] =?UTF-8?q?=F0=9F=A9=B9=20Mention=20filters=20in=20the?= =?UTF-8?q?=20documentation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/source/en/guides/community.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/docs/source/en/guides/community.md b/docs/source/en/guides/community.md index ba6979fe80..306b5d51c4 100644 --- a/docs/source/en/guides/community.md +++ b/docs/source/en/guides/community.md @@ -14,7 +14,7 @@ The `HfApi` class allows you to retrieve Discussions and Pull Requests on a give ```python >>> from huggingface_hub import get_repo_discussions ->>> for discussion in get_repo_discussions(repo_id="bigscience/bloom-1b3"): +>>> for discussion in get_repo_discussions(repo_id="bigscience/bloom"): ... print(f"{discussion.num} - {discussion.title}, pr: {discussion.is_pull_request}") # 11 - Add Flax weights, pr: True @@ -25,6 +25,21 @@ The `HfApi` class allows you to retrieve Discussions and Pull Requests on a give [...] ``` +`HfApi.get_repo_discussions` supports by author, type (Pull Request or Discussion) and status (`open` or `closed`): + +```python +>>> from huggingface_hub import get_repo_discussions +>>> for discussion in get_repo_discussions( +... repo_id="bigscience/bloom", +... author="ArthurZ", +... discussion_type="pull_request", +... discussion_status="open", +... ): +... print(f"{discussion.num} - {discussion.title} by {discussion.author}, pr: {discussion.is_pull_request}") + +# 19 - Add Flax weights by ArthurZ, pr: True +``` + `HfApi.get_repo_discussions` returns a [generator](https://docs.python.org/3.7/howto/functional.html#generators) that yields [`Discussion`] objects. To get all the Discussions in a single list, run: From 1431a5f1730832ecfc0482f3e9c13d6a3ff3619b Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Wed, 22 Nov 2023 13:28:50 +0000 Subject: [PATCH 5/8] =?UTF-8?q?=F0=9F=91=8C=20Literal=20type=20annotation?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/huggingface_hub/constants.py | 9 ++++++--- src/huggingface_hub/hf_api.py | 6 ++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index 22243b7db8..4839cd433b 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -1,6 +1,7 @@ import os import re -from typing import Optional +from typing import Literal, Optional, Tuple +import typing # Possible values for env variables @@ -79,8 +80,10 @@ def _as_int(value: Optional[str]) -> Optional[int]: "models": REPO_TYPE_MODEL, } -DISCUSSION_TYPES = ["all", "discussion", "pull_request"] -DISCUSSION_STATUS = ["all", "open", "closed"] +DiscussionTypeFilter = Literal["all", "discussion", "pull_request"] +DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter) +DiscussionStatusFilter = Literal["all", "open", "closed"] +DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter) # default cache default_home = os.path.join(os.path.expanduser("~"), ".cache") diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 31f7f75430..51f50b9c9b 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -103,6 +103,8 @@ REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES, SPACES_SDK_TYPES, + DiscussionStatusFilter, + DiscussionTypeFilter, ) from .file_download import ( get_hf_file_metadata, @@ -5199,8 +5201,8 @@ def get_repo_discussions( repo_id: str, *, author: Optional[str] = None, - discussion_type: Optional[str] = None, - discussion_status: Optional[str] = None, + discussion_type: Optional[DiscussionTypeFilter] = None, + discussion_status: Optional[DiscussionStatusFilter] = None, repo_type: Optional[str] = None, token: Optional[str] = None, ) -> Iterator[Discussion]: From 0f7e30aef3b0939710e034bf0701c6966249019c Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Wed, 22 Nov 2023 13:35:20 +0000 Subject: [PATCH 6/8] =?UTF-8?q?=E2=9C=85=20Add=20some=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_hf_api.py | 45 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index a7107c93e7..93ef7aef01 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -2427,6 +2427,51 @@ def test_get_repo_discussion(self): list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] ) + def test_get_repo_discussion_by_type(self): + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="pull_request") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num]) + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="discussion") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num]) + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="all") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]) + + + @unittest.skip("__DUMMY_TRANSFORMERS_USER__ is not a valid username anymore. This test fails the Hub input validation in consequence.") + def test_get_repo_discussion_by_author(self): + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author="no-discussion") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), []) + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author=USER) + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]) + + + def test_get_repo_discussion_by_status(self): + self._api.change_discussion_status(self.repo_id, self.discussion.num, "closed") + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="open") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual( + list([d.num for d in discussions_generator]), [self.pull_request.num] + ) + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="closed") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual( + list([d.num for d in discussions_generator]), [self.discussion.num] + ) + + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="all") + self.assertIsInstance(discussions_generator, types.GeneratorType) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]) + + def test_get_discussion_details(self): retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=2) self.assertEqual(retrieved, self.discussion) From 88d3708c00d69b093782dfe3014a5309e51ef30c Mon Sep 17 00:00:00 2001 From: SBrandeis Date: Wed, 22 Nov 2023 13:41:55 +0000 Subject: [PATCH 7/8] =?UTF-8?q?=F0=9F=92=84=20Code=20quality?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/huggingface_hub/constants.py | 2 +- src/huggingface_hub/hf_api.py | 2 +- tests/test_hf_api.py | 27 ++++++++++++++------------- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/huggingface_hub/constants.py b/src/huggingface_hub/constants.py index 4839cd433b..9a55b1d3f3 100644 --- a/src/huggingface_hub/constants.py +++ b/src/huggingface_hub/constants.py @@ -1,7 +1,7 @@ import os import re -from typing import Literal, Optional, Tuple import typing +from typing import Literal, Optional, Tuple # Possible values for env variables diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 51f50b9c9b..de1cd8f8fc 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -5261,7 +5261,7 @@ def get_repo_discussions( raise ValueError(f"Invalid discussion_status, must be one of {DISCUSSION_STATUS}") headers = self._build_hf_headers(token=token) - query_dict = {} + query_dict: Dict[str, str] = {} if discussion_type is not None: query_dict["type"] = discussion_type if discussion_status is not None: diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 93ef7aef01..9d676df444 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -2438,10 +2438,13 @@ def test_get_repo_discussion_by_type(self): discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="all") self.assertIsInstance(discussions_generator, types.GeneratorType) - self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]) - + self.assertListEqual( + list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] + ) - @unittest.skip("__DUMMY_TRANSFORMERS_USER__ is not a valid username anymore. This test fails the Hub input validation in consequence.") + @unittest.skip( + "__DUMMY_TRANSFORMERS_USER__ is not a valid username anymore. This test fails the Hub input validation in consequence." + ) def test_get_repo_discussion_by_author(self): discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author="no-discussion") self.assertIsInstance(discussions_generator, types.GeneratorType) @@ -2449,28 +2452,26 @@ def test_get_repo_discussion_by_author(self): discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author=USER) self.assertIsInstance(discussions_generator, types.GeneratorType) - self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]) - + self.assertListEqual( + list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] + ) def test_get_repo_discussion_by_status(self): self._api.change_discussion_status(self.repo_id, self.discussion.num, "closed") discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="open") self.assertIsInstance(discussions_generator, types.GeneratorType) - self.assertListEqual( - list([d.num for d in discussions_generator]), [self.pull_request.num] - ) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num]) discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="closed") self.assertIsInstance(discussions_generator, types.GeneratorType) - self.assertListEqual( - list([d.num for d in discussions_generator]), [self.discussion.num] - ) + self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num]) discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="all") self.assertIsInstance(discussions_generator, types.GeneratorType) - self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]) - + self.assertListEqual( + list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] + ) def test_get_discussion_details(self): retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=2) From c67fd92e67378d4b9ce73363cf4438a358b630c2 Mon Sep 17 00:00:00 2001 From: Lucain Date: Wed, 22 Nov 2023 17:00:48 +0100 Subject: [PATCH 8/8] Apply suggestions from code review --- docs/source/en/guides/community.md | 2 +- tests/test_hf_api.py | 11 +---------- 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/docs/source/en/guides/community.md b/docs/source/en/guides/community.md index 306b5d51c4..0baa34e4e5 100644 --- a/docs/source/en/guides/community.md +++ b/docs/source/en/guides/community.md @@ -25,7 +25,7 @@ The `HfApi` class allows you to retrieve Discussions and Pull Requests on a give [...] ``` -`HfApi.get_repo_discussions` supports by author, type (Pull Request or Discussion) and status (`open` or `closed`): +`HfApi.get_repo_discussions` supports filtering by author, type (Pull Request or Discussion) and status (`open` or `closed`): ```python >>> from huggingface_hub import get_repo_discussions diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 9d676df444..0d4fdfedf8 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -2442,20 +2442,11 @@ def test_get_repo_discussion_by_type(self): list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] ) - @unittest.skip( - "__DUMMY_TRANSFORMERS_USER__ is not a valid username anymore. This test fails the Hub input validation in consequence." - ) def test_get_repo_discussion_by_author(self): - discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author="no-discussion") + discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author="unknown") self.assertIsInstance(discussions_generator, types.GeneratorType) self.assertListEqual(list([d.num for d in discussions_generator]), []) - discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author=USER) - self.assertIsInstance(discussions_generator, types.GeneratorType) - self.assertListEqual( - list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num] - ) - def test_get_repo_discussion_by_status(self): self._api.change_discussion_status(self.repo_id, self.discussion.num, "closed")