Skip to content

Commit

Permalink
Merge pull request #95 from dmoklaf/master
Browse files Browse the repository at this point in the history
Fixed issue #92
  • Loading branch information
danielnsilva authored Oct 23, 2024
2 parents 1d94dff + 16d223f commit 4cafc2b
Show file tree
Hide file tree
Showing 9 changed files with 17,015 additions and 15,941 deletions.
13 changes: 11 additions & 2 deletions semanticscholar/PaginatedResults.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any, Union, List
import asyncio
import nest_asyncio

from semanticscholar.ApiRequester import ApiRequester
from semanticscholar.SemanticScholarException import NoMorePagesException
Expand Down Expand Up @@ -40,7 +39,6 @@ def __init__(
self._parameters = ''
self._items = []
self._continuation_token = None
nest_asyncio.apply()

@classmethod
async def create(
Expand Down Expand Up @@ -113,6 +111,13 @@ def __iter__(self) -> Any:
while self._has_next_page():
yield from self._get_next_page()

async def __aiter__(self) -> Any:
for item in self._items:
yield item
while self._has_next_page():
for item in await self._async_get_next_page():
yield item

def __len__(self) -> int:
return len(self._items)

Expand All @@ -134,6 +139,10 @@ async def _request_data(self) -> Union[dict, List[dict]]:
)

async def _async_get_next_page(self) -> Union[dict, List[dict]]:

if not self._has_next_page():
raise NoMorePagesException('No more pages to fetch.')

self._build_params()

results = await self._request_data()
Expand Down
3,246 changes: 1,645 additions & 1,601 deletions tests/data/test_get_author_papers_async.yaml

Large diffs are not rendered by default.

9,065 changes: 4,625 additions & 4,440 deletions tests/data/test_get_paper_citations_async.yaml

Large diffs are not rendered by default.

3,629 changes: 1,843 additions & 1,786 deletions tests/data/test_search_paper_bulk_retrieval_next_page_async.yaml

Large diffs are not rendered by default.

3,562 changes: 1,856 additions & 1,706 deletions tests/data/test_search_paper_bulk_retrieval_sorted_results_asc_async.yaml

Large diffs are not rendered by default.

3,562 changes: 1,856 additions & 1,706 deletions tests/data/test_search_paper_bulk_retrieval_sorted_results_default_order_async.yaml

Large diffs are not rendered by default.

3,549 changes: 1,841 additions & 1,708 deletions tests/data/test_search_paper_bulk_retrieval_sorted_results_desc_async.yaml

Large diffs are not rendered by default.

6,301 changes: 3,322 additions & 2,979 deletions tests/data/test_search_paper_bulk_retrieval_traversing_results_async.yaml

Large diffs are not rendered by default.

29 changes: 16 additions & 13 deletions tests/test_semanticscholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ async def test_get_paper_authors_async(self):
data = await self.sch.get_paper_authors('10.2139/ssrn.2250500')
self.assertEqual(data.offset, 0)
self.assertEqual(data.next, 0)
self.assertEqual(len([item for item in data]), 4)
self.assertEqual(len([item async for item in data]), 4)
self.assertEqual(data[0].name, 'E. Duflo')

@test_vcr.use_cassette
Expand Down Expand Up @@ -743,7 +743,7 @@ async def test_get_author_papers_async(self):
1723755, limit=100, fields=['title'])
self.assertEqual(data.offset, 0)
self.assertEqual(data.next, 100)
self.assertEqual(len([item for item in data]), 875)
self.assertEqual(len([item async for item in data]), 875)
self.assertEqual(
data[0].title,
'SARS-CoV-2 hijacks p38\u03b2/MAPK11 to promote virus replication')
Expand All @@ -754,7 +754,7 @@ async def test_get_paper_citations_async(self):
'10.2139/ssrn.2250500', fields=['title'])
self.assertEqual(data.offset, 0)
self.assertEqual(data.next, 100)
self.assertEqual(len([item.paper.title for item in data]), 2135)
self.assertEqual(len([item.paper.title async for item in data]), 2167)
self.assertEqual(
data[0].paper.title,
'Financial inclusion and roof quality: '
Expand Down Expand Up @@ -793,9 +793,9 @@ async def test_search_paper_next_page_async(self):
@test_vcr.use_cassette
async def test_search_paper_traversing_results_async(self):
data = await self.sch.search_paper('sublinear near optimal edit distance')
all_results = [item.title for item in data]
all_results = [item.title async for item in data]
with self.assertRaises(NoMorePagesException):
await data.next_page()
await data.async_next_page()
self.assertEqual(len(all_results), len(data.items))

@test_vcr.use_cassette
Expand All @@ -811,7 +811,7 @@ async def test_search_paper_year_async(self):
@test_vcr.use_cassette
async def test_search_paper_year_range_async(self):
data = await self.sch.search_paper('turing', year='1936-1937')
self.assertTrue(all([1936 <= item.year <= 1937 for item in data]))
self.assertTrue(all([1936 <= item.year <= 1937 async for item in data]))

@test_vcr.use_cassette
async def test_search_paper_publication_types_async(self):
Expand Down Expand Up @@ -870,7 +870,7 @@ async def test_search_paper_publication_date_or_year_invalid_async(self):
@test_vcr.use_cassette
async def test_search_paper_min_citation_count_async(self):
data = await self.sch.search_paper('turing', min_citation_count=1000)
self.assertTrue(all([item.citationCount >= 1000 for item in data]))
self.assertTrue(all([item.citationCount >= 1000 async for item in data]))

@test_vcr.use_cassette
async def test_search_paper_bulk_retrieval_async(self):
Expand All @@ -887,15 +887,18 @@ async def test_search_paper_bulk_retrieval_async(self):
async def test_search_paper_bulk_retrieval_next_page_async(self):
data = await self.sch.search_paper(
'kubernetes', bulk=True, fields=['title'])
data.next_page()
await data.async_next_page()
self.assertEqual(len(data), 2000)

@test_vcr.use_cassette
async def test_search_paper_bulk_retrieval_traversing_results_async(self):
data = await self.sch.search_paper(
'kubernetes', bulk=True, fields=['title'])
all_results = [item.title for item in data]
self.assertRaises(NoMorePagesException, data.next_page)
all_results = [item.title async for item in data]
print("XXX DATA", type(data))
print(data.async_next_page)
with self.assertRaises(NoMorePagesException):
await data.async_next_page()
self.assertEqual(len(all_results), len(data.items))

@test_vcr.use_cassette
Expand All @@ -905,7 +908,7 @@ async def test_search_paper_bulk_retrieval_sorted_results_default_order_async(se
bulk=True,
sort='citationCount',
fields=['citationCount'])
all_data = [item.citationCount for item in data]
all_data = [item.citationCount async for item in data]
self.assertTrue(sorted(all_data) == all_data)

@test_vcr.use_cassette
Expand All @@ -915,7 +918,7 @@ async def test_search_paper_bulk_retrieval_sorted_results_asc_async(self):
bulk=True,
sort='citationCount:asc',
fields=['citationCount'])
all_data = [item.citationCount for item in data]
all_data = [item.citationCount async for item in data]
self.assertTrue(sorted(all_data) == all_data)

@test_vcr.use_cassette
Expand All @@ -925,7 +928,7 @@ async def test_search_paper_bulk_retrieval_sorted_results_desc_async(self):
bulk=True,
sort='citationCount:desc',
fields=['citationCount'])
all_data = [item.citationCount for item in data]
all_data = [item.citationCount async for item in data]
self.assertTrue(sorted(all_data, reverse=True) == all_data)

@test_vcr.use_cassette
Expand Down

0 comments on commit 4cafc2b

Please sign in to comment.