Skip to content

Commit

Permalink
Merge pull request #84 from qiankunli/feat/custom-some-args
Browse files Browse the repository at this point in the history
feat: custom some args
  • Loading branch information
danielnsilva authored Mar 12, 2024
2 parents 0fa8aac + a18c60f commit 5f6954e
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 23 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ MANIFEST.in
.coverage
docs/
.devcontainer/
.idea
34 changes: 23 additions & 11 deletions semanticscholar/ApiRequester.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@

class ApiRequester:

def __init__(self, timeout, debug) -> None:
def __init__(self, timeout, debug, retry: bool = True) -> None:
'''
:param float timeout: an exception is raised \
if the server has not issued a response for timeout seconds.
:param bool debug: enable debug mode.
:param bool retry: enable retry mode.
'''
self.timeout = timeout
self.debug = debug
self.retry = retry

@property
def timeout(self) -> int:
Expand All @@ -36,21 +38,21 @@ def timeout(self, timeout: int) -> None:
:param int timeout:
'''
self._timeout = timeout

@property
def debug(self) -> bool:
'''
:type: :class:`bool`
'''
return self._debug

@debug.setter
def debug(self, debug: bool) -> None:
'''
:param bool debug:
'''
self._debug = debug

def _curl_cmd(self, url: str, method: str, headers: dict, payload: dict = None) -> str:
curl_cmd = f'curl -X {method}'
for key, value in headers.items():
Expand All @@ -68,18 +70,29 @@ def _print_debug(self, url, headers, payload, method) -> None:
print(f'cURL command:\n{self._curl_cmd(url, method, headers, payload)}')
print('-' * 80)

@retry(
wait=wait_fixed(30),
retry=retry_if_exception_type(ConnectionRefusedError),
stop=stop_after_attempt(10)
)
async def get_data_async(
self,
url: str,
parameters: str,
headers: dict,
payload: dict = None
) -> Union[dict, List[dict]]:
if self.retry:
return await self._get_data_async(url, parameters, headers, payload)
return await self._get_data_async.retry_with(stop=stop_after_attempt(1))(self, url=url, parameters=parameters, headers=headers, payload=payload)

@retry(
wait=wait_fixed(30),
retry=retry_if_exception_type(ConnectionRefusedError),
stop=stop_after_attempt(10)
)
async def _get_data_async(
self,
url: str,
parameters: str,
headers: dict,
payload: dict = None
) -> Union[dict, List[dict]]:
'''Get data from Semantic Scholar API
:param str url: absolute URL to API endpoint.
Expand Down Expand Up @@ -120,7 +133,7 @@ async def get_data_async(
raise Exception(data['message'])

return data

def get_data(
self,
url: str,
Expand All @@ -143,4 +156,3 @@ def get_data(
payload=payload
)
)

21 changes: 11 additions & 10 deletions semanticscholar/AsyncSemanticScholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ def __init__(
timeout: int = 30,
api_key: str = None,
api_url: str = None,
debug: bool = False
debug: bool = False,
retry: bool = True,
) -> None:
'''
:param float timeout: (optional) an exception is raised\
if the server has not issued a response for timeout seconds.
:param str api_key: (optional) private API key.
:param str api_url: (optional) custom API url.
:param bool debug: (optional) enable debug mode.
:param bool retry: enable retry mode.
'''

if api_url:
Expand All @@ -48,7 +50,8 @@ def __init__(

self._timeout = timeout
self._debug = debug
self._requester = ApiRequester(self._timeout, self._debug)
self._retry = retry
self._requester = ApiRequester(self._timeout, self._debug, self._retry)

@property
def timeout(self) -> int:
Expand All @@ -64,14 +67,14 @@ def timeout(self, timeout: int) -> None:
'''
self._timeout = timeout
self._requester.timeout = timeout

@property
def debug(self) -> bool:
'''
:type: :class:`bool`
'''
return self._debug

@debug.setter
def debug(self, debug: bool) -> None:
'''
Expand Down Expand Up @@ -168,7 +171,7 @@ async def get_papers(
data = await self._requester.get_data_async(
url, parameters, self.auth_header, payload)
papers = [Paper(item) for item in data if item is not None]

not_found_ids = self._get_not_found_ids(paper_ids, papers)

if not_found_ids:
Expand All @@ -177,7 +180,7 @@ async def get_papers(
return papers if not return_not_found else (papers, not_found_ids)

def _get_not_found_ids(self, paper_ids, papers):

prefix_mapping = {
'ARXIV': 'ArXiv',
'MAG': 'MAG',
Expand All @@ -197,9 +200,8 @@ def _get_not_found_ids(self, paper_ids, papers):
found_ids.add(
f'{prefix_mapping[prefix.lower()]}:{value}')
else:
found_ids.add(f'{value}')
found_ids = {id.lower() for id in found_ids}

found_ids.add(f'{value}')
found_ids = {id.lower() for id in found_ids}
not_found_ids = [id for id in paper_ids if id.lower() not in found_ids]

return not_found_ids
Expand Down Expand Up @@ -433,7 +435,6 @@ async def search_paper(
if fields_of_study:
fields_of_study = ','.join(fields_of_study)
query += f'&fieldsOfStudy={fields_of_study}'

if publication_date_or_year:
single_date_regex = r'\d{4}(-\d{2}(-\d{2})?)?'
full_regex = r'^({0})?(:({0})?)?$'.format(single_date_regex)
Expand Down
7 changes: 5 additions & 2 deletions semanticscholar/SemanticScholar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ def __init__(
timeout: int = 30,
api_key: str = None,
api_url: str = None,
debug: bool = False
debug: bool = False,
retry: bool = True,
) -> None:
'''
:param float timeout: (optional) an exception is raised\
if the server has not issued a response for timeout seconds.
:param str api_key: (optional) private API key.
:param str api_url: (optional) custom API url.
:param bool debug: (optional) enable debug mode.
:param bool retry: enable retry mode.
'''
nest_asyncio.apply()
self._timeout = timeout
Expand All @@ -34,7 +36,8 @@ def __init__(
timeout=timeout,
api_key=api_key,
api_url=api_url,
debug=debug
debug=debug,
retry=retry
)

@property
Expand Down

0 comments on commit 5f6954e

Please sign in to comment.