Skip to content

Synchronous Session support. #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion fastapi_users_db_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi_users.models import ID, OAP, UP
from sqlalchemy import Boolean, ForeignKey, Integer, String, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from sqlalchemy.orm import Mapped, Session, declared_attr, mapped_column
from sqlalchemy.sql import Select

from fastapi_users_db_sqlalchemy.generics import GUID
Expand Down Expand Up @@ -185,3 +185,97 @@ async def update_oauth_account(
async def _get_user(self, statement: Select) -> Optional[UP]:
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()


class SQLAlchemySynchronousUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]):
"""
Database adapter for SQLAlchemy with synchronous session support.

:param session: SQLAlchemy session instance.
:param user_table: SQLAlchemy user model.
:param oauth_account_table: Optional SQLAlchemy OAuth accounts model.
"""

session: Session
user_table: Type[UP]
oauth_account_table: Optional[Type[SQLAlchemyBaseOAuthAccountTable]]

def __init__(
self,
session: Session,
user_table: Type[UP],
oauth_account_table: Optional[Type[SQLAlchemyBaseOAuthAccountTable]] = None,
):
self.session = session
self.user_table = user_table
self.oauth_account_table = oauth_account_table

async def get(self, id: ID) -> Optional[UP]:
statement = select(self.user_table).where(self.user_table.id == id)
return await self._get_user(statement)

async def get_by_email(self, email: str) -> Optional[UP]:
statement = select(self.user_table).where(
func.lower(self.user_table.email) == func.lower(email)
)
return await self._get_user(statement)

async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]:
if self.oauth_account_table is None:
raise NotImplementedError()

statement = (
select(self.user_table)
.join(self.oauth_account_table)
.where(self.oauth_account_table.oauth_name == oauth) # type: ignore
.where(self.oauth_account_table.account_id == account_id) # type: ignore
)
return await self._get_user(statement)

async def create(self, create_dict: Dict[str, Any]) -> UP:
user = self.user_table(**create_dict)
self.session.add(user)
self.session.commit()
return user

async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP:
for key, value in update_dict.items():
setattr(user, key, value)
self.session.add(user)
self.session.commit()
return user

async def delete(self, user: UP) -> None:
self.session.delete(user)
self.session.commit()

async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP:
if self.oauth_account_table is None:
raise NotImplementedError()

self.session.refresh(user)
oauth_account = self.oauth_account_table(**create_dict)
self.session.add(oauth_account)
user.oauth_accounts.append(oauth_account) # type: ignore
self.session.add(user)

self.session.commit()

return user

async def update_oauth_account(
self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any]
) -> UP:
if self.oauth_account_table is None:
raise NotImplementedError()

for key, value in update_dict.items():
setattr(oauth_account, key, value)
self.session.add(oauth_account)
self.session.commit()

return user

async def _get_user(self, statement: Select) -> Optional[UP]:
results = self.session.execute(statement)
return results.unique().scalar_one_or_none()
50 changes: 49 additions & 1 deletion fastapi_users_db_sqlalchemy/access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi_users.models import ID
from sqlalchemy import ForeignKey, String, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from sqlalchemy.orm import Mapped, Session, declared_attr, mapped_column

from fastapi_users_db_sqlalchemy.generics import GUID, TIMESTAMPAware, now_utc

Expand Down Expand Up @@ -85,3 +85,51 @@ async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
async def delete(self, access_token: AP) -> None:
await self.session.delete(access_token)
await self.session.commit()


class SQLAlchemySynchronousAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]):
"""
Access token database adapter for SQLAlchemy with synchronous session support.

:param session: SQLAlchemy session instance.
:param access_token_table: SQLAlchemy access token model.
"""

def __init__(
self,
session: Session,
access_token_table: Type[AP],
):
self.session = session
self.access_token_table = access_token_table

async def get_by_token(
self, token: str, max_age: Optional[datetime] = None
) -> Optional[AP]:
statement = select(self.access_token_table).where(
self.access_token_table.token == token # type: ignore
)
if max_age is not None:
statement = statement.where(
self.access_token_table.created_at >= max_age # type: ignore
)

results = self.session.execute(statement)
return results.scalar_one_or_none()

async def create(self, create_dict: Dict[str, Any]) -> AP:
access_token = self.access_token_table(**create_dict)
self.session.add(access_token)
self.session.commit()
return access_token

async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP:
for key, value in update_dict.items():
setattr(access_token, key, value)
self.session.add(access_token)
self.session.commit()
return access_token

async def delete(self, access_token: AP) -> None:
self.session.delete(access_token)
self.session.commit()
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
DATABASE_URL = os.getenv(
"DATABASE_URL", "sqlite+aiosqlite:///./test-sqlalchemy-user.db"
)
SYNC_DATABASE_URL = os.getenv(
"SYNC_DATABASE_URL", "sqlite:///./test-sqlalchemy-user.db"
)


class User(schemas.BaseUser):
Expand Down
98 changes: 61 additions & 37 deletions tests/test_access_token.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,22 @@

import pytest
from pydantic import UUID4
from sqlalchemy import exc
from sqlalchemy import Engine, create_engine, exc
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import DeclarativeBase, sessionmaker

from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import (
SQLAlchemyAccessTokenDatabase,
SQLAlchemyBaseAccessTokenTableUUID,
SQLAlchemySynchronousAccessTokenDatabase,
)
from tests.conftest import DATABASE_URL
from tests.conftest import DATABASE_URL, SYNC_DATABASE_URL


class Base(DeclarativeBase):
Expand All @@ -33,98 +34,121 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
pass


def create_session_maker(engine: Engine):
return sessionmaker(engine)


def create_async_session_maker(engine: AsyncEngine):
return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)


def pytest_generate_tests(metafunc):
if "access_token_db" in metafunc.fixturenames:
metafunc.parametrize("access_token_db", ["sync", "async"], indirect=True)


@pytest.fixture
def user_id() -> UUID4:
return uuid.uuid4()


@pytest.fixture
async def sqlalchemy_access_token_db(
async def access_token_db(
request,
user_id: UUID4,
) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase[AccessToken], None]:
engine = create_async_engine(DATABASE_URL)
sessionmaker = create_async_session_maker(engine)
if request.param == "async":
engine = create_async_engine(DATABASE_URL)
sessionmaker = create_async_session_maker(engine)

async with engine.begin() as connection:
await connection.run_sync(Base.metadata.create_all)

async with sessionmaker() as session:
user = User(
id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere"
)
session.add(user)
await session.commit()

yield SQLAlchemyAccessTokenDatabase(session, AccessToken)

async with engine.begin() as connection:
await connection.run_sync(Base.metadata.drop_all)
elif request.param == "sync":
engine = create_engine(
SYNC_DATABASE_URL, connect_args={"check_same_thread": False}
)
sessionmaker = create_session_maker(engine)

async with engine.begin() as connection:
await connection.run_sync(Base.metadata.create_all)
Base.metadata.create_all(bind=engine)

async with sessionmaker() as session:
user = User(
id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere"
)
session.add(user)
await session.commit()
with sessionmaker() as session:
user = User(
id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere"
)
session.add(user)
session.commit()

yield SQLAlchemyAccessTokenDatabase(session, AccessToken)
yield SQLAlchemySynchronousAccessTokenDatabase(session, AccessToken)

async with engine.begin() as connection:
await connection.run_sync(Base.metadata.drop_all)
Base.metadata.drop_all(bind=engine)
else:
raise ValueError("invalid internal test config")


@pytest.mark.asyncio
async def test_queries(
sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken],
access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken],
user_id: UUID4,
):
access_token_create = {"token": "TOKEN", "user_id": user_id}

# Create
access_token = await sqlalchemy_access_token_db.create(access_token_create)
access_token = await access_token_db.create(access_token_create)
assert access_token.token == "TOKEN"
assert access_token.user_id == user_id

# Update
update_dict = {"created_at": datetime.now(timezone.utc)}
updated_access_token = await sqlalchemy_access_token_db.update(
access_token, update_dict
)
updated_access_token = await access_token_db.update(access_token, update_dict)
assert updated_access_token.created_at.replace(microsecond=0) == update_dict[
"created_at"
].replace(microsecond=0)

# Get by token
access_token_by_token = await sqlalchemy_access_token_db.get_by_token(
access_token.token
)
access_token_by_token = await access_token_db.get_by_token(access_token.token)
assert access_token_by_token is not None

# Get by token expired
access_token_by_token = await sqlalchemy_access_token_db.get_by_token(
access_token_by_token = await access_token_db.get_by_token(
access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1)
)
assert access_token_by_token is None

# Get by token not expired
access_token_by_token = await sqlalchemy_access_token_db.get_by_token(
access_token_by_token = await access_token_db.get_by_token(
access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1)
)
assert access_token_by_token is not None

# Get by token unknown
access_token_by_token = await sqlalchemy_access_token_db.get_by_token(
"NOT_EXISTING_TOKEN"
)
access_token_by_token = await access_token_db.get_by_token("NOT_EXISTING_TOKEN")
assert access_token_by_token is None

# Delete token
await sqlalchemy_access_token_db.delete(access_token)
deleted_access_token = await sqlalchemy_access_token_db.get_by_token(
access_token.token
)
await access_token_db.delete(access_token)
deleted_access_token = await access_token_db.get_by_token(access_token.token)
assert deleted_access_token is None


@pytest.mark.asyncio
async def test_insert_existing_token(
sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken],
access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken],
user_id: UUID4,
):
access_token_create = {"token": "TOKEN", "user_id": user_id}
await sqlalchemy_access_token_db.create(access_token_create)
await access_token_db.create(access_token_create)

with pytest.raises(exc.IntegrityError):
await sqlalchemy_access_token_db.create(access_token_create)
await access_token_db.create(access_token_create)
Loading