Skip to content
This repository was archived by the owner on Dec 20, 2023. It is now read-only.

Typing #68

Merged
merged 8 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ jobs:
npm install
npm run build
working-directory: ${{ github.workspace }}/authpage
- name: Pytest
run: poetry run pytest

- name: Black
run: poetry run black src tests
- name: Ruff
run: poetry run ruff src tests
- name: Mypy
run: poetry run mypy
- name: Pytest
run: poetry run pytest
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
.python-version
localenv.toml
.DS_Store
.vscode
12 changes: 6 additions & 6 deletions actions/local_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from apiserver.define import DEFINE
from apiserver.env import load_config
from apiserver.lib.model.entities import SignedUp, UserNames
from auth.core.model import IdInfo
from auth.data.authentication import get_apake_setup
from auth.data.keys import get_keys
from auth.data.schemad.opaque import get_setup
from auth.data.relational.opaque import get_setup
from auth.data.relational.user import EmptyIdUserData
from auth.define import refresh_exp, access_exp, id_exp
from auth.token.build import create_tokens, finish_tokens
from datacontext.context import DontReplaceContext
Expand Down Expand Up @@ -57,14 +57,14 @@ async def admin_access(local_dsrc):
admin_id = "admin_test"
scope = "member admin"
utc_now = auth.core.util.utc_timestamp()
id_info = IdInfo()
id_userdata = EmptyIdUserData()
access_token_data, id_token_data, access_scope, refresh_save = create_tokens(
admin_id,
scope,
utc_now - 1,
"test_nonce",
utc_now,
id_info,
id_userdata,
DEFINE.issuer,
DEFINE.frontend_client_id,
DEFINE.backend_client_id,
Expand All @@ -79,7 +79,7 @@ async def admin_access(local_dsrc):
keys.symmetric,
access_token_data,
id_token_data,
id_info,
id_userdata,
utc_now,
keys.signing,
access_exp,
Expand All @@ -97,7 +97,7 @@ async def test_get_admin_token(admin_access):
@pytest.mark.asyncio
async def test_generate_admin(local_dsrc):
admin_password = "admin"
setup = await get_apake_setup(local_dsrc.store)
setup = await get_apake_setup(DontReplaceContext(), local_dsrc.store)

cl_req, cl_state = opq.register_client(admin_password)
serv_resp = opq.register(setup, cl_req, util.usp_hex("0_admin"))
Expand Down
621 changes: 295 additions & 326 deletions poetry.lock

Large diffs are not rendered by default.

20 changes: 3 additions & 17 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ anyio = "^3.7.1"
regex = "^2023.10.3"
orjson = "^3.9.5"
yarl = "^1.9.2"
types-regex = "^2023.10.3.0"

[tool.pytest.ini_options]
asyncio_mode = "strict"
Expand All @@ -38,7 +39,6 @@ asyncio_mode = "strict"
s-api = "apiserver.dev:run"

[tool.poetry.group.dev.dependencies]
sqlalchemy = { extras = ["mypy", "asyncio"], version = "^2.0.20" }
pytest = "^7.0.1"
pytest-asyncio = "^0.20.3"
pytest-mock = "^3.7.0"
Expand All @@ -48,7 +48,7 @@ alembic = "^1.12.0"
coverage = "^6.3.2"
black = "^23.9.1"
mypy = "^1.5.1"
types-redis = "^4.3.21"
#types-redis = "^4.3.21"
# sqlalchemy-stubs = "^0.4"
faker = "^19.3.1"
ruff = "^0.0.287"
Expand All @@ -62,26 +62,12 @@ python_version = "3.11"
strict = true
files = ["src"]
plugins = [
"sqlmypy",
"pydantic.mypy",
"sqlalchemy.ext.mypy.plugin"
]
exclude = [
"src/apiserver/db/migrations/"
]

[[tool.mypy.overrides]]
module = [
"apiserver.auth.*",
"apiserver.data.*",
"apiserver.db.*",
"apiserver.db.migrations.*",
"apiserver.kv.*",
"apiserver.routers.*",
"apiserver.app",
"apiserver.dev",
"apiserver.env",

"schema.model.env"
]
ignore_errors = true

Expand Down
20 changes: 13 additions & 7 deletions src/apiserver/app/modules/ranking.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Literal, Optional
from typing import Literal, Optional, TypeGuard

from pydantic import BaseModel
from datetime import date
Expand All @@ -19,7 +19,7 @@ class NewEvent(BaseModel):
description: str = ""


async def add_new_event(dsrc: Source, new_event: NewEvent):
async def add_new_event(dsrc: Source, new_event: NewEvent) -> None:
"""Add a new event and recompute points. Display points will be updated to not include any events after the hidden
date. Use the 'publish' function to force them to be equal."""
async with data.get_conn(dsrc) as conn:
Expand Down Expand Up @@ -63,7 +63,7 @@ async def add_new_event(dsrc: Source, new_event: NewEvent):
)


async def sync_publish_ranking(dsrc: Source, publish: bool):
async def sync_publish_ranking(dsrc: Source, publish: bool) -> None:
async with data.get_conn(dsrc) as conn:
training_class = await data.classifications.most_recent_class_of_type(
conn, "training"
Expand All @@ -79,27 +79,33 @@ async def sync_publish_ranking(dsrc: Source, publish: bool):
)


def is_rank_type(rank_type: str) -> TypeGuard[Literal["training", "points"]]:
return rank_type in {"training", "points"}


async def class_id_or_recent(
dsrc: Source, class_id: Optional[int], rank_type: Optional[str]
) -> int:
if class_id is None and rank_type is None:
if class_id is not None:
return class_id
elif rank_type is None:
reason = "Provide either class_id or rank_type query parameter!"
raise AppError(
err_type=ErrorKeys.GET_CLASS,
err_desc=reason,
debug_key="user_events_invalid_class",
)
elif class_id is None and rank_type not in {"training", "points"}:
elif not is_rank_type(rank_type):
reason = f"Ranking {rank_type} is unknown!"
raise AppError(
err_type=ErrorKeys.GET_CLASS,
err_desc=reason,
debug_key="user_events_bad_ranking",
)
elif class_id is None:
else:
async with data.get_conn(dsrc) as conn:
class_id = (
await data.classifications.most_recent_class_of_type(conn, rank_type)
).classification_id

return class_id
return class_id
2 changes: 1 addition & 1 deletion src/apiserver/app/modules/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class FinishRequest(BaseModel):

async def finalize_save_register(
dsrc: Source, context: RegisterAppContext, register_finish: FinishRequest
):
) -> None:
saved_state = await get_register_state(context, dsrc, register_finish.auth_id)

# Generate password file
Expand Down
18 changes: 9 additions & 9 deletions src/apiserver/app/ops/mail.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def mail_from_config(config: Config) -> Optional[MailServer]:


def send_email(
logger_sent,
logger_sent: logging.Logger,
template: str,
receiver_email: str,
mail_server: Optional[MailServer],
Expand Down Expand Up @@ -76,10 +76,10 @@ def send_signup_email(
mail_server: Optional[MailServer],
redirect_link: str,
signup_link: str,
):
) -> None:
add_vars = {"redirect_link": redirect_link, "signup_link": signup_link}

def send_lam():
def send_lam() -> None:
send_email(
logger,
"confirm.jinja2",
Expand All @@ -98,10 +98,10 @@ def send_register_email(
receiver: str,
mail_server: Optional[MailServer],
register_link: str,
):
) -> None:
add_vars = {"register_link": register_link}

def send_lam():
def send_lam() -> None:
org_name = loc_dict["loc"]["org_name"]
send_email(
logger,
Expand All @@ -120,12 +120,12 @@ def send_reset_email(
receiver: str,
mail_server: Optional[MailServer],
reset_link: str,
):
) -> None:
add_vars = {
"reset_link": reset_link,
}

def send_lam():
def send_lam() -> None:
send_email(
logger,
"passwordchange.jinja2",
Expand All @@ -144,13 +144,13 @@ def send_change_email_email(
mail_server: Optional[MailServer],
reset_link: str,
old_email: str,
):
) -> None:
add_vars = {
"old_email": old_email,
"reset_link": reset_link,
}

def send_lam():
def send_lam() -> None:
send_email(
logger,
"emailchange.jinja2",
Expand Down
41 changes: 27 additions & 14 deletions src/apiserver/app/ops/startup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@
from apiserver.data.api.classifications import insert_classification
from apiserver.data.source import KeyState
from apiserver.env import Config
from apiserver.lib.model.entities import JWKSet, User, UserData, JWKPublicEdDSA
from auth.data.schemad.opaque import insert_opaque_row
from apiserver.lib.model.entities import (
JWKSet,
User,
UserData,
JWKPublicEdDSA,
JWKSymmetricA256GCM,
)
from auth.data.relational.opaque import insert_opaque_row
from auth.hazmat.structs import A256GCMKey
from apiserver.lib.hazmat import keys
from apiserver.lib.hazmat.keys import ed448_private_to_pem
Expand All @@ -22,11 +28,12 @@
from apiserver.data import Source
from schema.model import metadata as db_model
from apiserver.data.admin import drop_recreate_database
from store.error import DataError

logger = logging.getLogger(LOGGER_NAME)


async def startup(dsrc: Source, config: Config, recreate=False):
async def startup(dsrc: Source, config: Config, recreate: bool = False) -> None:
# Checks lock: returns True if it is the first lock since at least 25 seconds (lock expire time)
first_lock = await waiting_lock(dsrc)
logger.debug(f"{first_lock} - first lock status")
Expand Down Expand Up @@ -54,7 +61,7 @@ async def startup(dsrc: Source, config: Config, recreate=False):
MAX_WAIT_INDEX = 15


async def waiting_lock(dsrc: Source):
async def waiting_lock(dsrc: Source) -> bool:
"""We need this lock because in production we spawn multiple processes, which each startup separately. Returns
true if it is the first lock since at least 25 seconds (lock expire time)."""
await sleep(random() + 0.1)
Expand All @@ -73,7 +80,7 @@ async def waiting_lock(dsrc: Source):
return False


def drop_create_database(config: Config):
def drop_create_database(config: Config) -> None:
# runtime_key = aes_from_symmetric(config.KEY_PASS)
db_cluster = f"{config.DB_USER}:{config.DB_PASS}@{config.DB_HOST}:{config.DB_PORT}"
db_url = f"{db_cluster}/{config.DB_NAME}"
Expand All @@ -91,7 +98,7 @@ def drop_create_database(config: Config):
del sync_engine


async def initial_population(dsrc: Source, config: Config):
async def initial_population(dsrc: Source, config: Config) -> None:
kid1, kid2, kid3 = (util.random_time_hash_hex(short=True) for _ in range(3))
old_symmetric = keys.new_symmetric_key(kid1)
new_symmetric = keys.new_symmetric_key(kid2)
Expand Down Expand Up @@ -159,7 +166,7 @@ async def initial_population(dsrc: Source, config: Config):
await insert_classification(conn, "points")


async def get_keystate(dsrc: Source):
async def get_keystate(dsrc: Source) -> KeyState:
async with data.get_conn(dsrc) as conn:
# We get the Key IDs (kid) of the newest keys and also previous symmetric key
# These newest ones will be used for signing new tokens
Expand All @@ -173,21 +180,21 @@ async def get_keystate(dsrc: Source):
)


async def load_keys_from_jwk(dsrc: Source, config: Config):
async def load_keys_from_jwk(dsrc: Source, config: Config) -> JWKSet:
# Key used to decrypt the keys stored in the database
runtime_key = aes_from_symmetric(config.KEY_PASS)
async with data.get_conn(dsrc) as conn:
encrypted_key_set = await data.key.get_jwk(conn)
key_set_dict = decrypt_dict(runtime_key.private, encrypted_key_set)
key_set: JWKSet = JWKSet.model_validate(key_set_dict)
key_set = JWKSet.model_validate(key_set_dict)
# We re-encrypt as is required when using AES encryption
reencrypted_key_set = encrypt_dict(runtime_key.private, key_set_dict)
await data.key.update_jwk(conn, reencrypted_key_set)

return key_set


async def load_keys(dsrc: Source, config: Config):
async def load_keys(dsrc: Source, config: Config) -> None:
key_set = await load_keys_from_jwk(dsrc, config)
key_state = await get_keystate(dsrc)

Expand All @@ -197,6 +204,11 @@ async def load_keys(dsrc: Source, config: Config):
public_keys = []
for key in key_set.keys:
if key.alg == "EdDSA":
if key.d is None:
raise DataError(
"Key private bytes not defined for EdDSA key!",
"eddsa_no_private_bytes",
)
key_private_bytes = util.dec_b64url(key.d)
# PyJWT only accepts keys in PEM format, so we convert them from the raw format we store them in
pem_key, pem_private_key = ed448_private_to_pem(key_private_bytes, key.kid)
Expand All @@ -205,12 +217,13 @@ async def load_keys(dsrc: Source, config: Config):
# The public keys we will store in raw format, we want to exclude the private key as we want to be able to
# publish these keys
# The 'x' are the public key bytes (as set by the JWK standard)
public_key = JWKPublicEdDSA(
alg=key.alg, kid=key.kid, kty=key.kty, use=key.use, crv=key.crv, x=key.x
)
public_key = JWKPublicEdDSA.model_validate(key.model_dump())
public_keys.append(public_key)
elif key.alg == "A256GCM":
symmetric_key = A256GCMKey(kid=key.kid, symmetric=key.k)
symmetric_key_jwk = JWKSymmetricA256GCM.model_validate(key.model_dump())
symmetric_key = A256GCMKey(
kid=symmetric_key_jwk.kid, symmetric=symmetric_key_jwk.k
)
symmetric_keys.append(symmetric_key)

# In the future we can publish these keys
Expand Down
Loading