Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement POST /user, POST /login, JWT-based authorization #31

Merged
merged 14 commits into from
Feb 7, 2022
Merged
1 change: 1 addition & 0 deletions backend/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
env/
__pycache__/
Empty file added backend/__init__.py
Empty file.
69 changes: 69 additions & 0 deletions backend/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from datetime import datetime, timedelta, timezone
from typing import Any
from typing_extensions import Self
from flask import request
import jwt


class Token:
"""
Python-representation of a JWT session token.

Example:

```json
{ "id": 0 }
```
"""

def __init__(self, secret: str, user_id: int) -> None:
"""
Prepare a token for serialization with the given user ID.
"""
self.payload = {"id": user_id}
self._secret = secret

def __repr__(self) -> str:
return str(self.payload)

def to_jwt(self) -> str:
"""
Convert the provided dictionary payload into a Base64-encoded JWT.

Tokens expire in 31 days, and cannot be used before the time they
were issued at.
"""
return jwt.encode(
{
**self.payload,
# Tokens are valid for 31 days.
"exp": datetime.now(tz=timezone.utc) + timedelta(days=31),
# Tokens should not be accepted before the present day.
"nbf": datetime.now(tz=timezone.utc),
},
self._secret,
algorithm="HS256",
).decode("utf-8")


def get_token(secret: str) -> Token:
"""
Get the JSON map encoded in the request's "Authorization" header.
Expiration is automatically checked by PyJWT. If the signature is
expired, an ExpiredSignatureError is thrown.

The "nbf" and "exp" claims are required for each token. If they are
missing, then this function will throw an error.

Similarly, should the JWT be valid but missing a required property
(see Token class doc), it will throw a KeyError.
"""
return Token(
secret,
jwt.decode(
request.headers.get("Authorization"),
secret,
algorithms=["HS256"],
options={"require": ["exp", "nbf"]},
)["id"],
)
101 changes: 76 additions & 25 deletions backend/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,36 @@
#!/bin/python3

from argparse import ArgumentParser
from dataclasses import dataclass

from flask import Flask, g, request
import psycopg as pg
import psycopg.errors as errors
from psycopg.rows import class_row

from argparse import ArgumentParser

from auth import *

CONN_STR = ""
SECRET = ""

app = Flask(__name__)


@dataclass
class Account:
"""
Helper dataclass for psycopg row factory use.
"""

id: int
username: str
password: str
professor: bool
deleted: datetime


# Retrieve the global database connection object.
# Pulled from https://flask.palletsprojects.com/en/2.0.x/appcontext/
def get_db() -> pg.Connection:
Expand All @@ -30,16 +50,43 @@ def teardown_db(exception):
# Log an user into the database, then return a valid JWT for their session.
@app.route("/login", methods=["POST"])
def login():
return {"token": "example"}
json = request.json
username, password = json["username"], json["password"]

conn = get_db()
with conn.cursor(row_factory=class_row(Account)) as cur:
cur.execute(
"""
SELECT * FROM Accounts
WHERE username = %s
AND password = crypt(%s, password)
""",
[username, password],
)
records = cur.fetchall()

length = len(records)
if length == 0:
# The login failed.
return {}, 400
elif len(records) > 1:
# There is a problem with the database.
return {}, 500
account = records[0]

return {
"token": Token(
SECRET,
account.id,
),
}, 200


# Create a user in the database, then return a valid JWT for their session.
@app.route("/user", methods=["POST"])
def create():
return {"token": "example"}
form = request.form
account_type, username, password = form["type"], form["username"], form["password"]
invite_key = None if account_type == "student" else form["inviteKey"]
def create_user():
json = request.json
account_type, username, password = json["type"], json["username"], json["password"]

# Bad request
if account_type not in ["student", "professor"]:
Expand All @@ -48,29 +95,24 @@ def create():
# Create a database transaction to insert our accout into the associated
# course.
conn = get_db()
with conn.cursor() as cur:
cur.execute(
"INSERT INTO Accounts (username, password, professor) VALUES (%s, %s, %s)",
username,
password,
account_type == "professor",
)

# If the account is for a student, then join them to their class.
if account_type == "student":
with conn.cursor(row_factory=class_row(Account)) as cur:
try:
cur.execute(
"""
INSERT INTO ClassMembers (id, class_id) VALUES (id, class_id) \
WHERE id = (SELECT id FROM Accounts WHERE username = %s) AND \
class_id = (SELECT invites_to FROM Invites WHERE id = %s)
INSERT INTO Accounts (username, password, professor)
VALUES (%s, %s, %s)
RETURNING *
""",
username,
invite_key,
[username, password, account_type == "professor"],
)
conn.commit()
conn.commit()
result = cur.fetchone()
except errors.UniqueViolation:
# User already exists.
return {}, 409

# TODO: create and return a JWT for the new session
return {}, 201
# Create and return a JWT for the new session containing the user's ID.
return {"token": Token(SECRET, result.id).to_jwt()}, 201


@app.route("/class/<class_id>/info", methods=["GET"])
Expand Down Expand Up @@ -156,9 +198,18 @@ def join_class(class_id):
parser.add_argument(
"--db-conn",
type=str,
default="port=5432 user=dev password=dev",
default="host=localhost port=5432 dbname=gradebetter user=admin password=admin",
help="connection string for a postgresql database",
)
parser.add_argument(
"-s",
"--secret",
type=str,
default="gradebetter",
help="secret key to use in JWT generation",
)
args = parser.parse_args()

CONN_STR = args.db_conn
SECRET = args.secret
app.run(port=args.port)