Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: Celina Hanouti <hanouticelina@gmail.com>
  • Loading branch information
hlky and hanouticelina committed Oct 3, 2024
1 parent 87bca1c commit afadc22
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 37 deletions.
85 changes: 49 additions & 36 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1463,7 +1463,7 @@ class PaperInfo:
Contains information about a paper on the Hub.
Attributes:
paper_id (`str`):
id (`str`):
arXiv paper ID.
authors (`List[str]`, **optional**):
Names of paper authors
Expand All @@ -1487,7 +1487,7 @@ class PaperInfo:
Information about who submitted the daily paper.
"""

paper_id: str
id: str
authors: Optional[List[str]]
published_at: Optional[datetime]
title: Optional[str]
Expand All @@ -1501,38 +1501,25 @@ class PaperInfo:

def __init__(self, **kwargs) -> None:
paper = kwargs.pop("paper", {})
self.paper_id = kwargs.pop("id", None) or paper.pop("id", "")
self.authors = [author.pop("name", "") for author in paper.pop("authors", [])]
published_at = paper.pop("publishedAt", None)
self.id = kwargs.pop("id", None) or paper.pop("id", None)
authors = paper.pop("authors", None) or kwargs.pop("authors", None)
self.authors = [author.pop("name", None) for author in authors] if authors else None
published_at = paper.pop("publishedAt", None) or kwargs.pop("publishedAt", None)
self.published_at = parse_datetime(published_at) if published_at else None
self.title = kwargs.pop("title", "")
self.source = kwargs.pop("source", "")
self.summary = paper.pop("summary", "")
self.upvotes = paper.pop("upvotes", "")
self.discussion_id = paper.pop("discussionId", "")
self.title = kwargs.pop("title", None)
self.source = kwargs.pop("source", None)
self.summary = paper.pop("summary", None) or kwargs.pop("summary", None)
self.upvotes = paper.pop("upvotes", None) or kwargs.pop("upvotes", None)
self.discussion_id = paper.pop("discussionId", None) or kwargs.pop("discussionId", None)
self.comments = kwargs.pop("numComments", 0)
submitted_at = kwargs.pop("publishedAt", None)
submitted_at = kwargs.pop("publishedAt", None) or kwargs.pop("submittedOnDailyAt", None)
self.submitted_at = parse_datetime(submitted_at) if submitted_at else None
submitted_by = kwargs.pop("submittedBy", None)
submitted_by = kwargs.pop("submittedBy", None) or kwargs.pop("submittedOnDailyBy", None)
self.submitted_by = User(**submitted_by) if submitted_by else None

# forward compatibility
self.__dict__.update(**kwargs)

@staticmethod
def from_paper_info(**kwargs) -> PaperInfo:
daily_paper = {}
published_at = kwargs.pop("submittedOnDailyAt", None)
submitted_by = kwargs.pop("submittedOnDailyBy", {})
title = kwargs["title"] if "title" in kwargs else ""
daily_paper = {
"paper": kwargs,
"publishedAt": published_at,
"submittedBy": submitted_by,
"title": title,
}
return PaperInfo(**daily_paper)


def future_compatible(fn: CallableT) -> CallableT:
"""Wrap a method of `HfApi` to handle `run_as_future=True`.
Expand Down Expand Up @@ -9757,18 +9744,45 @@ def list_papers(
query: Optional[str] = None,
) -> Iterable[PaperInfo]:
"""
Get daily papers on the Hub.
List daily papers on the Hugging Face Hub, given a date or a search query.
Args:
date (`str`):
Date to get papers for in format YYYY-MM-DD.
date (`str`, *optional*):
The date to retrieve papers for, in the format 'YYYY-MM-DD'.
If provided, returns papers submitted on this date.
query (`str`, *optional*):
A search query string to find papers.
If provided, returns papers that match the query.
Returns:
`Iterable[PaperInfo]`: A list of [`PaperInfo`] objects.
Raises:
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError):
HTTP 400 if the date is invalid.
[`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError):
If neither `date` nor `query` is provided.
Example usage with the `date` argument:
```python
>>> from huggingface_hub import HfApi
>>> api = HfApi()
# List all papers submitted on a specific date
>>> api.list_papers(date="2024-09-17")
```
Example usage with the `query` argument:
```python
>>> from huggingface_hub import HfApi
>>> api = HfApi()
# List all papers with "attention" in their title
>>> api.list_papers(query="attention")
```
"""
if date is None and query is None:
raise ValueError("Provide one of `date` or `query`.")
Expand All @@ -9780,13 +9794,13 @@ def list_papers(
for paper in r.json():
yield PaperInfo(**paper)

def paper_info(self, paper_id: str) -> PaperInfo:
def paper_info(self, id: str) -> PaperInfo:
"""
Get information for a paper on the Hub.
Args:
paper_id (`str`, **optional**):
ID of the paper.
id (`str`, **optional**):
ArXiv id of the paper.
Returns:
`PaperInfo`: A `PaperInfo` object.
Expand All @@ -9795,10 +9809,9 @@ def paper_info(self, paper_id: str) -> PaperInfo:
[`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError):
HTTP 404 If the paper does not exist on the Hub.
"""
r = get_session().get(f"{constants.ENDPOINT}/api/papers/{paper_id}")
r = get_session().get(f"{constants.ENDPOINT}/api/papers/{id}")
hf_raise_for_status(r)
data = r.json()
return PaperInfo.from_paper_info(**data)
return PaperInfo(**r.json())

def auth_check(
self, repo_id: str, *, repo_type: Optional[str] = None, token: Union[bool, str, None] = None
Expand Down
2 changes: 1 addition & 1 deletion tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4122,7 +4122,7 @@ def test_papers_by_query(self) -> None:

def test_get_paper_by_id(self) -> None:
paper_id = "1706.03762"
paper = self.api.paper_info(paper_id=paper_id)
paper = self.api.paper_info(id=paper_id)
assert paper.title == "Attention Is All You Need"


Expand Down

0 comments on commit afadc22

Please sign in to comment.