Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add filters to HfApi.get_repo_discussions #1845

Merged
merged 9 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def _as_int(value: Optional[str]) -> Optional[int]:
"models": REPO_TYPE_MODEL,
}

DISCUSSION_TYPES = ["all", "discussion", "pull_request"]
DISCUSSION_STATUS = ["all", "open", "closed"]

# default cache
default_home = os.path.join(os.path.expanduser("~"), ".cache")
Expand Down
33 changes: 31 additions & 2 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
Union,
overload,
)
from urllib.parse import quote
from urllib.parse import quote, urlencode

import requests
from requests.exceptions import HTTPError
Expand Down Expand Up @@ -93,6 +93,8 @@
from .constants import (
DEFAULT_ETAG_TIMEOUT,
DEFAULT_REVISION,
DISCUSSION_STATUS,
DISCUSSION_TYPES,
ENDPOINT,
INFERENCE_ENDPOINTS_ENDPOINT,
REGEX_COMMIT_OID,
Expand Down Expand Up @@ -5191,6 +5193,9 @@ def get_repo_discussions(
self,
repo_id: str,
*,
author: Optional[str] = None,
discussion_type: Optional[str] = None,
discussion_status: Optional[str] = None,
SBrandeis marked this conversation as resolved.
Show resolved Hide resolved
repo_type: Optional[str] = None,
token: Optional[str] = None,
) -> Iterator[Discussion]:
Expand All @@ -5201,6 +5206,18 @@ def get_repo_discussions(
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
author (`str`, *optional*):
Pass a value to filter by discussion author. `None` means no filter.
Default is `None`.
discussion_type (`str`, *optional*):
Set to `"pull_request"` to fetch only pull requests, `"discussion"`
to fetch only discussions. Set to `"all"` or `None` to fetch both.
Default is `None`.
discussion_status (`str`, *optional*):
Set to `"open"` (respectively `"closed"`) to fetch only open
(respectively closed) discussions. Set to `"all"` or `None`
to fetch both.
Default is `None`.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if fetching from a dataset or
space, `None` or `"model"` if fetching from a model. Default is
Expand Down Expand Up @@ -5231,11 +5248,23 @@ def get_repo_discussions(
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
if repo_type is None:
repo_type = REPO_TYPE_MODEL
if discussion_type is not None and discussion_type not in DISCUSSION_TYPES:
raise ValueError(f"Invalid discussion_type, must be one of {DISCUSSION_TYPES}")
if discussion_status is not None and discussion_status not in DISCUSSION_STATUS:
raise ValueError(f"Invalid discussion_status, must be one of {DISCUSSION_STATUS}")

headers = self._build_hf_headers(token=token)
query_dict = {}
if discussion_type is not None:
query_dict["type"] = discussion_type
if discussion_status is not None:
query_dict["status"] = discussion_status
if author is not None:
query_dict["author"] = author

def _fetch_discussion_page(page_index: int):
path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?p={page_index}"
query_string = urlencode({**query_dict, "page_index": page_index})
path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions?{query_string}"
resp = get_session().get(path, headers=headers)
hf_raise_for_status(resp)
paginated_discussions = resp.json()
Expand Down
Loading