Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add filters to HfApi.get_repo_discussions #1845

Merged
merged 9 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion docs/source/en/guides/community.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ The `HfApi` class allows you to retrieve Discussions and Pull Requests on a give

```python
>>> from huggingface_hub import get_repo_discussions
>>> for discussion in get_repo_discussions(repo_id="bigscience/bloom-1b3"):
>>> for discussion in get_repo_discussions(repo_id="bigscience/bloom"):
... print(f"{discussion.num} - {discussion.title}, pr: {discussion.is_pull_request}")

# 11 - Add Flax weights, pr: True
Expand All @@ -25,6 +25,21 @@ The `HfApi` class allows you to retrieve Discussions and Pull Requests on a give
[...]
```

`HfApi.get_repo_discussions` supports by author, type (Pull Request or Discussion) and status (`open` or `closed`):
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

```python
>>> from huggingface_hub import get_repo_discussions
>>> for discussion in get_repo_discussions(
... repo_id="bigscience/bloom",
... author="ArthurZ",
... discussion_type="pull_request",
... discussion_status="open",
... ):
... print(f"{discussion.num} - {discussion.title} by {discussion.author}, pr: {discussion.is_pull_request}")

# 19 - Add Flax weights by ArthurZ, pr: True
```

`HfApi.get_repo_discussions` returns a [generator](https://docs.python.org/3.7/howto/functional.html#generators) that yields
[`Discussion`] objects. To get all the Discussions in a single list, run:

Expand Down
7 changes: 6 additions & 1 deletion src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import re
from typing import Optional
import typing
from typing import Literal, Optional, Tuple


# Possible values for env variables
Expand Down Expand Up @@ -79,6 +80,10 @@ def _as_int(value: Optional[str]) -> Optional[int]:
"models": REPO_TYPE_MODEL,
}

DiscussionTypeFilter = Literal["all", "discussion", "pull_request"]
DISCUSSION_TYPES: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionTypeFilter)
DiscussionStatusFilter = Literal["all", "open", "closed"]
DISCUSSION_STATUS: Tuple[DiscussionTypeFilter, ...] = typing.get_args(DiscussionStatusFilter)
Comment on lines +83 to +86
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤯


# default cache
default_home = os.path.join(os.path.expanduser("~"), ".cache")
Expand Down
35 changes: 33 additions & 2 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
Union,
overload,
)
from urllib.parse import quote
from urllib.parse import quote, urlencode

import requests
from requests.exceptions import HTTPError
Expand Down Expand Up @@ -93,6 +93,8 @@
from .constants import (
DEFAULT_ETAG_TIMEOUT,
DEFAULT_REVISION,
DISCUSSION_STATUS,
DISCUSSION_TYPES,
ENDPOINT,
INFERENCE_ENDPOINTS_ENDPOINT,
REGEX_COMMIT_OID,
Expand All @@ -101,6 +103,8 @@
REPO_TYPES_MAPPING,
REPO_TYPES_URL_PREFIXES,
SPACES_SDK_TYPES,
DiscussionStatusFilter,
DiscussionTypeFilter,
)
from .file_download import (
get_hf_file_metadata,
Expand Down Expand Up @@ -5196,6 +5200,9 @@ def get_repo_discussions(
self,
repo_id: str,
*,
author: Optional[str] = None,
discussion_type: Optional[DiscussionTypeFilter] = None,
discussion_status: Optional[DiscussionStatusFilter] = None,
repo_type: Optional[str] = None,
token: Optional[str] = None,
) -> Iterator[Discussion]:
Expand All @@ -5206,6 +5213,18 @@ def get_repo_discussions(
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
author (`str`, *optional*):
Pass a value to filter by discussion author. `None` means no filter.
Default is `None`.
discussion_type (`str`, *optional*):
Set to `"pull_request"` to fetch only pull requests, `"discussion"`
to fetch only discussions. Set to `"all"` or `None` to fetch both.
Default is `None`.
discussion_status (`str`, *optional*):
Set to `"open"` (respectively `"closed"`) to fetch only open
(respectively closed) discussions. Set to `"all"` or `None`
to fetch both.
Default is `None`.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if fetching from a dataset or
space, `None` or `"model"` if fetching from a model. Default is
Expand Down Expand Up @@ -5236,11 +5255,23 @@ def get_repo_discussions(
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
if repo_type is None:
repo_type = REPO_TYPE_MODEL
if discussion_type is not None and discussion_type not in DISCUSSION_TYPES:
raise ValueError(f"Invalid discussion_type, must be one of {DISCUSSION_TYPES}")
if discussion_status is not None and discussion_status not in DISCUSSION_STATUS:
raise ValueError(f"Invalid discussion_status, must be one of {DISCUSSION_STATUS}")

headers = self._build_hf_headers(token=token)
query_dict: Dict[str, str] = {}
if discussion_type is not None:
query_dict["type"] = discussion_type
if discussion_status is not None:
query_dict["status"] = discussion_status
if author is not None:
query_dict["author"] = author

def _fetch_discussion_page(page_index: int):
path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}"
query_string = urlencode({**query_dict, "page_index": page_index})
path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?{query_string}"
resp = get_session().get(path, headers=headers)
hf_raise_for_status(resp)
paginated_discussions = resp.json()
Expand Down
46 changes: 46 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2427,6 +2427,52 @@ def test_get_repo_discussion(self):
list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]
)

def test_get_repo_discussion_by_type(self):
discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="pull_request")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num])

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="discussion")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num])

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_type="all")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(
list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]
)

@unittest.skip(
"__DUMMY_TRANSFORMERS_USER__ is not a valid username anymore. This test fails the Hub input validation in consequence."
)
def test_get_repo_discussion_by_author(self):
discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author="no-discussion")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [])

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, author=USER)
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(
list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]
)
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

def test_get_repo_discussion_by_status(self):
self._api.change_discussion_status(self.repo_id, self.discussion.num, "closed")

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="open")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [self.pull_request.num])

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="closed")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(list([d.num for d in discussions_generator]), [self.discussion.num])

discussions_generator = self._api.get_repo_discussions(repo_id=self.repo_id, discussion_status="all")
self.assertIsInstance(discussions_generator, types.GeneratorType)
self.assertListEqual(
list([d.num for d in discussions_generator]), [self.discussion.num, self.pull_request.num]
)

def test_get_discussion_details(self):
retrieved = self._api.get_discussion_details(repo_id=self.repo_id, discussion_num=2)
self.assertEqual(retrieved, self.discussion)
Expand Down
Loading