Skip to content

Commit

Permalink
✨ Allow to filter invitations by status (#98)
Browse files Browse the repository at this point in the history
* Allow to filter invitations by admin

* ♻️ Tweak types to use SQLModel's SelectOfScalar and a TypeVar

* ♻️ Tweak types

---------

Co-authored-by: Sebastián Ramírez <tiangolo@gmail.com>
  • Loading branch information
patrick91 and tiangolo authored Jun 26, 2024
1 parent 5c8153e commit a67c141
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 12 deletions.
30 changes: 18 additions & 12 deletions backend/app/api/routes/invitations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import timedelta
from typing import Any
from typing import Any, TypeVar

from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import HTMLResponse
from sqlmodel import col, func, select
from sqlmodel.sql.expression import SelectOfScalar

from app.api.deps import CurrentUser, SessionDep, get_first_superuser
from app.api.utils.invitations import (
Expand All @@ -30,6 +31,8 @@

router = APIRouter()

T = TypeVar("T")


@router.get("/me", response_model=InvitationsPublic)
def read_invitations_me(
Expand Down Expand Up @@ -86,6 +89,7 @@ def read_invitations_team_by_admin(
team_slug: str,
skip: int = 0,
limit: int = 100,
status: InvitationStatus | None = None,
) -> Any:
"""
Retrieve a list of invitations sent by the current user.
Expand All @@ -103,18 +107,20 @@ def read_invitations_team_by_admin(
status_code=400, detail="Not enough permissions to execute this action"
)

count_statement = (
select(func.count())
.select_from(Invitation)
.where(Invitation.team_id == Team.id, col(Team.slug) == team_slug)
)
def _apply_filters(statement: SelectOfScalar[T]) -> SelectOfScalar[T]:
statement = statement.where(
Invitation.team_id == Team.id, Team.slug == team_slug
)

if status:
statement = statement.where(Invitation.status == status)

return statement

count_statement = _apply_filters(select(func.count()).select_from(Invitation))
count = session.exec(count_statement).one()
statement = (
select(Invitation)
.where(Invitation.team_id == Team.id, col(Team.slug) == team_slug)
.offset(skip)
.limit(limit)
)

statement = _apply_filters(select(Invitation)).offset(skip).limit(limit)

invitations = session.exec(statement).all()

Expand Down
67 changes: 67 additions & 0 deletions backend/app/tests/api/routes/test_invitations.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,73 @@ def test_read_invitations_team_success(client: TestClient, db: Session) -> None:
assert invitations[1]["id"]


def test_read_invitations_team_filter(client: TestClient, db: Session) -> None:
team = create_random_team(db)

user1 = create_user(
session=db,
email="test2623filter@fastapi.com",
password="test12345",
full_name="test2623",
is_verified=True,
)
user2 = create_user(
session=db,
email="test2722filter@example.com",
password="test12345",
full_name="test2722",
is_verified=True,
)
user3 = create_user(
session=db,
email="test2833filter@example.com",
password="test12345",
full_name="test2833",
is_verified=True,
)

add_user_to_team(session=db, user=user1, team=team, role=Role.admin)
add_user_to_team(session=db, user=user2, team=team, role=Role.admin)

invitation_to_create = InvitationCreate(
team_slug=team.slug, email=user2.email, role=Role.member
)
create_invitation(
session=db,
invitation_in=invitation_to_create,
invited_by=user1,
invitation_status=InvitationStatus.accepted,
)

invitation_to_create = InvitationCreate(
team_slug=team.slug, email=user3.email, role=Role.member
)
pending_invitation = create_invitation(
session=db,
invitation_in=invitation_to_create,
invited_by=user1,
invitation_status=InvitationStatus.pending,
)

user_auth_headers = user_authentication_headers(
client=client, email=user1.email, password="test12345"
)

response = client.get(
f"{settings.API_V1_STR}/invitations/team/{team.slug}",
params={"status": "pending"},
headers=user_auth_headers,
)

assert response.status_code == 200
data = response.json()
invitations = data["data"]
count = data["count"]
assert len(invitations) == 1
assert count == 1
assert invitations[0]["id"] == pending_invitation.id


def test_read_invitations_team_empty(client: TestClient, db: Session) -> None:
team = create_random_team(db)
team2 = create_random_team(db)
Expand Down

0 comments on commit a67c141

Please sign in to comment.