Skip to content

Commit

Permalink
fix: always return proof.file_path for proof uploaded by the user (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphael0202 authored Jan 11, 2024
1 parent e3acac0 commit fe14346
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 9 deletions.
62 changes: 55 additions & 7 deletions app/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import time
import uuid
from pathlib import Path
Expand Down Expand Up @@ -79,6 +80,11 @@ def get_db():
# Authentication helpers
# ------------------------------------------------------------------------------
oauth2_scheme = OAuth2PasswordBearerOrAuthCookie(tokenUrl="auth")
# Version of oauth2_scheme that does not raise an error if the token is
# invalid or missing
oauth2_scheme_no_error = OAuth2PasswordBearerOrAuthCookie(
tokenUrl="auth", auto_error=False
)


def create_token(user_id: str):
Expand All @@ -87,11 +93,20 @@ def create_token(user_id: str):

def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db: Session = Depends(get_db)
):
) -> schemas.UserBase:
"""Get the current user if authenticated.
This function is used as a dependency in endpoints that require
authentication. It raises an HTTPException if the user is not
authenticated.
:param token: the authentication token
:param db: the database session
:raises HTTPException: if the user is not authenticated
:return: the current user
"""
if token and "__U" in token:
current_user: schemas.UserBase = crud.update_user_last_used_field(
db, token=token
)
current_user = crud.update_user_last_used_field(db, token=token)
if current_user:
return current_user
raise HTTPException(
Expand All @@ -101,6 +116,24 @@ def get_current_user(
)


def get_current_user_optional(
token: Annotated[str, Depends(oauth2_scheme_no_error)],
db: Session = Depends(get_db),
) -> schemas.UserBase | None:
"""Get the current user if authenticated, None otherwise.
This function is used as a dependency in endpoints that require
authentication, but where the user is optional.
:param token: the authentication token
:param db: the database session
:return: the current user if authenticated, None otherwise
"""
if token and "__U" in token:
return crud.update_user_last_used_field(db, token=token)
return None


# Routes
# ------------------------------------------------------------------------------

Expand Down Expand Up @@ -166,14 +199,26 @@ def authentication(
)


def price_transformer(prices: list[Price]) -> list[Price]:
def price_transformer(
prices: list[Price], current_user: schemas.UserBase | None = None
) -> list[Price]:
"""Transformer function used to remove the file_path of private proofs.
If current_user is None, the file_path is removed for all proofs that are
not public. Otherwise, the file_path is removed for all proofs that are not
public and do not belong to the current user.
:param prices: the list of prices to transform
:param current_user: the current user, if authenticated
:return: the transformed list of prices
"""
user_id = current_user.user_id if current_user else None
for price in prices:
if price.proof and price.proof.is_public is False:
if (
price.proof
and price.proof.is_public is False
and price.proof.owner != user_id
):
price.proof.file_path = None
return prices

Expand All @@ -182,9 +227,12 @@ def price_transformer(prices: list[Price]) -> list[Price]:
def get_price(
filters: schemas.PriceFilter = FilterDepends(schemas.PriceFilter),
db: Session = Depends(get_db),
current_user: schemas.UserBase | None = Depends(get_current_user_optional),
):
return paginate(
db, crud.get_prices_query(filters=filters), transformer=price_transformer
db,
crud.get_prices_query(filters=filters),
transformer=functools.partial(price_transformer, current_user=current_user),
)


Expand Down
4 changes: 2 additions & 2 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def create_user(db: Session, user: UserBase):
return db_user


def update_user_last_used_field(db: Session, token: str):
def update_user_last_used_field(db: Session, token: str) -> UserBase | None:
db_user = get_user_by_token(db, token=token)
if db_user:
db.query(User).filter(User.user_id == db_user.user_id).update(
Expand All @@ -57,7 +57,7 @@ def update_user_last_used_field(db: Session, token: str):
db.commit()
db.refresh(db_user)
return db_user
return False
return None


def delete_user(db: Session, user_id: UserBase):
Expand Down

0 comments on commit fe14346

Please sign in to comment.