66import pytest
77from asgi_lifespan import LifespanManager
88from httpx import ASGITransport , AsyncClient
9- from sqlalchemy .ext .asyncio import AsyncSession
9+ from sqlalchemy .ext .asyncio import AsyncConnection , AsyncSession
1010
1111from app import ioc
1212from app .application import build_app
1313
1414
15- @pytest .fixture
15+ @pytest .fixture ( scope = "session" )
1616async def app () -> typing .AsyncIterator [fastapi .FastAPI ]:
1717 app_ = build_app ()
1818 async with LifespanManager (app_ ):
1919 yield app_
2020
2121
22- @pytest .fixture
22+ @pytest .fixture ( scope = "session" )
2323async def client (app : fastapi .FastAPI ) -> typing .AsyncIterator [AsyncClient ]:
2424 async with AsyncClient (
2525 transport = ASGITransport (app = app ),
@@ -28,23 +28,32 @@ async def client(app: fastapi.FastAPI) -> typing.AsyncIterator[AsyncClient]:
2828 yield client
2929
3030
31- @pytest .fixture
31+ @pytest .fixture ( scope = "session" )
3232def di_container (app : fastapi .FastAPI ) -> modern_di .Container :
3333 return modern_di_fastapi .fetch_di_container (app )
3434
3535
36- @pytest .fixture (autouse = True )
37- async def db_session (di_container : modern_di .Container ) -> typing .AsyncIterator [AsyncSession ]:
36+ @pytest .fixture (scope = "session" )
37+ async def db_connection (di_container : modern_di .Container ) -> typing .AsyncIterator [AsyncConnection ]:
3838 engine = await ioc .Dependencies .database_engine .async_resolve (di_container )
39- connection = await engine .connect ()
40- transaction = await connection .begin ()
41- await connection .begin_nested ()
42- ioc .Dependencies .database_engine .override (connection , di_container )
43-
39+ connection : typing .Final = await engine .connect ()
4440 try :
45- yield AsyncSession ( connection , expire_on_commit = False , autoflush = False )
41+ yield connection
4642 finally :
47- if connection .in_transaction ():
48- await transaction .rollback ()
4943 await connection .close ()
5044 await engine .dispose ()
45+
46+
47+ @pytest .fixture (autouse = True )
48+ async def db_session (
49+ db_connection : AsyncConnection , di_container : modern_di .Container
50+ ) -> typing .AsyncIterator [AsyncSession ]:
51+ transaction = await db_connection .begin ()
52+ await db_connection .begin_nested ()
53+ ioc .Dependencies .database_engine .override (db_connection , di_container )
54+
55+ try :
56+ yield AsyncSession (db_connection , expire_on_commit = False , autoflush = False )
57+ finally :
58+ if db_connection .in_transaction ():
59+ await transaction .rollback ()
0 commit comments