Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add user sessions #149

Merged
merged 1 commit into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions alembic/versions/20240115_1350_3c2015517236_add_session_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""add session table

Revision ID: 3c2015517236
Revises: 868640c5012e
Create Date: 2024-01-15 13:50:25.409039

"""
from typing import Sequence, Union

import sqlalchemy as sa
from sqlalchemy.dialects import postgresql

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "3c2015517236"
down_revision: Union[str, None] = "868640c5012e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"sessions",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.String(), nullable=False),
sa.Column("token", sa.String(), nullable=False),
sa.Column(
"created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=True,
),
sa.Column("last_used", sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("id"),
)
# Migrate tokens from users to sessions
op.execute(
"""
INSERT INTO sessions (user_id, token, created, last_used)
SELECT user_id, token, created, last_used FROM users
WHERE token IS NOT NULL
"""
)
op.create_index(op.f("ix_sessions_id"), "sessions", ["id"], unique=False)
op.create_index(op.f("ix_sessions_token"), "sessions", ["token"], unique=True)
op.create_index(op.f("ix_sessions_user_id"), "sessions", ["user_id"], unique=False)
op.drop_index("ix_users_token", table_name="users")
op.drop_column("users", "last_used")
op.drop_column("users", "token")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"users", sa.Column("token", sa.VARCHAR(), autoincrement=False, nullable=True)
)
op.add_column(
"users",
sa.Column(
"last_used",
postgresql.TIMESTAMP(timezone=True),
autoincrement=False,
nullable=True,
),
)
# Migrate tokens from sessions to users
op.execute(
"""UPDATE users
SET token = sessions.token, last_used = sessions.last_used
FROM sessions
WHERE sessions.user_id = users.user_id"""
)
op.create_index("ix_users_token", "users", ["token"], unique=False)
op.drop_index(op.f("ix_sessions_user_id"), table_name="sessions")
op.drop_index(op.f("ix_sessions_token"), table_name="sessions")
op.drop_index(op.f("ix_sessions_id"), table_name="sessions")
op.drop_table("sessions")
# ### end Alembic commands ###
24 changes: 12 additions & 12 deletions app/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,9 @@ def get_current_user(
:return: the current user
"""
if token and "__U" in token:
db_user = crud.get_user_by_token(db, token=token)
if db_user:
return crud.update_user_last_used_field(db, user=db_user)
session = crud.get_session_by_token(db, token=token)
if session:
return crud.update_session_last_used_field(db, session=session).user
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid authentication credentials",
Expand All @@ -130,9 +130,9 @@ def get_current_user_optional(
:return: the current user if authenticated, None otherwise
"""
if token and "__U" in token:
db_user = crud.get_user_by_token(db, token=token)
if db_user:
return crud.update_user_last_used_field(db, user=db_user)
session = crud.get_session_by_token(db, token=token)
if session:
return crud.update_session_last_used_field(db, session=session).user
return None


Expand Down Expand Up @@ -179,20 +179,20 @@ def authentication(
detail="OAUTH2_SERVER_URL environment variable missing",
)

data = {"user_id": form_data.username, "password": form_data.password}
user_id = form_data.username
data = {"user_id": user_id, "password": form_data.password}
r = requests.post(settings.oauth2_server_url, data=data) # type: ignore
if r.status_code == 200:
token = create_token(form_data.username)
user = schemas.UserCreate(user_id=form_data.username, token=token)
db_user, created = crud.get_or_create_user(db, user=user)
user = crud.update_user_last_used_field(db, user=db_user)
session, *_ = crud.create_session(db, user_id=user_id, token=token)
session = crud.update_session_last_used_field(db, session=session)
# set the cookie if requested
if set_cookie:
# Don't add httponly=True or secure=True as it's still in
# development phase, but it should be added once the front-end
# is ready
response.set_cookie(key="session", value=user.token)
return {"access_token": user.token, "token_type": "bearer"}
response.set_cookie(key="session", value=token)
return {"access_token": token, "token_type": "bearer"}
elif r.status_code == 403:
time.sleep(2) # prevents brute-force
raise HTTPException(
Expand Down
94 changes: 66 additions & 28 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from app import config
from app.enums import LocationOSMEnum, ProofTypeEnum
from app.models import Location, Price, Product, Proof, User
from app.models import Location, Price, Product, Proof
from app.models import Session as SessionModel
from app.models import User
from app.schemas import (
LocationCreate,
LocationFilter,
Expand All @@ -35,55 +37,91 @@ def get_users_query(filters: ProductFilter | None = None):
return query


def get_users(db: Session, filters: ProductFilter | None = None):
def get_users(db: Session, filters: ProductFilter | None = None) -> list[User]:
"""Return a list of users from the database.

:param db: the database session
:param filters: the filters to apply to the query, defaults to None
:return: the list of users
"""
return db.execute(get_users_query(filters=filters)).all()


def get_user(db: Session, user_id: str):
def get_user_by_user_id(db: Session, user_id: str) -> User:
return db.query(User).filter(User.user_id == user_id).first()


def get_user_by_user_id(db: Session, user_id: str):
return db.query(User).filter(User.user_id == user_id).first()

def get_session_by_token(db: Session, token: str) -> SessionModel:
"""Return the session linked to the token.

def get_user_by_token(db: Session, token: str):
return db.query(User).filter(User.token == token).first()
:param db: the database session
:param token: the session token
:return: the session
"""
return db.query(SessionModel).join(User).filter(SessionModel.token == token).first()


def create_user(db: Session, user: UserCreate) -> User:
def create_user(db: Session, user_id: str) -> User:
"""Create a user in the database.

:param db: the database session
:param product: the user to create
:param user_id: the Open Food Facts user ID
:param token: the session token
:return: the created user
"""
db_user = User(user_id=user.user_id, token=user.token)
db.add(db_user)
user = User(user_id=user_id)
db.add(user)
db.commit()
db.refresh(db_user)
return db_user
db.refresh(user)
return user


def get_or_create_user(db: Session, user: UserCreate):
def _create_session(db: Session, user: User, token: str) -> SessionModel:
"""Create a session in the database.

:param db: the database session
:param user: the user linked to the session
:param token: the session token
:return: the created session
"""
session = SessionModel(token=token, user=user)
db.add(session)
db.commit()
db.refresh(session)
return session


def create_session(
db: Session, user_id: str, token: str
) -> tuple[SessionModel, User, bool]:
"""Create a new session (and optionally the user if it doesn't exist) in
DB.

:param db: the database session
:param user_id: the Open Food Facts user ID
:param token: the session token
:return: the created session, the user and a boolean indicating whether the
user was created or not
"""
created = False
db_user = get_user_by_user_id(db, user_id=user.user_id)
if not db_user:
db_user = create_user(db, user=user)
user = get_user_by_user_id(db, user_id=user_id)
if not user:
user = create_user(db, user_id=user_id)
session = _create_session(db, user=user, token=token)
created = True
return db_user, created
else:
session = get_session_by_token(db, token=token)
if not session:
session = _create_session(db, user=user, token=token)
return session, user, created


def update_user(db: Session, user: UserCreate, update_dict: dict):
for key, value in update_dict.items():
setattr(user, key, value)
def update_session_last_used_field(db: Session, session: SessionModel) -> SessionModel:
"""Update the last_used field of a session to the current time."""
session.last_used = func.now()
db.commit()
db.refresh(user)
return user


def update_user_last_used_field(db: Session, user: UserCreate) -> UserCreate | None:
return update_user(db, user, {"last_used": func.now()})
db.refresh(session)
return session


def increment_user_price_count(db: Session, user: UserCreate):
Expand Down
18 changes: 14 additions & 4 deletions app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,26 @@

class User(Base):
user_id = Column(String, primary_key=True, index=True)
token = Column(String, unique=True, index=True)

last_used = Column(DateTime(timezone=True))
price_count = Column(Integer, nullable=False, server_default="0", index=True)

created = Column(DateTime(timezone=True), server_default=func.now())
sessions: Mapped[list["Session"]] = relationship(back_populates="user")

__tablename__ = "users"


class Session(Base):
id = Column(Integer, primary_key=True, index=True)

user_id = Column(String, ForeignKey("users.user_id"), index=True, nullable=False)
user: Mapped[User] = relationship("User")

token = Column(String, unique=True, index=True, nullable=False)
created = Column(DateTime(timezone=True), server_default=func.now())
last_used = Column(DateTime(timezone=True), nullable=True)

__tablename__ = "sessions"


class Product(Base):
id = Column(Integer, primary_key=True, index=True)

Expand Down
Loading
Loading