Skip to content

Commit

Permalink
✅ Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
SBrandeis committed Jun 27, 2022
1 parent 9e7d8a5 commit 4979ff0
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 13 deletions.
33 changes: 20 additions & 13 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@


REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$")
REGEX_HEXADECIMAL = re.compile(r"[a-fA-F0-9]")
USERNAME_PLACEHOLDER = "hf_user"

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -1138,7 +1137,9 @@ def model_info(
path = (
f"{self.endpoint}/api/models/{repo_id}"
if revision is None
else f"{self.endpoint}/api/models/{repo_id}/revision/{revision}"
else (
f"{self.endpoint}/api/models/{repo_id}/revision/{quote(revision, safe='')}"
)
)
headers = {"authorization": f"Bearer {token}"} if token is not None else None
status_query_param = {"securityStatus": True} if securityStatus else None
Expand Down Expand Up @@ -1195,7 +1196,9 @@ def dataset_info(
path = (
f"{self.endpoint}/api/datasets/{repo_id}"
if revision is None
else f"{self.endpoint}/api/datasets/{repo_id}/revision/{revision}"
else (
f"{self.endpoint}/api/datasets/{repo_id}/revision/{quote(revision, safe='')}"
)
)
headers = {"authorization": f"Bearer {token}"} if token is not None else None
r = requests.get(path, headers=headers, timeout=timeout)
Expand Down Expand Up @@ -1249,7 +1252,9 @@ def space_info(
path = (
f"{self.endpoint}/api/spaces/{repo_id}"
if revision is None
else f"{self.endpoint}/api/spaces/{repo_id}/revision/{revision}"
else (
f"{self.endpoint}/api/spaces/{repo_id}/revision/{quote(revision, safe='')}"
)
)
headers = {"authorization": f"Bearer {token}"} if token is not None else None
r = requests.get(path, headers=headers, timeout=timeout)
Expand Down Expand Up @@ -1802,7 +1807,9 @@ def create_commit(
if repo_type not in REPO_TYPES:
raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}")
token, name = self._validate_or_retrieve_token(token)
revision = revision if revision is not None else DEFAULT_REVISION
revision = (
quote(revision, safe="") if revision is not None else DEFAULT_REVISION
)
create_pr = create_pr if create_pr is not None else False

operations = list(operations)
Expand Down Expand Up @@ -2501,6 +2508,7 @@ def create_discussion(
"description": description,
"pullRequest": pull_request,
},
headers={"Authorization": f"Bearer {token}"},
)
_raise_for_status(resp)
num = resp.json()["num"]
Expand Down Expand Up @@ -2749,13 +2757,16 @@ def change_discussion_status(
"""
if new_status not in ["open", "closed"]:
raise ValueError("Invalid status, valid statuses are: 'open' and 'closed'")
body = {"status": new_status}
if comment and comment.strip():
body["comment"] = comment.strip()
resp = self._post_discussion_changes(
repo_id=repo_id,
repo_type=repo_type,
discussion_num=discussion_num,
token=token,
resource="status",
body={"status": new_status, "comment": comment},
body=body,
)
return deserialize_event(resp.json()["newStatus"])

Expand Down Expand Up @@ -2808,7 +2819,7 @@ def merge_pull_request(
discussion_num=discussion_num,
token=token,
resource="merge",
body={"comment": comment},
body={"comment": comment.strip()} if comment and comment.strip() else None,
)

def edit_discussion_comment(
Expand All @@ -2830,7 +2841,7 @@ def edit_discussion_comment(
discussion_num (`int`):
The number of the discussion or pull request. Must be a strictly positive integer.
comment_id (`str`):
The ID of the comment to edit. ID is an hexadecimal string.
The ID of the comment to edit.
new_content (`str`):
The new content of the comment. Comments support markdown formatting.
repo_type (`str`, *optional*):
Expand All @@ -2857,8 +2868,6 @@ def edit_discussion_comment(
</Tip>
"""
if not REGEX_HEXADECIMAL.fullmatch(comment_id):
raise ValueError("Invalid comment_id: must be an hexadecimal string")
resp = self._post_discussion_changes(
repo_id=repo_id,
repo_type=repo_type,
Expand Down Expand Up @@ -2889,7 +2898,7 @@ def hide_discussion_comment(
discussion_num (`int`):
The number of the discussion or pull request. Must be a strictly positive integer.
comment_id (`str`):
The ID of the comment to edit. ID is an hexadecimal string.
The ID of the comment to edit.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if uploading to a dataset or
space, `None` or `"model"` if uploading to a model. Default is
Expand All @@ -2914,8 +2923,6 @@ def hide_discussion_comment(
</Tip>
"""
if not REGEX_HEXADECIMAL.fullmatch(comment_id):
raise ValueError("Invalid comment_id: must be an hexadecimal string")
resp = self._post_discussion_changes(
repo_id=repo_id,
repo_type=repo_type,
Expand Down
174 changes: 174 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import requests
from huggingface_hub._commit_api import CommitOperationAdd, CommitOperationDelete
from huggingface_hub.commands.user import _login
from huggingface_hub.community import DiscussionComment, DiscussionWithDetails
from huggingface_hub.constants import (
REPO_TYPE_DATASET,
REPO_TYPE_MODEL,
Expand All @@ -55,6 +56,7 @@
ModelFilter,
_filter_emissions,
)
from huggingface_hub.utils.pagination import Pagination
from requests.exceptions import HTTPError

from .testing_constants import (
Expand Down Expand Up @@ -1452,3 +1454,175 @@ def test_repo_type_and_id_from_hf_id(self):
repo_type_and_id_from_hf_id(key, hub_url="https://huggingface.co"),
tuple(value),
)


class HfApiDiscussionsTest(HfApiCommonTestWithLogin):
def setUp(self):
super().setUp()
self.repo_name = f"{USER}/{repo_name()}"
self._api.create_repo(repo_id=self.repo_name, token=self._token)
self.pull_request = self._api.create_discussion(
repo_id=self.repo_name,
pull_request=True,
title="Test Pull Request",
token=self._token,
)
self.discussion = self._api.create_discussion(
repo_id=self.repo_name,
pull_request=False,
title="Test Discussion",
token=self._token,
)

def tearDown(self):
self._api.delete_repo(repo_id=self.repo_name, token=self._token)
super().tearDown()

def test_create_discussion(self):
discussion = self._api.create_discussion(
repo_id=self.repo_name,
title=" Test discussion ! ",
token=self._token,
)
self.assertEqual(discussion.num, 3)
self.assertEqual(discussion.author, USER)
self.assertEqual(discussion.is_pull_request, False)
self.assertEqual(discussion.title, "Test discussion !")

def test_create_pull_request(self):
discussion = self._api.create_discussion(
repo_id=self.repo_name,
title=" Test PR ! ",
token=self._token,
pull_request=True,
)
self.assertEqual(discussion.num, 3)
self.assertEqual(discussion.author, USER)
self.assertEqual(discussion.is_pull_request, True)
self.assertEqual(discussion.title, "Test PR !")

model_info = self._api.repo_info(
repo_id=self.repo_name,
revision="refs/pr/1",
)
self.assertIsInstance(model_info, ModelInfo)

def test_get_repo_discussion(self):
paginated = self._api.get_repo_discussions(repo_id=self.repo_name)
self.assertIsInstance(paginated, Pagination)
self.assertListEqual(
[d.num for d in paginated.page],
[self.discussion.num, self.pull_request.num],
)
self.assertEqual(paginated.has_next, False)
self.assertEqual(paginated.next_page, None)
self.assertEqual(paginated.total, 2)
self.assertEqual(paginated.page_num, 0)

def test_get_discussion_details(self):
retrieved = self._api.get_discussion_details(
repo_id=self.repo_name, discussion_num=2
)
self.assertEqual(retrieved, self.discussion)

def test_edit_discussion_comment(self):
def get_first_comment(discussion: DiscussionWithDetails) -> DiscussionComment:
return [evt for evt in discussion.events if evt.type == "comment"][0]

edited_comment = self._api.edit_discussion_comment(
repo_id=self.repo_name,
discussion_num=self.pull_request.num,
comment_id=get_first_comment(self.pull_request).id,
new_content="**Edited** comment 🤗",
token=self._token,
)
retrieved = self._api.get_discussion_details(
repo_id=self.repo_name, discussion_num=self.pull_request.num
)
self.assertEqual(get_first_comment(retrieved).edited, True)
self.assertEqual(
get_first_comment(retrieved).id, get_first_comment(self.pull_request).id
)
self.assertEqual(get_first_comment(retrieved).content, "**Edited** comment 🤗")

self.assertEqual(get_first_comment(retrieved), edited_comment)

def test_comment_discussion(self):
new_comment = self._api.comment_discussion(
repo_id=self.repo_name,
discussion_num=self.discussion.num,
comment="""\
# Multi-line comment
**With formatting**, including *italic text* & ~strike through~
And even [links](http://hf.co)! 💥🤯
""",
token=self._token,
)
retrieved = self._api.get_discussion_details(
repo_id=self.repo_name, discussion_num=self.discussion.num
)
self.assertEqual(len(retrieved.events), 2)
self.assertIn(new_comment, retrieved.events)

def test_rename_discussion(self):
rename_event = self._api.rename_discussion(
repo_id=self.repo_name,
discussion_num=self.discussion.num,
new_title="New titlee",
token=self._token,
)
retrieved = self._api.get_discussion_details(
repo_id=self.repo_name, discussion_num=self.discussion.num
)
self.assertIn(rename_event, retrieved.events)
self.assertEqual(rename_event.old_title, self.discussion.title)
self.assertEqual(rename_event.new_title, "New titlee")

def test_change_discussion_status(self):
status_change_event = self._api.change_discussion_status(
repo_id=self.repo_name,
discussion_num=self.discussion.num,
new_status="closed",
token=self._token,
)
retrieved = self._api.get_discussion_details(
repo_id=self.repo_name, discussion_num=self.discussion.num
)
self.assertIn(status_change_event, retrieved.events)
self.assertEqual(status_change_event.new_status, "closed")

with self.assertRaises(ValueError):
self._api.change_discussion_status(
repo_id=self.repo_name,
discussion_num=self.discussion.num,
new_status="published",
token=self._token,
)

# @unittest.skip("To unskip when create_commit works for arbitrary references")
def test_merge_pull_request(self):
self._api.create_commit(
repo_id=self.repo_name,
commit_message="Commit some file",
operations=[
CommitOperationAdd(path_in_repo="file.test", path_or_fileobj=b"Content")
],
revision=self.pull_request.git_reference,
token=self._token,
)
self._api.change_discussion_status(
repo_id=self.repo_name,
discussion_num=self.pull_request.num,
new_status="open",
token=self._token,
)
self._api.merge_pull_request(
self.repo_name, self.pull_request.num, token=self._token
)

retrieved = self._api.get_discussion_details(
repo_id=self.repo_name, discussion_num=self.pull_request.num
)
self.assertEqual(retrieved.status, "merged")
self.assertIsNotNone(retrieved.merge_commit_oid)

0 comments on commit 4979ff0

Please sign in to comment.