Skip to content

Commit 7159116

Browse files
timu-jesse-ezellsilentworkso-santi
authored
fix(auth): return auth_response from exchange_code_for_session instead of response dict (#1288)
Co-authored-by: Andrew Smith <a.smith@silentworks.co.uk> Co-authored-by: Leonardo Santiago <leonardo.ribeiro.santiago@gmail.com>
1 parent 9ab912b commit 7159116

File tree

11 files changed

+74
-81
lines changed

11 files changed

+74
-81
lines changed

src/auth/src/supabase_auth/_async/gotrue_admin_api.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from pydantic import TypeAdapter
77

88
from ..helpers import (
9-
validate_uuid,
109
model_validate,
1110
parse_link_response,
1211
parse_user_response,
12+
validate_uuid,
1313
)
1414
from ..http_clients import AsyncClient
1515
from ..types import (
@@ -57,15 +57,15 @@ def __init__(
5757
)
5858
# TODO(@o-santi): why is is this done this way?
5959
self.mfa = AsyncGoTrueAdminMFAAPI()
60-
self.mfa.list_factors = self._list_factors # type: ignore
61-
self.mfa.delete_factor = self._delete_factor # type: ignore
60+
self.mfa.list_factors = self._list_factors # type: ignore
61+
self.mfa.delete_factor = self._delete_factor # type: ignore
6262
self.oauth = AsyncGoTrueAdminOAuthAPI()
63-
self.oauth.list_clients = self._list_oauth_clients # type: ignore
64-
self.oauth.create_client = self._create_oauth_client # type: ignore
65-
self.oauth.get_client = self._get_oauth_client # type: ignore
66-
self.oauth.update_client = self._update_oauth_client # type: ignore
67-
self.oauth.delete_client = self._delete_oauth_client # type: ignore
68-
self.oauth.regenerate_client_secret = self._regenerate_oauth_client_secret # type: ignore
63+
self.oauth.list_clients = self._list_oauth_clients # type: ignore
64+
self.oauth.create_client = self._create_oauth_client # type: ignore
65+
self.oauth.get_client = self._get_oauth_client # type: ignore
66+
self.oauth.update_client = self._update_oauth_client # type: ignore
67+
self.oauth.delete_client = self._delete_oauth_client # type: ignore
68+
self.oauth.regenerate_client_secret = self._regenerate_oauth_client_secret # type: ignore
6969

7070
async def sign_out(self, jwt: str, scope: SignOutScope = "global") -> None:
7171
"""
@@ -276,9 +276,8 @@ async def _create_oauth_client(
276276
body=params,
277277
)
278278

279-
return OAuthClientResponse(
280-
client=model_validate(OAuthClient, response.content)
281-
)
279+
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))
280+
282281
async def _get_oauth_client(
283282
self,
284283
client_id: str,
@@ -295,9 +294,7 @@ async def _get_oauth_client(
295294
"GET",
296295
f"admin/oauth/clients/{client_id}",
297296
)
298-
return OAuthClientResponse(
299-
client=model_validate(OAuthClient, response.content)
300-
)
297+
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))
301298

302299
async def _update_oauth_client(
303300
self,
@@ -317,9 +314,7 @@ async def _update_oauth_client(
317314
f"admin/oauth/clients/{client_id}",
318315
body=params,
319316
)
320-
return OAuthClientResponse(
321-
client=model_validate(OAuthClient, response.content)
322-
)
317+
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))
323318

324319
async def _delete_oauth_client(
325320
self,
@@ -354,6 +349,4 @@ async def _regenerate_oauth_client_secret(
354349
"POST",
355350
f"admin/oauth/clients/{client_id}/regenerate_secret",
356351
)
357-
return OAuthClientResponse(
358-
client=model_validate(OAuthClient, response.content)
359-
)
352+
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))

src/auth/src/supabase_auth/_async/gotrue_admin_oauth_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
from typing import Optional
2+
13
from ..types import (
24
CreateOAuthClientParams,
35
OAuthClientListResponse,
46
OAuthClientResponse,
57
PageParams,
68
UpdateOAuthClientParams,
79
)
8-
from typing import Optional
910

1011

1112
class AsyncGoTrueAdminOAuthAPI:

src/auth/src/supabase_auth/_async/gotrue_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,9 @@ async def _get_url_for_provider(
11651165
query = query.set("provider", provider)
11661166
return f"{url}?{query}", query
11671167

1168-
async def exchange_code_for_session(self, params: CodeExchangeParams):
1168+
async def exchange_code_for_session(
1169+
self, params: CodeExchangeParams
1170+
) -> AuthResponse:
11691171
code_verifier = params.get("code_verifier") or await self._storage.get_item(
11701172
f"{self._storage_key}-code-verifier"
11711173
)
@@ -1184,7 +1186,7 @@ async def exchange_code_for_session(self, params: CodeExchangeParams):
11841186
if auth_response.session:
11851187
await self._save_session(auth_response.session)
11861188
self._notify_all_subscribers("SIGNED_IN", auth_response.session)
1187-
return response
1189+
return auth_response
11881190

11891191
async def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
11901192
jwk: Optional[JWK] = None

src/auth/src/supabase_auth/_sync/gotrue_admin_api.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from pydantic import TypeAdapter
77

88
from ..helpers import (
9-
validate_uuid,
109
model_validate,
1110
parse_link_response,
1211
parse_user_response,
12+
validate_uuid,
1313
)
1414
from ..http_clients import SyncClient
1515
from ..types import (
@@ -57,15 +57,15 @@ def __init__(
5757
)
5858
# TODO(@o-santi): why is is this done this way?
5959
self.mfa = SyncGoTrueAdminMFAAPI()
60-
self.mfa.list_factors = self._list_factors # type: ignore
61-
self.mfa.delete_factor = self._delete_factor # type: ignore
60+
self.mfa.list_factors = self._list_factors # type: ignore
61+
self.mfa.delete_factor = self._delete_factor # type: ignore
6262
self.oauth = SyncGoTrueAdminOAuthAPI()
63-
self.oauth.list_clients = self._list_oauth_clients # type: ignore
64-
self.oauth.create_client = self._create_oauth_client # type: ignore
65-
self.oauth.get_client = self._get_oauth_client # type: ignore
66-
self.oauth.update_client = self._update_oauth_client # type: ignore
67-
self.oauth.delete_client = self._delete_oauth_client # type: ignore
68-
self.oauth.regenerate_client_secret = self._regenerate_oauth_client_secret # type: ignore
63+
self.oauth.list_clients = self._list_oauth_clients # type: ignore
64+
self.oauth.create_client = self._create_oauth_client # type: ignore
65+
self.oauth.get_client = self._get_oauth_client # type: ignore
66+
self.oauth.update_client = self._update_oauth_client # type: ignore
67+
self.oauth.delete_client = self._delete_oauth_client # type: ignore
68+
self.oauth.regenerate_client_secret = self._regenerate_oauth_client_secret # type: ignore
6969

7070
def sign_out(self, jwt: str, scope: SignOutScope = "global") -> None:
7171
"""
@@ -276,9 +276,8 @@ def _create_oauth_client(
276276
body=params,
277277
)
278278

279-
return OAuthClientResponse(
280-
client=model_validate(OAuthClient, response.content)
281-
)
279+
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))
280+
282281
def _get_oauth_client(
283282
self,
284283
client_id: str,
@@ -295,9 +294,7 @@ def _get_oauth_client(
295294
"GET",
296295
f"admin/oauth/clients/{client_id}",
297296
)
298-
return OAuthClientResponse(
299-
client=model_validate(OAuthClient, response.content)
300-
)
297+
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))
301298

302299
def _update_oauth_client(
303300
self,
@@ -317,9 +314,7 @@ def _update_oauth_client(
317314
f"admin/oauth/clients/{client_id}",
318315
body=params,
319316
)
320-
return OAuthClientResponse(
321-
client=model_validate(OAuthClient, response.content)
322-
)
317+
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))
323318

324319
def _delete_oauth_client(
325320
self,
@@ -354,6 +349,4 @@ def _regenerate_oauth_client_secret(
354349
"POST",
355350
f"admin/oauth/clients/{client_id}/regenerate_secret",
356351
)
357-
return OAuthClientResponse(
358-
client=model_validate(OAuthClient, response.content)
359-
)
352+
return OAuthClientResponse(client=model_validate(OAuthClient, response.content))

src/auth/src/supabase_auth/_sync/gotrue_admin_oauth_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
from typing import Optional
2+
13
from ..types import (
24
CreateOAuthClientParams,
35
OAuthClientListResponse,
46
OAuthClientResponse,
57
PageParams,
68
UpdateOAuthClientParams,
79
)
8-
from typing import Optional
910

1011

1112
class SyncGoTrueAdminOAuthAPI:

src/auth/src/supabase_auth/_sync/gotrue_client.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,7 @@ def sign_in_with_oauth(
441441
)
442442
return OAuthResponse(provider=provider, url=url_with_qs)
443443

444-
def link_identity(
445-
self, credentials: SignInWithOAuthCredentials
446-
) -> OAuthResponse:
444+
def link_identity(self, credentials: SignInWithOAuthCredentials) -> OAuthResponse:
447445
provider = credentials["provider"]
448446
options = credentials.get("options", {})
449447
redirect_to = options.get("redirect_to")
@@ -743,9 +741,7 @@ def set_session(self, access_token: str, refresh_token: str) -> AuthResponse:
743741
self._notify_all_subscribers("TOKEN_REFRESHED", session)
744742
return AuthResponse(session=session, user=session.user)
745743

746-
def refresh_session(
747-
self, refresh_token: Optional[str] = None
748-
) -> AuthResponse:
744+
def refresh_session(self, refresh_token: Optional[str] = None) -> AuthResponse:
749745
"""
750746
Returns a new session, regardless of expiry status.
751747
@@ -1153,9 +1149,7 @@ def _get_url_for_provider(
11531149
if self._flow_type == "pkce":
11541150
code_verifier = generate_pkce_verifier()
11551151
code_challenge = generate_pkce_challenge(code_verifier)
1156-
self._storage.set_item(
1157-
f"{self._storage_key}-code-verifier", code_verifier
1158-
)
1152+
self._storage.set_item(f"{self._storage_key}-code-verifier", code_verifier)
11591153
code_challenge_method = (
11601154
"plain" if code_verifier == code_challenge else "s256"
11611155
)
@@ -1165,7 +1159,7 @@ def _get_url_for_provider(
11651159
query = query.set("provider", provider)
11661160
return f"{url}?{query}", query
11671161

1168-
def exchange_code_for_session(self, params: CodeExchangeParams):
1162+
def exchange_code_for_session(self, params: CodeExchangeParams) -> AuthResponse:
11691163
code_verifier = params.get("code_verifier") or self._storage.get_item(
11701164
f"{self._storage_key}-code-verifier"
11711165
)
@@ -1184,7 +1178,7 @@ def exchange_code_for_session(self, params: CodeExchangeParams):
11841178
if auth_response.session:
11851179
self._save_session(auth_response.session)
11861180
self._notify_all_subscribers("SIGNED_IN", auth_response.session)
1187-
return response
1181+
return auth_response
11881182

11891183
def _fetch_jwks(self, kid: str, jwks: JWKSet) -> JWK:
11901184
jwk: Optional[JWK] = None

src/auth/src/supabase_auth/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ def is_valid_uuid(value: str) -> bool:
299299
except ValueError:
300300
return False
301301

302+
302303
def validate_uuid(id: str | None) -> None:
303304
if id is None:
304305
raise ValueError("Invalid id, id is None")
305306
if not is_valid_uuid(id):
306-
raise ValueError(f"Invalid id, '{id}' is not a valid uuid")
307+
raise ValueError(f"Invalid id, '{id}' is not a valid uuid")

src/auth/src/supabase_auth/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,9 @@ class JWKSet(TypedDict):
893893
Only relevant when the OAuth 2.1 server is enabled in Supabase Auth.
894894
"""
895895

896-
OAuthClientTokenEndpointAuthMethod = Literal["none", "client_secret_basic", "client_secret_post"]
896+
OAuthClientTokenEndpointAuthMethod = Literal[
897+
"none", "client_secret_basic", "client_secret_post"
898+
]
897899
"""
898900
OAuth client token endpoint authentication method.
899901
Only relevant when the OAuth 2.1 server is enabled in Supabase Auth.
@@ -957,6 +959,7 @@ class CreateOAuthClientParams(BaseModel):
957959
scope: Optional[str] = None
958960
"""Space-separated list of scope values"""
959961

962+
960963
class UpdateOAuthClientParams(BaseModel):
961964
"""
962965
Parameters for updating an existing OAuth client.
@@ -974,6 +977,7 @@ class UpdateOAuthClientParams(BaseModel):
974977
grant_types: Optional[List[OAuthClientGrantType]] = None
975978
"""Array of allowed grant types"""
976979

980+
977981
class OAuthClientResponse(BaseModel):
978982
"""
979983
Response type for OAuth client operations.

src/auth/tests/_async/test_gotrue_admin_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
AuthWeakPasswordError,
1111
)
1212
from supabase_auth.types import CreateOAuthClientParams, UpdateOAuthClientParams
13+
1314
from .clients import (
1415
auth_client,
1516
auth_client_with_session,
@@ -649,6 +650,7 @@ async def test_get_oauth_client():
649650
assert response.client is not None
650651
assert response.client.client_id == client_id
651652

653+
652654
# Server is not yet released, so this test is not yet relevant.
653655
# async def test_update_oauth_client():
654656
# """Test updating an OAuth client."""
@@ -671,6 +673,7 @@ async def test_get_oauth_client():
671673
# assert response.client is not None
672674
# assert response.client.client_name == "Updated Test OAuth Client"
673675

676+
674677
async def test_delete_oauth_client():
675678
"""Test deleting an OAuth client."""
676679
# First create a client

src/auth/tests/_sync/test_gotrue.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,7 @@ def test_exchange_code_for_session():
331331
client._flow_type = "pkce"
332332

333333
# Test the PKCE URL generation which is needed for exchange_code_for_session
334-
url, params = client._get_url_for_provider(
335-
f"{client._url}/authorize", "github", {}
336-
)
334+
url, params = client._get_url_for_provider(f"{client._url}/authorize", "github", {})
337335

338336
# Verify PKCE parameters were added
339337
assert "code_challenge" in params

0 commit comments

Comments
 (0)