Skip to content
This repository was archived by the owner on Dec 20, 2023. It is now read-only.

Context tests and contextualize #70

Merged
merged 13 commits into from
Nov 8, 2023
9 changes: 4 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,16 @@ jobs:
npm install
npm run build
working-directory: ${{ github.workspace }}/authpage

- name: Pytest
run: poetry run pytest
env:
QUERY_TEST: true
- name: Black
run: poetry run black src tests
- name: Ruff
run: poetry run ruff src tests
- name: Mypy
run: poetry run mypy
- name: Pytest
run: poetry run pytest
env:
QUERY_TEST: true
services:
postgres:
image: ghcr.io/dsav-dodeka/postgres:localdev
Expand Down
4 changes: 4 additions & 0 deletions src/apiserver/app/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Modules encapsulate the different context functions in cases where the router function would become too complex.
They should not know about HTTP (that is for the actual routers and possible router helper functions), hence they
should raise only AppError exceptions and take in only model objects, a context and data source. They should NOT open
connections directly, that is for the lower data context functions to do."""
127 changes: 43 additions & 84 deletions src/apiserver/app/modules/ranking.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,21 @@
from typing import Literal, Optional, TypeGuard
from typing import Optional

from pydantic import BaseModel
from datetime import date

from apiserver.app.error import ErrorKeys, AppError
from apiserver.data import Source
from apiserver import data
from apiserver.data.api.classifications import UserPoints
from store.error import DataError


class NewEvent(BaseModel):
users: list[UserPoints]
class_type: Literal["points", "training"]
date: date
event_id: str
category: str
description: str = ""


async def add_new_event(dsrc: Source, new_event: NewEvent) -> None:
"""Add a new event and recompute points. Display points will be updated to not include any events after the hidden
date. Use the 'publish' function to force them to be equal."""
async with data.get_conn(dsrc) as conn:
try:
classification = await data.classifications.most_recent_class_of_type(
conn, new_event.class_type
)
except DataError as e:
if e.key != "incorrect_class_type":
raise e
raise AppError(ErrorKeys.RANKING_UPDATE, e.message, "incorrect_class_type")
if classification.start_date > new_event.date:
desc = "Event cannot happen before start of classification!"
raise AppError(ErrorKeys.RANKING_UPDATE, desc, "ranking_date_before_start")

event_id = await data.classifications.add_class_event(
conn,
new_event.event_id,
classification.classification_id,
new_event.category,
new_event.date,
new_event.description,
)

try:
await data.classifications.add_users_to_event(
conn, event_id=event_id, points=new_event.users
)
except DataError as e:
if e.key != "database_integrity":
raise e
raise AppError(
ErrorKeys.RANKING_UPDATE,
e.message,
"add_event_users_violates_integrity",
)

await data.special.update_class_points(
conn,
classification.classification_id,
)


async def sync_publish_ranking(dsrc: Source, publish: bool) -> None:
async with data.get_conn(dsrc) as conn:
training_class = await data.classifications.most_recent_class_of_type(
conn, "training"
)
points_class = await data.classifications.most_recent_class_of_type(
conn, "points"
)
await data.special.update_class_points(
conn, training_class.classification_id, publish
)
await data.special.update_class_points(
conn, points_class.classification_id, publish
)


def is_rank_type(rank_type: str) -> TypeGuard[Literal["training", "points"]]:
return rank_type in {"training", "points"}
from apiserver.data.context.app_context import RankingContext
from apiserver.data.context.ranking import (
context_events_in_class,
context_most_recent_class_id_of_type,
context_user_events_in_class,
)
from apiserver.data.source import source_session
from apiserver.lib.logic.ranking import is_rank_type
from apiserver.lib.model.entities import ClassEvent, UserEvent


async def class_id_or_recent(
dsrc: Source, class_id: Optional[int], rank_type: Optional[str]
dsrc: Source, ctx: RankingContext, class_id: Optional[int], rank_type: Optional[str]
) -> int:
if class_id is not None:
return class_id
Expand All @@ -103,9 +34,37 @@ async def class_id_or_recent(
debug_key="user_events_bad_ranking",
)
else:
async with data.get_conn(dsrc) as conn:
class_id = (
await data.classifications.most_recent_class_of_type(conn, rank_type)
).classification_id
class_id = await context_most_recent_class_id_of_type(ctx, dsrc, rank_type)

return class_id


async def mod_user_events_in_class(
dsrc: Source,
ctx: RankingContext,
user_id: str,
class_id: Optional[int],
rank_type: Optional[str],
) -> list[UserEvent]:
async with source_session(dsrc) as session:
sure_class_id = await class_id_or_recent(session, ctx, class_id, rank_type)

user_events = await context_user_events_in_class(
ctx, session, user_id, sure_class_id
)

return user_events


async def mod_events_in_class(
dsrc: Source,
ctx: RankingContext,
class_id: Optional[int] = None,
rank_type: Optional[str] = None,
) -> list[ClassEvent]:
async with source_session(dsrc) as session:
sure_class_id = await class_id_or_recent(session, ctx, class_id, rank_type)

events = await context_events_in_class(ctx, session, sure_class_id)

return events
77 changes: 43 additions & 34 deletions src/apiserver/app/routers/ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,34 @@
from fastapi import APIRouter
from starlette.requests import Request

from apiserver import data
from apiserver.app.error import ErrorResponse, AppError
from apiserver.app.modules.ranking import (
add_new_event,
NewEvent,
sync_publish_ranking,
class_id_or_recent,
is_rank_type,
mod_events_in_class,
mod_user_events_in_class,
)
from apiserver.app.ops.header import Authorization
from apiserver.app.response import RawJSONResponse
from apiserver.app.routers.helper import require_admin, require_member
from apiserver.data.api.classifications import get_event_user_points
from apiserver.data.context.app_context import Code, RankingContext, conn_wrap
from apiserver.data.context.authorize import require_admin as ctx_require_admin
from apiserver.data.context.ranking import (
add_new_event,
context_most_recent_class_points,
sync_publish_ranking,
)
from apiserver.lib.logic.ranking import is_rank_type
from apiserver.lib.model.entities import (
ClassEvent,
NewEvent,
UserEvent,
UserPointsNames,
UserPointsNamesList,
UserEventsList,
EventsList,
)
from apiserver.data import Source, get_conn
from apiserver.data.api.classifications import events_in_class, get_event_user_points
from apiserver.data.special import user_events_in_class
from apiserver.data import Source
from datacontext.context import ctxlize

router = APIRouter()

Expand All @@ -34,16 +39,17 @@ async def admin_update_ranking(
new_event: NewEvent, request: Request, authorization: Authorization
) -> None:
dsrc: Source = request.state.dsrc
cd: Code = request.state.cd
await require_admin(authorization, dsrc)

try:
await add_new_event(dsrc, new_event)
await add_new_event(cd.app_context.rank_ctx, dsrc, new_event)
except AppError as e:
raise ErrorResponse(400, "invalid_ranking_update", e.err_desc, e.debug_key)


async def get_classification(
dsrc: Source, rank_type: str, admin: bool = False
dsrc: Source, ctx: RankingContext, rank_type: str, admin: bool = False
) -> RawJSONResponse:
if not is_rank_type(rank_type):
reason = f"Ranking {rank_type} is unknown!"
Expand All @@ -53,13 +59,8 @@ async def get_classification(
err_desc=reason,
debug_key="bad_ranking",
)
async with data.get_conn(dsrc) as conn:
class_view = await data.classifications.most_recent_class_of_type(
conn, rank_type
)
user_points = await data.classifications.all_points_in_class(
conn, class_view.classification_id, admin
)

user_points = await context_most_recent_class_points(ctx, dsrc, rank_type, admin)
return RawJSONResponse(UserPointsNamesList.dump_json(user_points))


Expand All @@ -70,30 +71,33 @@ async def member_classification(
rank_type: str, request: Request, authorization: Authorization
) -> RawJSONResponse:
dsrc: Source = request.state.dsrc
cd: Code = request.state.cd
await require_member(authorization, dsrc)

return await get_classification(dsrc, rank_type, False)
return await get_classification(dsrc, cd.app_context.rank_ctx, rank_type, False)


@router.get("/admin/classification/{rank_type}/", response_model=list[UserPointsNames])
async def member_classification_admin(
rank_type: str, request: Request, authorization: Authorization
) -> RawJSONResponse:
dsrc: Source = request.state.dsrc
cd: Code = request.state.cd
await require_admin(authorization, dsrc)

return await get_classification(dsrc, rank_type, True)
return await get_classification(dsrc, cd.app_context.rank_ctx, rank_type, True)


@router.post("/admin/class/sync/")
async def sync_publish_classification(
request: Request, authorization: Authorization, publish: Optional[str] = None
) -> None:
dsrc: Source = request.state.dsrc
cd: Code = request.state.cd
await require_admin(authorization, dsrc)

do_publish = publish == "publish"
await sync_publish_ranking(dsrc, do_publish)
await sync_publish_ranking(cd.app_context.rank_ctx, dsrc, do_publish)


@router.get("/admin/class/events/user/{user_id}/", response_model=list[UserEvent])
Expand All @@ -105,48 +109,53 @@ async def get_user_events_in_class(
rank_type: Optional[str] = None,
) -> RawJSONResponse:
dsrc: Source = request.state.dsrc
cd: Code = request.state.cd
await require_admin(authorization, dsrc)

try:
sure_class_id = await class_id_or_recent(dsrc, class_id, rank_type)
user_events = await mod_user_events_in_class(
dsrc, cd.app_context.rank_ctx, user_id, class_id, rank_type
)
except AppError as e:
raise ErrorResponse(400, e.err_type, e.err_desc, e.debug_key)

async with get_conn(dsrc) as conn:
user_events = await user_events_in_class(conn, user_id, sure_class_id)

return RawJSONResponse(UserEventsList.dump_json(user_events))


@router.get("/admin/class/events/", response_model=list[ClassEvent])
@router.get("/admin/class/events/all/", response_model=list[ClassEvent])
async def get_events_in_class(
request: Request,
authorization: Authorization,
class_id: Optional[int] = None,
rank_type: Optional[str] = None,
) -> RawJSONResponse:
dsrc: Source = request.state.dsrc
cd: Code = request.state.cd
await require_admin(authorization, dsrc)

try:
sure_class_id = await class_id_or_recent(dsrc, class_id, rank_type)
events = await mod_events_in_class(
dsrc, cd.app_context.rank_ctx, class_id, rank_type
)
except AppError as e:
raise ErrorResponse(400, e.err_type, e.err_desc, e.debug_key)

async with get_conn(dsrc) as conn:
events = await events_in_class(conn, sure_class_id)

return RawJSONResponse(EventsList.dump_json(events))


@router.get("/admin/class/events/{event_id}/", response_model=list[UserPointsNames])
@router.get(
"/admin/class/users/event/{event_id}/", response_model=list[UserPointsNames]
)
async def get_event_users(
event_id: str, request: Request, authorization: Authorization
) -> RawJSONResponse:
dsrc: Source = request.state.dsrc
await require_admin(authorization, dsrc)
cd: Code = request.state.cd
await ctx_require_admin(cd.app_context.authrz_ctx, authorization, dsrc)

async with get_conn(dsrc) as conn:
event_users = await get_event_user_points(conn, event_id)
# Result could be empty!
event_users = await ctxlize(get_event_user_points, conn_wrap)(
cd.wrap, dsrc, event_id
)

return RawJSONResponse(UserPointsNamesList.dump_json(event_users))
13 changes: 12 additions & 1 deletion src/apiserver/app_lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
from apiserver.data.context import Code, SourceContexts
from apiserver.data.context.register import ctx_reg as register_app_reg
from apiserver.data.context.update import ctx_reg as update_reg
from apiserver.data.context.ranking import ctx_reg as ranking_reg
from apiserver.data.context.authorize import ctx_reg as authrz_app_reg
from apiserver.define import LOGGER_NAME, DEFINE
from apiserver.env import load_config, Config
from apiserver.resources import res_path
from datacontext.context import WrapContext

logger = logging.getLogger(LOGGER_NAME)

Expand Down Expand Up @@ -91,8 +94,16 @@ def register_and_define_code() -> Code:
source_data_context = SourceContexts()
source_data_context.include_registry(register_app_reg)
source_data_context.include_registry(update_reg)
source_data_context.include_registry(ranking_reg)
source_data_context.include_registry(authrz_app_reg)

return Code(auth_context=data_context, app_context=source_data_context)
wrap_data_context = WrapContext()

return Code(
auth_context=data_context,
app_context=source_data_context,
wrap=wrap_data_context,
)


AppLifespan = Callable[[FastAPI], AsyncContextManager[State]]
Expand Down
Loading