Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Types to use Option[T] #609

Merged
merged 20 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions supabase_auth/_async/gotrue_admin_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from ..helpers import model_validate, parse_link_response, parse_user_response
from ..http_clients import AsyncClient
Expand All @@ -28,7 +28,7 @@ def __init__(
*,
url: str = "",
headers: Dict[str, str] = {},
http_client: Union[AsyncClient, None] = None,
http_client: Optional[AsyncClient] = None,
verify: bool = True,
proxy: Optional[str] = None,
) -> None:
Expand Down
48 changes: 24 additions & 24 deletions supabase_auth/_async/gotrue_base_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Callable, Dict, Optional, TypeVar, Union, overload
from typing import Any, Callable, Dict, Optional, TypeVar, overload

from httpx import Response
from pydantic import BaseModel
Expand All @@ -19,7 +19,7 @@ def __init__(
*,
url: str,
headers: Dict[str, str],
http_client: Union[AsyncClient, None],
http_client: Optional[AsyncClient],
verify: bool = True,
proxy: Optional[str] = None,
):
Expand Down Expand Up @@ -47,11 +47,11 @@ async def _request(
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
jwt: Union[str, None] = None,
redirect_to: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
jwt: Optional[str] = None,
redirect_to: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
query: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
no_resolve_json: Literal[False] = False,
xform: Callable[[Any], T],
) -> T: ... # pragma: no cover
Expand All @@ -62,11 +62,11 @@ async def _request(
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
jwt: Union[str, None] = None,
redirect_to: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
jwt: Optional[str] = None,
redirect_to: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
query: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
no_resolve_json: Literal[True],
xform: Callable[[Response], T],
) -> T: ... # pragma: no cover
Expand All @@ -77,11 +77,11 @@ async def _request(
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
jwt: Union[str, None] = None,
redirect_to: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
jwt: Optional[str] = None,
redirect_to: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
query: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
no_resolve_json: bool = False,
) -> None: ... # pragma: no cover

Expand All @@ -90,14 +90,14 @@ async def _request(
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
jwt: Union[str, None] = None,
redirect_to: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
jwt: Optional[str] = None,
redirect_to: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
query: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
no_resolve_json: bool = False,
xform: Union[Callable[[Any], T], None] = None,
) -> Union[T, None]:
xform: Optional[Callable[[Any], T]] = None,
) -> Optional[T]:
url = f"{self._url}/{path}"
headers = {**self._headers, **(headers or {})}
if API_VERSION_HEADER_NAME not in headers:
Expand Down
44 changes: 22 additions & 22 deletions supabase_auth/_async/gotrue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import partial
from json import loads
from time import time
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse
from uuid import uuid4

Expand Down Expand Up @@ -87,13 +87,13 @@ class AsyncGoTrueClient(AsyncGoTrueBaseAPI):
def __init__(
self,
*,
url: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
storage_key: Union[str, None] = None,
url: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
storage_key: Optional[str] = None,
auto_refresh_token: bool = True,
persist_session: bool = True,
storage: Union[AsyncSupportedStorage, None] = None,
http_client: Union[AsyncClient, None] = None,
storage: Optional[AsyncSupportedStorage] = None,
http_client: Optional[AsyncClient] = None,
flow_type: AuthFlowType = "implicit",
verify: bool = True,
proxy: Optional[str] = None,
Expand All @@ -110,8 +110,8 @@ def __init__(
self._auto_refresh_token = auto_refresh_token
self._persist_session = persist_session
self._storage = storage or AsyncMemoryStorage()
self._in_memory_session: Union[Session, None] = None
self._refresh_token_timer: Union[Timer, None] = None
self._in_memory_session: Optional[Session] = None
self._refresh_token_timer: Optional[Timer] = None
self._network_retries = 0
self._state_change_emitters: Dict[str, Subscription] = {}
self._flow_type = flow_type
Expand All @@ -134,7 +134,7 @@ def __init__(

# Initializations

async def initialize(self, *, url: Union[str, None] = None) -> None:
async def initialize(self, *, url: Optional[str] = None) -> None:
if url and self._is_implicit_grant_flow(url):
await self.initialize_from_url(url)
else:
Expand All @@ -158,7 +158,7 @@ async def initialize_from_url(self, url: str) -> None:
# Public methods

async def sign_in_anonymously(
self, credentials: Union[SignInAnonymouslyCredentials, None] = None
self, credentials: Optional[SignInAnonymouslyCredentials] = None
) -> AuthResponse:
"""
Creates a new anonymous user.
Expand Down Expand Up @@ -591,14 +591,14 @@ async def reauthenticate(self) -> AuthResponse:
xform=parse_auth_response,
)

async def get_session(self) -> Union[Session, None]:
async def get_session(self) -> Optional[Session]:
"""
Returns the session, refreshing it if necessary.

The session returned can be null if the session is not detected which
can happen in the event a user is not signed-in or has logged out.
"""
current_session: Union[Session, None] = None
current_session: Optional[Session] = None
if self._persist_session:
maybe_session = await self._storage.get_item(self._storage_key)
current_session = self._get_valid_session(maybe_session)
Expand All @@ -620,7 +620,7 @@ async def get_session(self) -> Union[Session, None]:
else current_session
)

async def get_user(self, jwt: Union[str, None] = None) -> Union[UserResponse, None]:
async def get_user(self, jwt: Optional[str] = None) -> Optional[UserResponse]:
"""
Gets the current user details if there is an existing session.

Expand Down Expand Up @@ -672,7 +672,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon
time_now = round(time())
expires_at = time_now
has_expired = True
session: Union[Session, None] = None
session: Optional[Session] = None
if access_token and access_token.split(".")[1]:
payload = self._decode_jwt(access_token)
exp = payload.get("exp")
Expand Down Expand Up @@ -701,7 +701,7 @@ async def set_session(self, access_token: str, refresh_token: str) -> AuthRespon
return AuthResponse(session=session, user=response.user)

async def refresh_session(
self, refresh_token: Union[str, None] = None
self, refresh_token: Optional[str] = None
) -> AuthResponse:
"""
Returns a new session, regardless of expiry status.
Expand Down Expand Up @@ -743,7 +743,7 @@ async def sign_out(self, options: SignOutOptions = {"scope": "global"}) -> None:

def on_auth_state_change(
self,
callback: Callable[[AuthChangeEvent, Union[Session, None]], None],
callback: Callable[[AuthChangeEvent, Optional[Session]], None],
) -> Subscription:
"""
Receive a notification every time an auth event happens.
Expand Down Expand Up @@ -889,7 +889,7 @@ async def _get_authenticator_assurance_level(
current_authentication_methods=[],
)
payload = self._decode_jwt(session.access_token)
current_level: Union[AuthenticatorAssuranceLevels, None] = None
current_level: Optional[AuthenticatorAssuranceLevels] = None
if payload.get("aal"):
current_level = payload.get("aal")
verified_factors = [
Expand Down Expand Up @@ -917,7 +917,7 @@ async def _remove_session(self) -> None:
async def _get_session_from_url(
self,
url: str,
) -> Tuple[Session, Union[str, None]]:
) -> Tuple[Session, Optional[str]]:
if not self._is_implicit_grant_flow(url):
raise AuthImplicitGrantRedirectError("Not a valid implicit grant flow url.")
result = urlparse(url)
Expand Down Expand Up @@ -1062,15 +1062,15 @@ async def refresh_token_function():
def _notify_all_subscribers(
self,
event: AuthChangeEvent,
session: Union[Session, None],
session: Optional[Session],
) -> None:
for subscription in self._state_change_emitters.values():
subscription.callback(event, session)

def _get_valid_session(
self,
raw_session: Union[str, None],
) -> Union[Session, None]:
raw_session: Optional[str],
) -> Optional[Session]:
if not raw_session:
return None
data = loads(raw_session)
Expand All @@ -1096,7 +1096,7 @@ def _get_param(
self,
query_params: Dict[str, List[str]],
name: str,
) -> Union[str, None]:
) -> Optional[str]:
return query_params[name][0] if name in query_params else None

def _is_implicit_grant_flow(self, url: str) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions supabase_auth/_sync/gotrue_admin_api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from ..helpers import model_validate, parse_link_response, parse_user_response
from ..http_clients import SyncClient
Expand All @@ -28,7 +28,7 @@ def __init__(
*,
url: str = "",
headers: Dict[str, str] = {},
http_client: Union[SyncClient, None] = None,
http_client: Optional[SyncClient] = None,
verify: bool = True,
proxy: Optional[str] = None,
) -> None:
Expand Down
48 changes: 24 additions & 24 deletions supabase_auth/_sync/gotrue_base_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Callable, Dict, Optional, TypeVar, Union, overload
from typing import Any, Callable, Dict, Optional, TypeVar, overload

from httpx import Response
from pydantic import BaseModel
Expand All @@ -19,7 +19,7 @@ def __init__(
*,
url: str,
headers: Dict[str, str],
http_client: Union[SyncClient, None],
http_client: Optional[SyncClient],
verify: bool = True,
proxy: Optional[str] = None,
):
Expand Down Expand Up @@ -47,11 +47,11 @@ def _request(
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
jwt: Union[str, None] = None,
redirect_to: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
jwt: Optional[str] = None,
redirect_to: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
query: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
no_resolve_json: Literal[False] = False,
xform: Callable[[Any], T],
) -> T: ... # pragma: no cover
Expand All @@ -62,11 +62,11 @@ def _request(
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
jwt: Union[str, None] = None,
redirect_to: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
jwt: Optional[str] = None,
redirect_to: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
query: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
no_resolve_json: Literal[True],
xform: Callable[[Response], T],
) -> T: ... # pragma: no cover
Expand All @@ -77,11 +77,11 @@ def _request(
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
jwt: Union[str, None] = None,
redirect_to: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
jwt: Optional[str] = None,
redirect_to: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
query: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
no_resolve_json: bool = False,
) -> None: ... # pragma: no cover

Expand All @@ -90,14 +90,14 @@ def _request(
method: Literal["GET", "OPTIONS", "HEAD", "POST", "PUT", "PATCH", "DELETE"],
path: str,
*,
jwt: Union[str, None] = None,
redirect_to: Union[str, None] = None,
headers: Union[Dict[str, str], None] = None,
query: Union[Dict[str, str], None] = None,
body: Union[Any, None] = None,
jwt: Optional[str] = None,
redirect_to: Optional[str] = None,
headers: Optional[Dict[str, str]] = None,
query: Optional[Dict[str, str]] = None,
body: Optional[Any] = None,
no_resolve_json: bool = False,
xform: Union[Callable[[Any], T], None] = None,
) -> Union[T, None]:
xform: Optional[Callable[[Any], T]] = None,
) -> Optional[T]:
url = f"{self._url}/{path}"
headers = {**self._headers, **(headers or {})}
if API_VERSION_HEADER_NAME not in headers:
Expand Down
Loading