From a67c14166541345890b18d552362dada0d64ead3 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 26 Jun 2024 11:08:52 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Allow=20to=20filter=20invitations?= =?UTF-8?q?=20by=20status=20(#98)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow to filter invitations by admin * ♻️ Tweak types to use SQLModel's SelectOfScalar and a TypeVar * ♻️ Tweak types --------- Co-authored-by: Sebastián Ramírez --- backend/app/api/routes/invitations.py | 30 +++++---- .../app/tests/api/routes/test_invitations.py | 67 +++++++++++++++++++ 2 files changed, 85 insertions(+), 12 deletions(-) diff --git a/backend/app/api/routes/invitations.py b/backend/app/api/routes/invitations.py index 5b3bdb9f5f..a387047a6d 100644 --- a/backend/app/api/routes/invitations.py +++ b/backend/app/api/routes/invitations.py @@ -1,9 +1,10 @@ from datetime import timedelta -from typing import Any +from typing import Any, TypeVar from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import HTMLResponse from sqlmodel import col, func, select +from sqlmodel.sql.expression import SelectOfScalar from app.api.deps import CurrentUser, SessionDep, get_first_superuser from app.api.utils.invitations import ( @@ -30,6 +31,8 @@ router = APIRouter() +T = TypeVar("T") + @router.get("/me", response_model=InvitationsPublic) def read_invitations_me( @@ -86,6 +89,7 @@ def read_invitations_team_by_admin( team_slug: str, skip: int = 0, limit: int = 100, + status: InvitationStatus | None = None, ) -> Any: """ Retrieve a list of invitations sent by the current user. @@ -103,18 +107,20 @@ def read_invitations_team_by_admin( status_code=400, detail="Not enough permissions to execute this action" ) - count_statement = ( - select(func.count()) - .select_from(Invitation) - .where(Invitation.team_id == Team.id, col(Team.slug) == team_slug) - ) + def _apply_filters(statement: SelectOfScalar[T]) -> SelectOfScalar[T]: + statement = statement.where( + Invitation.team_id == Team.id, Team.slug == team_slug + ) + + if status: + statement = statement.where(Invitation.status == status) + + return statement + + count_statement = _apply_filters(select(func.count()).select_from(Invitation)) count = session.exec(count_statement).one() - statement = ( - select(Invitation) - .where(Invitation.team_id == Team.id, col(Team.slug) == team_slug) - .offset(skip) - .limit(limit) - ) + + statement = _apply_filters(select(Invitation)).offset(skip).limit(limit) invitations = session.exec(statement).all() diff --git a/backend/app/tests/api/routes/test_invitations.py b/backend/app/tests/api/routes/test_invitations.py index efb77c717b..3fd9ae4521 100644 --- a/backend/app/tests/api/routes/test_invitations.py +++ b/backend/app/tests/api/routes/test_invitations.py @@ -326,6 +326,73 @@ def test_read_invitations_team_success(client: TestClient, db: Session) -> None: assert invitations[1]["id"] +def test_read_invitations_team_filter(client: TestClient, db: Session) -> None: + team = create_random_team(db) + + user1 = create_user( + session=db, + email="test2623filter@fastapi.com", + password="test12345", + full_name="test2623", + is_verified=True, + ) + user2 = create_user( + session=db, + email="test2722filter@example.com", + password="test12345", + full_name="test2722", + is_verified=True, + ) + user3 = create_user( + session=db, + email="test2833filter@example.com", + password="test12345", + full_name="test2833", + is_verified=True, + ) + + add_user_to_team(session=db, user=user1, team=team, role=Role.admin) + add_user_to_team(session=db, user=user2, team=team, role=Role.admin) + + invitation_to_create = InvitationCreate( + team_slug=team.slug, email=user2.email, role=Role.member + ) + create_invitation( + session=db, + invitation_in=invitation_to_create, + invited_by=user1, + invitation_status=InvitationStatus.accepted, + ) + + invitation_to_create = InvitationCreate( + team_slug=team.slug, email=user3.email, role=Role.member + ) + pending_invitation = create_invitation( + session=db, + invitation_in=invitation_to_create, + invited_by=user1, + invitation_status=InvitationStatus.pending, + ) + + user_auth_headers = user_authentication_headers( + client=client, email=user1.email, password="test12345" + ) + + response = client.get( + f"{settings.API_V1_STR}/invitations/team/{team.slug}", + params={"status": "pending"}, + headers=user_auth_headers, + ) + + assert response.status_code == 200 + data = response.json() + invitations = data["data"] + count = data["count"] + assert len(invitations) == 1 + assert count == 1 + assert invitations[0]["id"] == pending_invitation.id + + def test_read_invitations_team_empty(client: TestClient, db: Session) -> None: team = create_random_team(db) team2 = create_random_team(db)