From 3a05a69adfdd1a09aba384e53d69e3eb3ed6c09f Mon Sep 17 00:00:00 2001 From: Mike Fotinakis Date: Thu, 20 Jul 2023 13:07:20 -0400 Subject: [PATCH] Add synchronous Session support. --- fastapi_users_db_sqlalchemy/__init__.py | 96 +++++++++++- fastapi_users_db_sqlalchemy/access_token.py | 50 ++++++- tests/conftest.py | 3 + tests/test_access_token.py | 98 ++++++++----- tests/test_users.py | 155 +++++++++++++------- 5 files changed, 308 insertions(+), 94 deletions(-) diff --git a/fastapi_users_db_sqlalchemy/__init__.py b/fastapi_users_db_sqlalchemy/__init__.py index 4373702..b7db6ca 100644 --- a/fastapi_users_db_sqlalchemy/__init__.py +++ b/fastapi_users_db_sqlalchemy/__init__.py @@ -6,7 +6,7 @@ from fastapi_users.models import ID, OAP, UP from sqlalchemy import Boolean, ForeignKey, Integer, String, func, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, declared_attr, mapped_column +from sqlalchemy.orm import Mapped, Session, declared_attr, mapped_column from sqlalchemy.sql import Select from fastapi_users_db_sqlalchemy.generics import GUID @@ -185,3 +185,97 @@ async def update_oauth_account( async def _get_user(self, statement: Select) -> Optional[UP]: results = await self.session.execute(statement) return results.unique().scalar_one_or_none() + + +class SQLAlchemySynchronousUserDatabase(Generic[UP, ID], BaseUserDatabase[UP, ID]): + """ + Database adapter for SQLAlchemy with synchronous session support. + + :param session: SQLAlchemy session instance. + :param user_table: SQLAlchemy user model. + :param oauth_account_table: Optional SQLAlchemy OAuth accounts model. + """ + + session: Session + user_table: Type[UP] + oauth_account_table: Optional[Type[SQLAlchemyBaseOAuthAccountTable]] + + def __init__( + self, + session: Session, + user_table: Type[UP], + oauth_account_table: Optional[Type[SQLAlchemyBaseOAuthAccountTable]] = None, + ): + self.session = session + self.user_table = user_table + self.oauth_account_table = oauth_account_table + + async def get(self, id: ID) -> Optional[UP]: + statement = select(self.user_table).where(self.user_table.id == id) + return await self._get_user(statement) + + async def get_by_email(self, email: str) -> Optional[UP]: + statement = select(self.user_table).where( + func.lower(self.user_table.email) == func.lower(email) + ) + return await self._get_user(statement) + + async def get_by_oauth_account(self, oauth: str, account_id: str) -> Optional[UP]: + if self.oauth_account_table is None: + raise NotImplementedError() + + statement = ( + select(self.user_table) + .join(self.oauth_account_table) + .where(self.oauth_account_table.oauth_name == oauth) # type: ignore + .where(self.oauth_account_table.account_id == account_id) # type: ignore + ) + return await self._get_user(statement) + + async def create(self, create_dict: Dict[str, Any]) -> UP: + user = self.user_table(**create_dict) + self.session.add(user) + self.session.commit() + return user + + async def update(self, user: UP, update_dict: Dict[str, Any]) -> UP: + for key, value in update_dict.items(): + setattr(user, key, value) + self.session.add(user) + self.session.commit() + return user + + async def delete(self, user: UP) -> None: + self.session.delete(user) + self.session.commit() + + async def add_oauth_account(self, user: UP, create_dict: Dict[str, Any]) -> UP: + if self.oauth_account_table is None: + raise NotImplementedError() + + self.session.refresh(user) + oauth_account = self.oauth_account_table(**create_dict) + self.session.add(oauth_account) + user.oauth_accounts.append(oauth_account) # type: ignore + self.session.add(user) + + self.session.commit() + + return user + + async def update_oauth_account( + self, user: UP, oauth_account: OAP, update_dict: Dict[str, Any] + ) -> UP: + if self.oauth_account_table is None: + raise NotImplementedError() + + for key, value in update_dict.items(): + setattr(oauth_account, key, value) + self.session.add(oauth_account) + self.session.commit() + + return user + + async def _get_user(self, statement: Select) -> Optional[UP]: + results = self.session.execute(statement) + return results.unique().scalar_one_or_none() diff --git a/fastapi_users_db_sqlalchemy/access_token.py b/fastapi_users_db_sqlalchemy/access_token.py index 5878818..6956613 100644 --- a/fastapi_users_db_sqlalchemy/access_token.py +++ b/fastapi_users_db_sqlalchemy/access_token.py @@ -6,7 +6,7 @@ from fastapi_users.models import ID from sqlalchemy import ForeignKey, String, select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import Mapped, declared_attr, mapped_column +from sqlalchemy.orm import Mapped, Session, declared_attr, mapped_column from fastapi_users_db_sqlalchemy.generics import GUID, TIMESTAMPAware, now_utc @@ -85,3 +85,51 @@ async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: async def delete(self, access_token: AP) -> None: await self.session.delete(access_token) await self.session.commit() + + +class SQLAlchemySynchronousAccessTokenDatabase(Generic[AP], AccessTokenDatabase[AP]): + """ + Access token database adapter for SQLAlchemy with synchronous session support. + + :param session: SQLAlchemy session instance. + :param access_token_table: SQLAlchemy access token model. + """ + + def __init__( + self, + session: Session, + access_token_table: Type[AP], + ): + self.session = session + self.access_token_table = access_token_table + + async def get_by_token( + self, token: str, max_age: Optional[datetime] = None + ) -> Optional[AP]: + statement = select(self.access_token_table).where( + self.access_token_table.token == token # type: ignore + ) + if max_age is not None: + statement = statement.where( + self.access_token_table.created_at >= max_age # type: ignore + ) + + results = self.session.execute(statement) + return results.scalar_one_or_none() + + async def create(self, create_dict: Dict[str, Any]) -> AP: + access_token = self.access_token_table(**create_dict) + self.session.add(access_token) + self.session.commit() + return access_token + + async def update(self, access_token: AP, update_dict: Dict[str, Any]) -> AP: + for key, value in update_dict.items(): + setattr(access_token, key, value) + self.session.add(access_token) + self.session.commit() + return access_token + + async def delete(self, access_token: AP) -> None: + self.session.delete(access_token) + self.session.commit() diff --git a/tests/conftest.py b/tests/conftest.py index 17d3c79..afb6c98 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,9 @@ DATABASE_URL = os.getenv( "DATABASE_URL", "sqlite+aiosqlite:///./test-sqlalchemy-user.db" ) +SYNC_DATABASE_URL = os.getenv( + "SYNC_DATABASE_URL", "sqlite:///./test-sqlalchemy-user.db" +) class User(schemas.BaseUser): diff --git a/tests/test_access_token.py b/tests/test_access_token.py index df149ff..c81e10a 100644 --- a/tests/test_access_token.py +++ b/tests/test_access_token.py @@ -4,21 +4,22 @@ import pytest from pydantic import UUID4 -from sqlalchemy import exc +from sqlalchemy import Engine, create_engine, exc from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) -from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import DeclarativeBase, sessionmaker from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID from fastapi_users_db_sqlalchemy.access_token import ( SQLAlchemyAccessTokenDatabase, SQLAlchemyBaseAccessTokenTableUUID, + SQLAlchemySynchronousAccessTokenDatabase, ) -from tests.conftest import DATABASE_URL +from tests.conftest import DATABASE_URL, SYNC_DATABASE_URL class Base(DeclarativeBase): @@ -33,98 +34,121 @@ class User(SQLAlchemyBaseUserTableUUID, Base): pass +def create_session_maker(engine: Engine): + return sessionmaker(engine) + + def create_async_session_maker(engine: AsyncEngine): return async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) +def pytest_generate_tests(metafunc): + if "access_token_db" in metafunc.fixturenames: + metafunc.parametrize("access_token_db", ["sync", "async"], indirect=True) + + @pytest.fixture def user_id() -> UUID4: return uuid.uuid4() @pytest.fixture -async def sqlalchemy_access_token_db( +async def access_token_db( + request, user_id: UUID4, ) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase[AccessToken], None]: - engine = create_async_engine(DATABASE_URL) - sessionmaker = create_async_session_maker(engine) + if request.param == "async": + engine = create_async_engine(DATABASE_URL) + sessionmaker = create_async_session_maker(engine) + + async with engine.begin() as connection: + await connection.run_sync(Base.metadata.create_all) + + async with sessionmaker() as session: + user = User( + id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere" + ) + session.add(user) + await session.commit() + + yield SQLAlchemyAccessTokenDatabase(session, AccessToken) + + async with engine.begin() as connection: + await connection.run_sync(Base.metadata.drop_all) + elif request.param == "sync": + engine = create_engine( + SYNC_DATABASE_URL, connect_args={"check_same_thread": False} + ) + sessionmaker = create_session_maker(engine) - async with engine.begin() as connection: - await connection.run_sync(Base.metadata.create_all) + Base.metadata.create_all(bind=engine) - async with sessionmaker() as session: - user = User( - id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere" - ) - session.add(user) - await session.commit() + with sessionmaker() as session: + user = User( + id=user_id, email="lancelot@camelot.bt", hashed_password="guinevere" + ) + session.add(user) + session.commit() - yield SQLAlchemyAccessTokenDatabase(session, AccessToken) + yield SQLAlchemySynchronousAccessTokenDatabase(session, AccessToken) - async with engine.begin() as connection: - await connection.run_sync(Base.metadata.drop_all) + Base.metadata.drop_all(bind=engine) + else: + raise ValueError("invalid internal test config") @pytest.mark.asyncio async def test_queries( - sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], + access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], user_id: UUID4, ): access_token_create = {"token": "TOKEN", "user_id": user_id} # Create - access_token = await sqlalchemy_access_token_db.create(access_token_create) + access_token = await access_token_db.create(access_token_create) assert access_token.token == "TOKEN" assert access_token.user_id == user_id # Update update_dict = {"created_at": datetime.now(timezone.utc)} - updated_access_token = await sqlalchemy_access_token_db.update( - access_token, update_dict - ) + updated_access_token = await access_token_db.update(access_token, update_dict) assert updated_access_token.created_at.replace(microsecond=0) == update_dict[ "created_at" ].replace(microsecond=0) # Get by token - access_token_by_token = await sqlalchemy_access_token_db.get_by_token( - access_token.token - ) + access_token_by_token = await access_token_db.get_by_token(access_token.token) assert access_token_by_token is not None # Get by token expired - access_token_by_token = await sqlalchemy_access_token_db.get_by_token( + access_token_by_token = await access_token_db.get_by_token( access_token.token, max_age=datetime.now(timezone.utc) + timedelta(hours=1) ) assert access_token_by_token is None # Get by token not expired - access_token_by_token = await sqlalchemy_access_token_db.get_by_token( + access_token_by_token = await access_token_db.get_by_token( access_token.token, max_age=datetime.now(timezone.utc) - timedelta(hours=1) ) assert access_token_by_token is not None # Get by token unknown - access_token_by_token = await sqlalchemy_access_token_db.get_by_token( - "NOT_EXISTING_TOKEN" - ) + access_token_by_token = await access_token_db.get_by_token("NOT_EXISTING_TOKEN") assert access_token_by_token is None # Delete token - await sqlalchemy_access_token_db.delete(access_token) - deleted_access_token = await sqlalchemy_access_token_db.get_by_token( - access_token.token - ) + await access_token_db.delete(access_token) + deleted_access_token = await access_token_db.get_by_token(access_token.token) assert deleted_access_token is None @pytest.mark.asyncio async def test_insert_existing_token( - sqlalchemy_access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], + access_token_db: SQLAlchemyAccessTokenDatabase[AccessToken], user_id: UUID4, ): access_token_create = {"token": "TOKEN", "user_id": user_id} - await sqlalchemy_access_token_db.create(access_token_create) + await access_token_db.create(access_token_create) with pytest.raises(exc.IntegrityError): - await sqlalchemy_access_token_db.create(access_token_create) + await access_token_db.create(access_token_create) diff --git a/tests/test_users.py b/tests/test_users.py index 141a93f..8d5ba64 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,26 +1,28 @@ -from typing import Any, AsyncGenerator, Dict, List +from typing import Any, AsyncGenerator, Dict, Generator, List import pytest -from sqlalchemy import String, exc -from sqlalchemy.ext.asyncio import ( - AsyncEngine, - async_sessionmaker, - create_async_engine, -) +from sqlalchemy import Engine, String, create_engine, exc +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlalchemy.orm import ( DeclarativeBase, Mapped, mapped_column, relationship, + sessionmaker, ) from fastapi_users_db_sqlalchemy import ( UUID_ID, SQLAlchemyBaseOAuthAccountTableUUID, SQLAlchemyBaseUserTableUUID, + SQLAlchemySynchronousUserDatabase, SQLAlchemyUserDatabase, ) -from tests.conftest import DATABASE_URL +from tests.conftest import DATABASE_URL, SYNC_DATABASE_URL + + +def create_session_maker(engine: Engine): + return sessionmaker(engine) def create_async_session_maker(engine: AsyncEngine): @@ -50,106 +52,151 @@ class UserOAuth(SQLAlchemyBaseUserTableUUID, OAuthBase): ) +def pytest_generate_tests(metafunc): + if "user_db" in metafunc.fixturenames: + metafunc.parametrize("user_db", ["sync", "async"], indirect=True) + if "user_db_oauth" in metafunc.fixturenames: + metafunc.parametrize("user_db_oauth", ["sync", "async"], indirect=True) + + @pytest.fixture -async def sqlalchemy_user_db() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: - engine = create_async_engine(DATABASE_URL) - sessionmaker = create_async_session_maker(engine) +async def user_db( + request, +) -> AsyncGenerator[SQLAlchemyUserDatabase, None] | Generator[ + SQLAlchemySynchronousUserDatabase, None, None +]: + if request.param == "async": + engine = create_async_engine(DATABASE_URL) + sessionmaker = create_async_session_maker(engine) + + async with engine.begin() as connection: + await connection.run_sync(Base.metadata.create_all) - async with engine.begin() as connection: - await connection.run_sync(Base.metadata.create_all) + async with sessionmaker() as session: + yield SQLAlchemyUserDatabase(session, User) - async with sessionmaker() as session: - yield SQLAlchemyUserDatabase(session, User) + async with engine.begin() as connection: + await connection.run_sync(Base.metadata.drop_all) + elif request.param == "sync": + engine = create_engine( + SYNC_DATABASE_URL, connect_args={"check_same_thread": False} + ) + sessionmaker = create_session_maker(engine) - async with engine.begin() as connection: - await connection.run_sync(Base.metadata.drop_all) + Base.metadata.create_all(bind=engine) + + with sessionmaker() as session: + yield SQLAlchemySynchronousUserDatabase(session, User) + + Base.metadata.drop_all(bind=engine) + else: + raise ValueError("invalid internal test config") @pytest.fixture -async def sqlalchemy_user_db_oauth() -> AsyncGenerator[SQLAlchemyUserDatabase, None]: - engine = create_async_engine(DATABASE_URL) - sessionmaker = create_async_session_maker(engine) +async def user_db_oauth( + request, +) -> AsyncGenerator[SQLAlchemyUserDatabase, None] | Generator[ + SQLAlchemySynchronousUserDatabase, None, None +]: + if request.param == "async": + engine = create_async_engine(DATABASE_URL) + sessionmaker = create_async_session_maker(engine) + + async with engine.begin() as connection: + await connection.run_sync(OAuthBase.metadata.create_all) - async with engine.begin() as connection: - await connection.run_sync(OAuthBase.metadata.create_all) + async with sessionmaker() as session: + yield SQLAlchemyUserDatabase(session, UserOAuth, OAuthAccount) - async with sessionmaker() as session: - yield SQLAlchemyUserDatabase(session, UserOAuth, OAuthAccount) + async with engine.begin() as connection: + await connection.run_sync(OAuthBase.metadata.drop_all) + elif request.param == "sync": + engine = create_engine( + SYNC_DATABASE_URL, connect_args={"check_same_thread": False} + ) + sessionmaker = create_session_maker(engine) - async with engine.begin() as connection: - await connection.run_sync(OAuthBase.metadata.drop_all) + OAuthBase.metadata.create_all(bind=engine) + + with sessionmaker() as session: + yield SQLAlchemySynchronousUserDatabase(session, UserOAuth, OAuthAccount) + + OAuthBase.metadata.drop_all(bind=engine) + else: + raise ValueError("invalid internal test config") @pytest.mark.asyncio -async def test_queries(sqlalchemy_user_db: SQLAlchemyUserDatabase[User, UUID_ID]): +async def test_queries(user_db: SQLAlchemyUserDatabase[User, UUID_ID]): user_create = { "email": "lancelot@camelot.bt", "hashed_password": "guinevere", } # Create - user = await sqlalchemy_user_db.create(user_create) + user = await user_db.create(user_create) assert user.id is not None assert user.is_active is True assert user.is_superuser is False assert user.email == user_create["email"] # Update - updated_user = await sqlalchemy_user_db.update(user, {"is_superuser": True}) + updated_user = await user_db.update(user, {"is_superuser": True}) assert updated_user.is_superuser is True # Get by id - id_user = await sqlalchemy_user_db.get(user.id) + id_user = await user_db.get(user.id) assert id_user is not None assert id_user.id == user.id assert id_user.is_superuser is True # Get by email - email_user = await sqlalchemy_user_db.get_by_email(str(user_create["email"])) + email_user = await user_db.get_by_email(str(user_create["email"])) assert email_user is not None assert email_user.id == user.id # Get by uppercased email - email_user = await sqlalchemy_user_db.get_by_email("Lancelot@camelot.bt") + email_user = await user_db.get_by_email("Lancelot@camelot.bt") assert email_user is not None assert email_user.id == user.id # Unknown user - unknown_user = await sqlalchemy_user_db.get_by_email("galahad@camelot.bt") + unknown_user = await user_db.get_by_email("galahad@camelot.bt") assert unknown_user is None # Delete user - await sqlalchemy_user_db.delete(user) - deleted_user = await sqlalchemy_user_db.get(user.id) + await user_db.delete(user) + deleted_user = await user_db.get(user.id) assert deleted_user is None # OAuth without defined table with pytest.raises(NotImplementedError): - await sqlalchemy_user_db.get_by_oauth_account("foo", "bar") + await user_db.get_by_oauth_account("foo", "bar") with pytest.raises(NotImplementedError): - await sqlalchemy_user_db.add_oauth_account(user, {}) + await user_db.add_oauth_account(user, {}) with pytest.raises(NotImplementedError): oauth_account = OAuthAccount() - await sqlalchemy_user_db.update_oauth_account(user, oauth_account, {}) + await user_db.update_oauth_account(user, oauth_account, {}) @pytest.mark.asyncio async def test_insert_existing_email( - sqlalchemy_user_db: SQLAlchemyUserDatabase[User, UUID_ID], + user_db: SQLAlchemyUserDatabase[User, UUID_ID], ): user_create = { "email": "lancelot@camelot.bt", "hashed_password": "guinevere", } - await sqlalchemy_user_db.create(user_create) + await user_db.create(user_create) with pytest.raises(exc.IntegrityError): - await sqlalchemy_user_db.create(user_create) + await user_db.create(user_create) @pytest.mark.asyncio async def test_queries_custom_fields( - sqlalchemy_user_db: SQLAlchemyUserDatabase[User, UUID_ID], + user_db: SQLAlchemyUserDatabase[User, UUID_ID], ): """It should output custom fields in query result.""" user_create = { @@ -157,9 +204,9 @@ async def test_queries_custom_fields( "hashed_password": "guinevere", "first_name": "Lancelot", } - user = await sqlalchemy_user_db.create(user_create) + user = await user_db.create(user_create) - id_user = await sqlalchemy_user_db.get(user.id) + id_user = await user_db.get(user.id) assert id_user is not None assert id_user.id == user.id assert id_user.first_name == user.first_name @@ -167,7 +214,7 @@ async def test_queries_custom_fields( @pytest.mark.asyncio async def test_queries_oauth( - sqlalchemy_user_db_oauth: SQLAlchemyUserDatabase[UserOAuth, UUID_ID], + user_db_oauth: SQLAlchemyUserDatabase[UserOAuth, UUID_ID], oauth_account1: Dict[str, Any], oauth_account2: Dict[str, Any], ): @@ -177,43 +224,41 @@ async def test_queries_oauth( } # Create - user = await sqlalchemy_user_db_oauth.create(user_create) + user = await user_db_oauth.create(user_create) assert user.id is not None # Add OAuth account - user = await sqlalchemy_user_db_oauth.add_oauth_account(user, oauth_account1) - user = await sqlalchemy_user_db_oauth.add_oauth_account(user, oauth_account2) + user = await user_db_oauth.add_oauth_account(user, oauth_account1) + user = await user_db_oauth.add_oauth_account(user, oauth_account2) assert len(user.oauth_accounts) == 2 assert user.oauth_accounts[1].account_id == oauth_account2["account_id"] assert user.oauth_accounts[0].account_id == oauth_account1["account_id"] # Update - user = await sqlalchemy_user_db_oauth.update_oauth_account( + user = await user_db_oauth.update_oauth_account( user, user.oauth_accounts[0], {"access_token": "NEW_TOKEN"} ) assert user.oauth_accounts[0].access_token == "NEW_TOKEN" # Get by id - id_user = await sqlalchemy_user_db_oauth.get(user.id) + id_user = await user_db_oauth.get(user.id) assert id_user is not None assert id_user.id == user.id assert id_user.oauth_accounts[0].access_token == "NEW_TOKEN" # Get by email - email_user = await sqlalchemy_user_db_oauth.get_by_email(user_create["email"]) + email_user = await user_db_oauth.get_by_email(user_create["email"]) assert email_user is not None assert email_user.id == user.id assert len(email_user.oauth_accounts) == 2 # Get by OAuth account - oauth_user = await sqlalchemy_user_db_oauth.get_by_oauth_account( + oauth_user = await user_db_oauth.get_by_oauth_account( oauth_account1["oauth_name"], oauth_account1["account_id"] ) assert oauth_user is not None assert oauth_user.id == user.id # Unknown OAuth account - unknown_oauth_user = await sqlalchemy_user_db_oauth.get_by_oauth_account( - "foo", "bar" - ) + unknown_oauth_user = await user_db_oauth.get_by_oauth_account("foo", "bar") assert unknown_oauth_user is None