Skip to content
This repository was archived by the owner on Dec 20, 2023. It is now read-only.

Use FastAPI Depends, improve ctxlize, context tests and header tests #71

Merged
merged 5 commits into from
Nov 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions actions/local_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
import pytest
import pytest_asyncio
from faker import Faker
from apiserver.data.context.ranking import add_new_event

import apiserver.lib.utilities as util
import auth.core.util
from apiserver import data
from apiserver.app.modules.ranking import add_new_event, NewEvent
from apiserver.app.ops.startup import get_keystate
from apiserver.data import Source, get_conn
from apiserver.data.api.classifications import insert_classification, UserPoints
from apiserver.data.api.ud.userdata import new_userdata
from apiserver.data.special import update_class_points
from apiserver.define import DEFINE
from apiserver.env import load_config
from apiserver.lib.model.entities import SignedUp, UserNames
from apiserver.lib.model.entities import NewEvent, SignedUp, UserNames
from auth.data.authentication import get_apake_setup
from auth.data.keys import get_keys
from auth.data.relational.opaque import get_setup
Expand All @@ -28,7 +28,7 @@
from store import Store


@pytest.fixture(scope="module", autouse=True)
@pytest.fixture(scope="session", autouse=True)
def event_loop():
"""Necessary for async tests with module-scoped fixtures"""
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -212,4 +212,4 @@ async def test_add_event(local_dsrc, faker: Faker):
description="desc",
)

await add_new_event(local_dsrc, new_event)
await add_new_event(DontReplaceContext(), local_dsrc, new_event)
211 changes: 107 additions & 104 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ pytest-mock = "^3.7.0"
pre-commit = "^2.20.0"
httpx = "^0.24.1"
alembic = "^1.12.0"
coverage = "^6.3.2"
coverage = "^7.3.2"
black = "^23.9.1"
mypy = "^1.5.1"
#types-redis = "^4.3.21"
Expand Down
109 changes: 109 additions & 0 deletions src/apiserver/app/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Annotated
from fastapi import Depends, Request
from apiserver.app.error import ErrorResponse
from apiserver.app.ops.header import auth_header, verify_token_header
from apiserver.data.context.app_context import Code, SourceContexts

from apiserver.data.source import Source
from apiserver.lib.model.entities import AccessToken
from apiserver.lib.resource.error import ResourceError, resource_error_code
from auth.data.context import Contexts as AuthContexts
from datacontext.context import ctxlize


def dep_source(request: Request) -> Source:
dsrc: Source = request.state.dsrc
return dsrc


def dep_app_context(request: Request) -> SourceContexts:
cd: Code = request.state.cd
return cd.app_context


def dep_auth_context(request: Request) -> AuthContexts:
cd: Code = request.state.cd
return cd.auth_context


SourceDep = Annotated[Source, Depends(dep_source)]
AppContext = Annotated[SourceContexts, Depends(dep_app_context)]
AuthContext = Annotated[AuthContexts, Depends(dep_auth_context)]

Authorization = Annotated[str, Depends(auth_header)]


async def dep_header_token(
authorization: Authorization, dsrc: SourceDep, app_ctx: AppContext
) -> AccessToken:
try:
return await ctxlize(verify_token_header)(
app_ctx.authrz_ctx, authorization, dsrc
)
except ResourceError as e:
code = resource_error_code(e.err_type)

raise ErrorResponse(
code,
err_type=e.err_type,
err_desc=e.err_desc,
debug_key=e.debug_key,
)


AccessDep = Annotated[AccessToken, Depends(dep_header_token)]


def verify_user(acc: AccessToken, user_id: str) -> bool:
"""Verifies if the user in the access token corresponds to the provided user_id.

Args:
acc: AccessToken object.
user_id: user_id that will be compared against.

Returns:
True if user_id = acc.sub.

Raises:
ErrorResponse: If access token subject does not correspond to user_id.
"""
if acc.sub != user_id:
reason = "Resource not available to this subject."
raise ErrorResponse(
403, err_type="wrong_subject", err_desc=reason, debug_key="bad_sub"
)

return True


def has_scope(scopes: str, required: set[str]) -> bool:
scope_set = set(scopes.split())
return required.issubset(scope_set)


def require_admin(acc: AccessDep) -> AccessToken:
if not has_scope(acc.scope, {"admin"}):
raise ErrorResponse(
403,
err_type="insufficient_scope",
err_desc="Insufficient permissions to access this resource.",
debug_key="low_perms",
)

return acc


def require_member(acc: AccessDep) -> AccessToken:
if not has_scope(acc.scope, {"member"}):
raise ErrorResponse(
403,
err_type="insufficient_scope",
err_desc="Insufficient permissions to access this resource.",
debug_key="low_perms",
)

return acc


RequireMember = Annotated[AccessToken, Depends(require_member)]
RequireAdmin = Annotated[AccessToken, Depends(require_admin)]
7 changes: 7 additions & 0 deletions src/apiserver/app/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,22 @@ class ErrorKeys(StrEnum):
RANKING_UPDATE = "invalid_ranking_update"
DATA = "invalid_data_load"
GET_CLASS = "invalid_get_class"
CHECK = "invalid_code_check"
UPDATE = "invalid_update"


class AppError(Exception):
err_type: ErrorKeys
err_desc: str
debug_key: Optional[str]

def __init__(
self,
err_type: ErrorKeys,
err_desc: str,
debug_key: Optional[str] = None,
):
super().__init__(err_desc)
self.err_type = err_type
self.err_desc = err_desc
self.debug_key = debug_key
Expand Down
17 changes: 9 additions & 8 deletions src/apiserver/app/modules/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@

from apiserver.app.error import ErrorKeys, AppError
from apiserver.data import Source
from apiserver.data.context.app_context import RankingContext
from apiserver.data.context.ranking import (
context_events_in_class,
context_most_recent_class_id_of_type,
context_user_events_in_class,
)
from apiserver.data.api.classifications import events_in_class
from apiserver.data.context.app_context import RankingContext, conn_wrap
from apiserver.data.context.ranking import context_most_recent_class_id_of_type
from apiserver.data.source import source_session
from apiserver.data.special import user_events_in_class
from apiserver.lib.logic.ranking import is_rank_type
from apiserver.lib.model.entities import ClassEvent, UserEvent
from datacontext.context import ctxlize_wrap


async def class_id_or_recent(
Expand Down Expand Up @@ -49,7 +48,7 @@ async def mod_user_events_in_class(
async with source_session(dsrc) as session:
sure_class_id = await class_id_or_recent(session, ctx, class_id, rank_type)

user_events = await context_user_events_in_class(
user_events = await ctxlize_wrap(user_events_in_class, conn_wrap)(
ctx, session, user_id, sure_class_id
)

Expand All @@ -65,6 +64,8 @@ async def mod_events_in_class(
async with source_session(dsrc) as session:
sure_class_id = await class_id_or_recent(session, ctx, class_id, rank_type)

events = await context_events_in_class(ctx, session, sure_class_id)
events = await ctxlize_wrap(events_in_class, conn_wrap)(
ctx, session, sure_class_id
)

return events
56 changes: 56 additions & 0 deletions src/apiserver/app/modules/update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from apiserver.app.error import AppError, ErrorKeys
from apiserver.data.context.app_context import UpdateContext
from apiserver.data.source import Source
from apiserver.data.trs.trs import pop_string
from auth.data.authentication import pop_flow_user
from auth.data.context import LoginContext
from datacontext.context import ctxlize
from store.error import NoDataError


async def verify_delete_account(
auth_code: str,
flow_id: str,
dsrc: Source,
update_ctx: UpdateContext,
login_ctx: LoginContext,
) -> str:
# First we check if the auth code exists. This proves the flow_user logged in with the correct password.
# Note that such a code can be generated by any user that logs in, so this is not proof that this user wants to
# delete their account
try:
flow_user = await pop_flow_user(login_ctx, dsrc.store, auth_code)
except NoDataError:
# logger.debug(e.message)
reason = "Expired or missing auth code"
raise AppError(
err_type=ErrorKeys.CHECK, err_desc=reason, debug_key="empty_flow"
)
# We check if the user who authenticated with their password has the same flow_id as the one requested.
# Note that we have not yet checked if this is a valid flow ID.
if flow_user.flow_id != flow_id:
reason = "Flow for auth code does not match requested delete flow!"
raise AppError(
err_type=ErrorKeys.UPDATE,
err_desc=reason,
debug_key="update_flows_dont_match",
)

# If the flow ID exists, then we know that a delete request was indeed initiated.
stored_user_id = await ctxlize(pop_string)(update_ctx, dsrc, flow_id)
if stored_user_id is None:
reason = "Delete request has expired, please try again!"
# logger.debug(reason + f" {flow_user.user_id}")
raise AppError(err_type=ErrorKeys.UPDATE, err_desc=reason)

# If the user ID's are equal, we know that the person who requested the deletion is the same person who
# logged in. Therefore, we have proved that the delete account initiator has entered their password.
if flow_user.user_id != stored_user_id:
reason = "Flow user does not match requested delete flow's user!"
raise AppError(err_type=ErrorKeys.UPDATE, err_desc=reason)
if stored_user_id is None:
reason = "Delete request has expired, please try again!"
# logger.debug(reason + f" {flow_user.user_id}")
raise AppError(err_type=ErrorKeys.UPDATE, err_desc=reason)

return stored_user_id
27 changes: 0 additions & 27 deletions src/apiserver/app/ops/errors.py

This file was deleted.

72 changes: 35 additions & 37 deletions src/apiserver/app/ops/header.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,55 @@
from typing import Annotated

from fastapi.params import Security
from fastapi.security.api_key import APIKeyHeader
from fastapi import HTTPException, Request
from fastapi.datastructures import Headers

from apiserver.define import DEFINE, grace_period
from apiserver import data
from apiserver.app.ops.errors import BadAuth
from apiserver.lib.hazmat.tokens import (
verify_access_token,
BadVerification,
get_kid,
)
from apiserver.data import Source
from apiserver.lib.resource.error import ResourceError
from apiserver.lib.resource.header import (
AccessSettings,
extract_token_and_kid,
resource_verify_token,
)
from store.error import NoDataError
from apiserver.lib.model.entities import AccessToken

scheme = "Bearer"
www_authenticate = f"Bearer realm={DEFINE.realm}"

# TODO modify APIKeyHeader for better status code
auth_header = APIKeyHeader(name="Authorization", scheme_name=scheme, auto_error=True)

Authorization = Annotated[str, Security(auth_header)]
def auth_header(request: Request) -> str:
# This is so we don't have to instantiate a Request object, which can be annoying
return parse_auth_header(request.headers)


async def handle_header(authorization: str, dsrc: Source) -> AccessToken:
if authorization is None:
raise BadAuth(err_type="invalid_request", err_desc="No authorization provided.")
if "Bearer " not in authorization:
raise BadAuth(
err_type="invalid_request",
err_desc="Authorization must follow 'Bearer' scheme",
def parse_auth_header(headers: Headers) -> str:
authorization = headers.get("Authorization")
if not authorization:
# Conforms to RFC6750 https://www.rfc-editor.org/rfc/rfc6750.html
raise HTTPException(
status_code=400, headers={"WWW-Authenticate": www_authenticate}
)
token = authorization.removeprefix("Bearer ")

return authorization


async def verify_token_header(authorization: str, dsrc: Source) -> AccessToken:
# THROWS ResourceError
token, kid = extract_token_and_kid(authorization)

try:
kid = get_kid(token)
public_key = (await data.trs.key.get_pem_key(dsrc, kid)).public
return verify_access_token(
public_key,
token,
grace_period,
DEFINE.issuer,
DEFINE.backend_client_id,
)
except NoDataError as e:
raise BadAuth(
raise ResourceError(
err_type="invalid_token",
err_desc="Key does not exist!",
debug_key=e.key,
)
except BadVerification as e:
raise BadAuth(
err_type="invalid_token",
err_desc="Token verification failed!",
debug_key=e.err_key,
)

access_settings = AccessSettings(
grace_period=grace_period,
issuer=DEFINE.issuer,
aud_client_ids=[DEFINE.backend_client_id],
)

# THROWS ResourceError
return resource_verify_token(token, public_key, access_settings)
Loading