66import  pytest 
77from  asgi_lifespan  import  LifespanManager 
88from  httpx  import  ASGITransport , AsyncClient 
9- from  sqlalchemy .ext .asyncio  import  AsyncSession 
9+ from  sqlalchemy .ext .asyncio  import  AsyncConnection ,  AsyncSession 
1010
1111from  app  import  ioc 
1212from  app .application  import  build_app 
@@ -33,18 +33,27 @@ def di_container(app: fastapi.FastAPI) -> modern_di.Container:
3333    return  modern_di_fastapi .fetch_di_container (app )
3434
3535
36- @pytest .fixture (autouse = True ) 
37- async  def  db_session (di_container : modern_di .Container ) ->  typing .AsyncIterator [AsyncSession ]:
36+ @pytest .fixture (scope = "session" ) 
37+ async  def  db_connection (di_container : modern_di .Container ) ->  typing .AsyncIterator [AsyncConnection ]:
3838    engine  =  await  ioc .Dependencies .database_engine .async_resolve (di_container )
39-     connection  =  await  engine .connect ()
40-     transaction  =  await  connection .begin ()
41-     await  connection .begin_nested ()
42-     ioc .Dependencies .database_engine .override (connection , di_container )
43- 
39+     connection : typing .Final  =  await  engine .connect ()
4440    try :
45-         yield  AsyncSession ( connection ,  expire_on_commit = False ,  autoflush = False ) 
41+         yield  connection 
4642    finally :
47-         if  connection .in_transaction ():
48-             await  transaction .rollback ()
4943        await  connection .close ()
5044        await  engine .dispose ()
45+ 
46+ 
47+ @pytest .fixture (autouse = True ) 
48+ async  def  db_session (
49+     db_connection : AsyncConnection , di_container : modern_di .Container 
50+ ) ->  typing .AsyncIterator [AsyncSession ]:
51+     transaction  =  await  db_connection .begin ()
52+     await  db_connection .begin_nested ()
53+     ioc .Dependencies .database_engine .override (db_connection , di_container )
54+ 
55+     try :
56+         yield  AsyncSession (db_connection , expire_on_commit = False , autoflush = False )
57+     finally :
58+         if  db_connection .in_transaction ():
59+             await  transaction .rollback ()
0 commit comments