From 2d5564ac4c686574dc5c15b108822273739dd101 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 27 Dec 2023 19:03:45 +0100 Subject: [PATCH 1/4] =?UTF-8?q?=E2=9C=A8=20Update=20models=20for=20login?= =?UTF-8?q?=20API?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/app/models.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/backend/app/app/models.py b/src/backend/app/app/models.py index 590c3e92aa..ca860c2021 100644 --- a/src/backend/app/app/models.py +++ b/src/backend/app/app/models.py @@ -73,16 +73,20 @@ class ItemOut(ItemBase): # Generic message -class Msg(BaseModel): - msg: str +class Message(BaseModel): + message: str # JSON payload containing access token class Token(BaseModel): access_token: str - token_type: str + token_type: str = "bearer" # Contents of JWT token class TokenPayload(BaseModel): sub: Union[int, None] = None + +class NewPassword(BaseModel): + token: str + new_password: str From ff5b9c3104781a30c5f58d349902d5d54c4765db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 27 Dec 2023 19:03:59 +0100 Subject: [PATCH 2/4] =?UTF-8?q?=E2=9C=A8=20Add=20authenticate=20simplified?= =?UTF-8?q?=20CRUD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/app/crud/__init__.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/backend/app/app/crud/__init__.py b/src/backend/app/app/crud/__init__.py index b2b325069c..d452138321 100644 --- a/src/backend/app/app/crud/__init__.py +++ b/src/backend/app/app/crud/__init__.py @@ -9,7 +9,7 @@ # item = CRUDBase[Item, ItemCreate, ItemUpdate](Item) from sqlmodel import Session, select -from app.core.security import get_password_hash +from app.core.security import get_password_hash, verify_password from app.models import UserCreate, User @@ -27,3 +27,12 @@ def get_user_by_email(*, session: Session, email: str) -> User | None: statement = select(User).where(User.email == email) session_user = session.exec(statement).first() return session_user + + +def authenticate(*, session: Session, email: str, password: str) -> User | None: + user = get_user_by_email(session=session, email=email) + if not user: + return None + if not verify_password(password, user.hashed_password): + return None + return user From 5394ba0f8a1810c865853ed80b321e81e9395cad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 27 Dec 2023 19:04:37 +0100 Subject: [PATCH 3/4] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Refactor=20get=5Fcurre?= =?UTF-8?q?nt=5Fuser=20dependency,=20integrate=20is=5Factive?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/backend/app/app/api/deps.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/backend/app/app/api/deps.py b/src/backend/app/app/api/deps.py index 552f0a2818..abe067b55f 100644 --- a/src/backend/app/app/api/deps.py +++ b/src/backend/app/app/api/deps.py @@ -39,18 +39,12 @@ def get_current_user(session: SessionDep, token: TokenDep) -> User: user = session.get(User, token_data.sub) if not user: raise HTTPException(status_code=404, detail="User not found") - return user - - -def get_current_active_user( - current_user: Annotated[User, Depends(get_current_user)] -) -> User: - if not current_user.is_active: + if not user.is_active: raise HTTPException(status_code=400, detail="Inactive user") - return current_user + return user -CurrentUser = Annotated[User, Depends(get_current_active_user)] +CurrentUser = Annotated[User, Depends(get_current_user)] def get_current_active_superuser(current_user: CurrentUser) -> User: From aff61c897c32b378ef6fbf6288f3a307713abc54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebasti=C3=A1n=20Ram=C3=ADrez?= Date: Wed, 27 Dec 2023 19:05:24 +0100 Subject: [PATCH 4/4] =?UTF-8?q?=E2=9C=A8=20Refactor=20and=20upgrade=20logi?= =?UTF-8?q?n=20API=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/app/api/api_v1/endpoints/login.py | 66 +++++++++---------- 1 file changed, 32 insertions(+), 34 deletions(-) diff --git a/src/backend/app/app/api/api_v1/endpoints/login.py b/src/backend/app/app/api/api_v1/endpoints/login.py index 4dc3a9b248..15f0f3f095 100644 --- a/src/backend/app/app/api/api_v1/endpoints/login.py +++ b/src/backend/app/app/api/api_v1/endpoints/login.py @@ -1,15 +1,15 @@ from datetime import timedelta -from typing import Any +from typing import Annotated, Any -from fastapi import APIRouter, Body, Depends, HTTPException +from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm -from sqlalchemy.orm import Session -from app import crud, models, schemas -from app.api import deps +from app import crud +from app.api.deps import CurrentUser, SessionDep from app.core import security from app.core.config import settings from app.core.security import get_password_hash +from app.models import Message, NewPassword, Token, UserOut from app.utils import ( generate_password_reset_token, send_reset_password_email, @@ -19,43 +19,42 @@ router = APIRouter() -@router.post("/login/access-token", response_model=schemas.Token) +@router.post("/login/access-token") def login_access_token( - db: Session = Depends(deps.get_db), form_data: OAuth2PasswordRequestForm = Depends() -) -> Any: + session: SessionDep, form_data: Annotated[OAuth2PasswordRequestForm, Depends()] +) -> Token: """ OAuth2 compatible token login, get an access token for future requests """ - user = crud.user.authenticate( - db, email=form_data.username, password=form_data.password + user = crud.authenticate( + session=session, email=form_data.username, password=form_data.password ) if not user: raise HTTPException(status_code=400, detail="Incorrect email or password") - elif not crud.user.is_active(user): + elif not user.is_active: raise HTTPException(status_code=400, detail="Inactive user") access_token_expires = timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES) - return { - "access_token": security.create_access_token( + return Token( + access_token=security.create_access_token( user.id, expires_delta=access_token_expires - ), - "token_type": "bearer", - } + ) + ) -@router.post("/login/test-token", response_model=schemas.User) -def test_token(current_user: models.User = Depends(deps.get_current_user)) -> Any: +@router.post("/login/test-token", response_model=UserOut) +def test_token(current_user: CurrentUser) -> Any: """ Test access token """ return current_user -@router.post("/password-recovery/{email}", response_model=schemas.Msg) -def recover_password(email: str, db: Session = Depends(deps.get_db)) -> Any: +@router.post("/password-recovery/{email}") +def recover_password(email: str, session: SessionDep) -> Message: """ Password Recovery """ - user = crud.user.get_by_email(db, email=email) + user = crud.get_user_by_email(session=session, email=email) if not user: raise HTTPException( @@ -66,31 +65,30 @@ def recover_password(email: str, db: Session = Depends(deps.get_db)) -> Any: send_reset_password_email( email_to=user.email, email=email, token=password_reset_token ) - return {"msg": "Password recovery email sent"} + return Message(message="Password recovery email sent") -@router.post("/reset-password/", response_model=schemas.Msg) +@router.post("/reset-password/") def reset_password( - token: str = Body(...), - new_password: str = Body(...), - db: Session = Depends(deps.get_db), -) -> Any: + session: SessionDep, + body: NewPassword, +) -> Message: """ Reset password """ - email = verify_password_reset_token(token) + email = verify_password_reset_token(token=body.token) if not email: raise HTTPException(status_code=400, detail="Invalid token") - user = crud.user.get_by_email(db, email=email) + user = crud.get_user_by_email(session=session, email=email) if not user: raise HTTPException( status_code=404, detail="The user with this username does not exist in the system.", ) - elif not crud.user.is_active(user): + elif not user.is_active: raise HTTPException(status_code=400, detail="Inactive user") - hashed_password = get_password_hash(new_password) + hashed_password = get_password_hash(password=body.new_password) user.hashed_password = hashed_password - db.add(user) - db.commit() - return {"msg": "Password updated successfully"} + session.add(user) + session.commit() + return Message(message="Password updated successfully")