Skip to content

Commit

Permalink
store oauth2 state in cookies
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Sep 19, 2024
1 parent b8729b5 commit c143665
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 90 deletions.
60 changes: 51 additions & 9 deletions src/phoenix/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,31 +71,53 @@ def validate_password_format(password: str) -> None:
def set_access_token_cookie(
*, response: ResponseType, access_token: str, max_age: timedelta
) -> ResponseType:
return _set_token_cookie(
return _set_cookie(
response=response,
cookie_name=PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
cookie_max_age=max_age,
token=access_token,
value=access_token,
)


def set_refresh_token_cookie(
*, response: ResponseType, refresh_token: str, max_age: timedelta
) -> ResponseType:
return _set_token_cookie(
return _set_cookie(
response=response,
cookie_name=PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
cookie_max_age=max_age,
token=refresh_token,
value=refresh_token,
)


def _set_token_cookie(
response: ResponseType, cookie_name: str, cookie_max_age: timedelta, token: str
def set_oauth2_state_cookie(
*, response: ResponseType, state: str, max_age: timedelta
) -> ResponseType:
return _set_cookie(
response=response,
cookie_name=PHOENIX_OAUTH2_STATE_COOKIE_NAME,
cookie_max_age=max_age,
value=state,
)


def set_oauth2_nonce_cookie(
*, response: ResponseType, nonce: str, max_age: timedelta
) -> ResponseType:
return _set_cookie(
response=response,
cookie_name=PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
cookie_max_age=max_age,
value=nonce,
)


def _set_cookie(
response: ResponseType, cookie_name: str, cookie_max_age: timedelta, value: str
) -> ResponseType:
response.set_cookie(
key=cookie_name,
value=token,
value=value,
secure=get_env_phoenix_use_secure_cookies(),
httponly=True,
samesite="strict",
Expand All @@ -104,16 +126,26 @@ def _set_token_cookie(
return response


def delete_access_token_cookie(response: Response) -> Response:
def delete_access_token_cookie(response: ResponseType) -> ResponseType:
response.delete_cookie(key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
return response


def delete_refresh_token_cookie(response: Response) -> Response:
def delete_refresh_token_cookie(response: ResponseType) -> ResponseType:
response.delete_cookie(key=PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
return response


def delete_oauth2_state_cookie(response: ResponseType) -> ResponseType:
response.delete_cookie(key=PHOENIX_OAUTH2_STATE_COOKIE_NAME)
return response


def delete_oauth2_nonce_cookie(response: ResponseType) -> ResponseType:
response.delete_cookie(key=PHOENIX_OAUTH2_NONCE_COOKIE_NAME)
return response


@dataclass(frozen=True)
class _PasswordRequirements:
"""
Expand Down Expand Up @@ -206,6 +238,16 @@ def validate(
"""The name of the cookie that stores the Phoenix access token."""
PHOENIX_REFRESH_TOKEN_COOKIE_NAME = "phoenix-refresh-token"
"""The name of the cookie that stores the Phoenix refresh token."""
PHOENIX_OAUTH2_STATE_COOKIE_NAME = "phoenix-oauth2-state"
"""The name of the cookie that stores the state used for the OAuth2 authorization code flow."""
PHOENIX_OAUTH2_NONCE_COOKIE_NAME = "phoenix-oauth2-nonce"
"""The name of the cookie that stores the nonce used for the OAuth2 authorization code flow."""
DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES = 15
"""
The default amount of time in minutes that can elapse between the initial
redirect to the IDP and the invocation of the callback URL during the OAuth2
authorization code flow.
"""


class Token(str): ...
Expand Down
4 changes: 4 additions & 0 deletions src/phoenix/server/api/routers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
Token,
delete_access_token_cookie,
delete_oauth2_nonce_cookie,
delete_oauth2_state_cookie,
delete_refresh_token_cookie,
is_valid_password,
set_access_token_cookie,
Expand Down Expand Up @@ -97,6 +99,8 @@ async def logout(
response = Response(status_code=HTTP_204_NO_CONTENT)
response = delete_access_token_cookie(response)
response = delete_refresh_token_cookie(response)
response = delete_oauth2_state_cookie(response)
response = delete_oauth2_nonce_cookie(response)
return response


Expand Down
85 changes: 72 additions & 13 deletions src/phoenix/server/api/routers/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,36 @@
from datetime import timedelta
from typing import Any, Dict, Optional

from authlib.common.security import generate_token
from authlib.integrations.starlette_client import OAuthError
from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client
from fastapi import APIRouter, Path, Request
from fastapi import APIRouter, Cookie, Path, Query, Request
from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from starlette.datastructures import URL
from starlette.responses import RedirectResponse
from starlette.status import HTTP_302_FOUND
from typing_extensions import Annotated

from phoenix.auth import (
DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES,
PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
PHOENIX_OAUTH2_STATE_COOKIE_NAME,
delete_oauth2_nonce_cookie,
delete_oauth2_state_cookie,
set_access_token_cookie,
set_oauth2_nonce_cookie,
set_oauth2_state_cookie,
set_refresh_token_cookie,
)
from phoenix.db import models
from phoenix.db.enums import UserRole
from phoenix.server.bearer_auth import create_access_and_refresh_tokens
from phoenix.server.jwt_store import JwtStore
from phoenix.server.oauth2 import OAuth2Client

_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+"


router = APIRouter(prefix="/oauth2", include_in_schema=False)


Expand All @@ -36,16 +44,43 @@ async def login(
oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client
):
return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
redirect_uri = request.url_for("create_tokens", idp_name=idp_name)
response: RedirectResponse = await oauth2_client.authorize_redirect(request, redirect_uri)
authorization_url_data = await oauth2_client.create_authorization_url(
redirect_uri=_get_create_tokens_endpoint(request=request, idp_name=idp_name),
state=generate_token(),
)
assert isinstance(authorization_url := authorization_url_data.get("url"), str)
assert isinstance(state := authorization_url_data.get("state"), str)
assert isinstance(nonce := authorization_url_data.get("nonce"), str)
response = RedirectResponse(url=authorization_url, status_code=HTTP_302_FOUND)
response = set_oauth2_state_cookie(
response=response,
state=state,
max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
)
response = set_oauth2_nonce_cookie(
response=response,
nonce=nonce,
max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES),
)
return response


@router.get("/{idp_name}/tokens")
async def create_tokens(
request: Request,
idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)],
state: str = Query(),
authorization_code: str = Query(alias="code"),
stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME),
stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME),
) -> RedirectResponse:
if state != stored_state:
return _redirect_to_login(
error=(
"Received invalid state parameter during "
"OAuth2 authorization code flow for IDP {idp_name}."
)
)
assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta)
assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta)
token_store: JwtStore = request.app.state.get_token_store()
Expand All @@ -54,13 +89,20 @@ async def create_tokens(
):
return _redirect_to_login(error=f"Unknown IDP: {idp_name}.")
try:
token_data = await oauth2_client.authorize_access_token(request)
token_data = await oauth2_client.fetch_access_token(
state=state,
code=authorization_code,
redirect_uri=_get_create_tokens_endpoint(request=request, idp_name=idp_name),
)
except OAuthError as error:
return _redirect_to_login(error=str(error))
if (user_info := _get_user_info(token_data)) is None:
_validate_token_data(token_data)
if "id_token" not in token_data:
return _redirect_to_login(
error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect."
)
user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce)
user_info = _parse_user_info(user_info)
try:
async with request.app.state.db() as session:
user = await _ensure_user_exists_and_is_up_to_date(
Expand All @@ -76,13 +118,15 @@ async def create_tokens(
access_token_expiry=access_token_expiry,
refresh_token_expiry=refresh_token_expiry,
)
response = RedirectResponse(url="/") # todo: sanitize a return url
response = RedirectResponse(url="/", status_code=HTTP_302_FOUND) # todo: sanitize a return url
response = set_access_token_cookie(
response=response, access_token=access_token, max_age=access_token_expiry
)
response = set_refresh_token_cookie(
response=response, refresh_token=refresh_token, max_age=refresh_token_expiry
)
response = delete_oauth2_state_cookie(response)
response = delete_oauth2_nonce_cookie(response)
return response


Expand All @@ -94,15 +138,19 @@ class UserInfo:
profile_picture_url: Optional[str]


def _get_user_info(token_data: Dict[str, Any]) -> Optional[UserInfo]:
def _validate_token_data(token_data: Dict[str, Any]) -> None:
"""
Parses token data and extracts user info if available.
Performs basic validations on the token data returned by the IDP.
"""
assert isinstance(token_data.get("access_token"), str)
assert isinstance(token_type := token_data.get("token_type"), str)
assert token_type.lower() == "bearer"
if (user_info := token_data.get("userinfo")) is None:
return None


def _parse_user_info(user_info: Dict[str, Any]) -> UserInfo:
"""
Parses user info from the IDP's ID token.
"""
assert isinstance(subject := user_info.get("sub"), (str, int))
idp_user_id = str(subject)
assert isinstance(email := user_info.get("email"), str)
Expand Down Expand Up @@ -278,4 +326,15 @@ def _redirect_to_login(*, error: str) -> RedirectResponse:
"""
Creates a RedirectResponse to the login page to display an error message.
"""
return RedirectResponse(url=URL("/login").include_query_params(error=error))
url = URL("/login").include_query_params(error=error)
response = RedirectResponse(url=url)
response = delete_oauth2_state_cookie(response)
response = delete_oauth2_nonce_cookie(response)
return response


def _get_create_tokens_endpoint(*, request: Request, idp_name: str) -> str:
"""
Gets the endpoint for create tokens route.
"""
return str(request.url_for(create_tokens.__name__, idp_name=idp_name))
88 changes: 20 additions & 68 deletions src/phoenix/server/oauth2.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
from datetime import datetime, timedelta
from typing import Any, Dict, Generic, Iterable, Optional, Tuple
from typing import Any, Dict, Iterable

from authlib.integrations.starlette_client import OAuth
from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client
from typing_extensions import TypeAlias, TypeVar
from authlib.integrations.base_client import BaseApp
from authlib.integrations.base_client.async_app import AsyncOAuth2Mixin
from authlib.integrations.base_client.async_openid import AsyncOpenIDMixin
from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncHttpxOAuth2Client

from phoenix.config import OAuth2ClientConfig


class OAuth2Client(AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): # type:ignore[misc]
"""
An OAuth2 client class that supports OpenID Connect. Adapted from authlib's
`StarletteOAuth2App` to be useable without integration with Starlette.
https://github.com/lepture/authlib/blob/904d66bebd79bf39fb8814353a22bab7d3e092c4/authlib/integrations/starlette_client/apps.py#L58
"""

client_cls = AsyncHttpxOAuth2Client

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(framework=None, *args, **kwargs)


class OAuth2Clients:
def __init__(self) -> None:
self._clients: Dict[str, OAuth2Client] = {}
self._oauth = OAuth(cache=_OAuth2ClientTTLCache[str, Any]())

def add_client(self, config: OAuth2ClientConfig) -> None:
if (idp_name := config.idp_name) in self._clients:
raise ValueError(f"oauth client already registered: {idp_name}")
client = self._oauth.register(
idp_name,
client = OAuth2Client(
client_id=config.client_id,
client_secret=config.client_secret,
server_metadata_url=config.server_metadata_url,
Expand All @@ -37,63 +49,3 @@ def from_configs(cls, configs: Iterable[OAuth2ClientConfig]) -> "OAuth2Clients":
for config in configs:
oauth2_clients.add_client(config)
return oauth2_clients


_CacheKey = TypeVar("_CacheKey")
_CacheValue = TypeVar("_CacheValue")
_Expiry: TypeAlias = datetime
_MINUTE = timedelta(minutes=1)


class _OAuth2ClientTTLCache(Generic[_CacheKey, _CacheValue]):
"""
A TTL cache satisfying the interface required by the Authlib Starlette
integration. Provides an alternative to starlette session middleware.
"""

def __init__(self, cleanup_interval: timedelta = 1 * _MINUTE) -> None:
self._data: Dict[_CacheKey, Tuple[_CacheValue, _Expiry]] = {}
self._last_cleanup_time = datetime.now()
self._cleanup_interval = cleanup_interval

async def get(self, key: _CacheKey) -> Optional[_CacheValue]:
"""
Retrieves the value associated with the given key if it exists and has
not expired, otherwise, returns None.
"""
self._remove_expired_keys_if_cleanup_interval_exceeded()
if (value_and_expiry := self._data.get(key)) is None:
return None
value, expiry = value_and_expiry
if datetime.now() < expiry:
return value
self._data.pop(key, None)
return None

async def set(self, key: _CacheKey, value: _CacheValue, expires: int) -> None:
"""
Sets the value associated with the given key to the provided value with
the given expiry time in seconds.
"""
self._remove_expired_keys_if_cleanup_interval_exceeded()
expiry = datetime.now() + timedelta(seconds=expires)
self._data[key] = (value, expiry)

async def delete(self, key: _CacheKey) -> None:
"""
Removes the value associated with the given key if it exists.
"""
self._remove_expired_keys_if_cleanup_interval_exceeded()
self._data.pop(key, None)

def _remove_expired_keys_if_cleanup_interval_exceeded(self) -> None:
time_since_last_cleanup = datetime.now() - self._last_cleanup_time
if time_since_last_cleanup > self._cleanup_interval:
self._remove_expired_keys()

def _remove_expired_keys(self) -> None:
current_time = datetime.now()
delete_keys = [key for key, (_, expiry) in self._data.items() if expiry <= current_time]
for key in delete_keys:
self._data.pop(key, None)
self._last_cleanup_time = current_time

0 comments on commit c143665

Please sign in to comment.