Skip to content

Commit

Permalink
Fix asyncify for users client where token is not required (#506)
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjmcgrath authored Jul 24, 2023
2 parents d2ab498 + 6e3b2a9 commit c5131b6
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
16 changes: 15 additions & 1 deletion auth0/asyncify.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import aiohttp

from auth0.authentication import Users
from auth0.authentication.base import AuthenticationBase
from auth0.rest import RestClientOptions
from auth0.rest_async import AsyncRestClient
Expand All @@ -21,6 +22,17 @@ def asyncify(cls):
if callable(getattr(cls, func)) and not func.startswith("_")
]

class UsersAsyncClient(cls):
def __init__(
self,
domain,
telemetry=True,
timeout=5.0,
protocol="https",
):
super().__init__(domain, telemetry, timeout, protocol)
self.client = AsyncRestClient(None, telemetry=telemetry, timeout=timeout)

class AsyncManagementClient(cls):
def __init__(
self,
Expand Down Expand Up @@ -68,7 +80,9 @@ def __init__(
class Wrapper(cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if AuthenticationBase in cls.__bases__:
if cls == Users:
self._async_client = UsersAsyncClient(*args, **kwargs)
elif AuthenticationBase in cls.__bases__:
self._async_client = AsyncAuthenticationClient(*args, **kwargs)
else:
self._async_client = AsyncManagementClient(*args, **kwargs)
Expand Down
1 change: 0 additions & 1 deletion auth0/authentication/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ def userinfo(self, access_token: str) -> dict[str, Any]:
Returns:
The user profile.
"""

data: dict[str, Any] = self.client.get(
url=f"{self.protocol}://{self.domain}/userinfo",
headers={"Authorization": f"Bearer {access_token}"},
Expand Down
19 changes: 18 additions & 1 deletion auth0/test_async/test_asyncify.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from callee import Attrs

from auth0.asyncify import asyncify
from auth0.authentication import GetToken
from auth0.authentication import GetToken, Users
from auth0.management import Clients, Guardian, Jobs

clients = re.compile(r"^https://example\.com/api/v2/clients.*")
token = re.compile(r"^https://example\.com/oauth/token.*")
user_info = re.compile(r"^https://example\.com/userinfo.*")
factors = re.compile(r"^https://example\.com/api/v2/guardian/factors.*")
users_imports = re.compile(r"^https://example\.com/api/v2/jobs/users-imports.*")
payload = {"foo": "bar"}
Expand Down Expand Up @@ -111,6 +112,22 @@ async def test_post_auth(self, mocked):
timeout=ANY,
)

@aioresponses()
async def test_user_info(self, mocked):
callback, mock = get_callback()
mocked.get(user_info, callback=callback)
c = asyncify(Users)(domain="example.com")
self.assertEqual(
await c.userinfo_async(access_token="access-token-example"), payload
)
mock.assert_called_with(
Attrs(path="/userinfo"),
headers={**headers, "Authorization": "Bearer access-token-example"},
timeout=ANY,
allow_redirects=True,
params=None,
)

@aioresponses()
async def test_file_post(self, mocked):
callback, mock = get_callback()
Expand Down

0 comments on commit c5131b6

Please sign in to comment.