Skip to content

Commit 7495941

Browse files
GitMarco27DeanChensj
authored andcommitted
feat: full async implementation of DatabaseSessionService
Merge #2889 # Implement Full async DatabaseSessionService **Target Issue:** #1005 ## Overview This PR introduces an asynchronous implementation of the `DatabaseSessionService` with minimal breaking changes. The primary goal is to enable effective use of ADK in fully async environments and API endpoints while avoiding event loop blocking during database I/O operations. ## Changes - Converted `DatabaseSessionService` to use async/await patterns throughout ## Testing Plan The implementation has been tested following the project's contribution guidelines: ### Unit Tests - All existing unit tests pass successfully - Minor update to test requirements added to support `aiosqlite` ### Manual End-to-End Testing - E2E tests performed using: - **LLM Provider:** LiteLLM - **Database:** PostgreSQL with `asyncpg` driver ```python from google.adk.sessions.database_session_service import DatabaseSessionService connection_string: str = ( "postgresql+asyncpg://PG_USER:PG_PSWD@PG_HOST:5432/PG_DB" ) session_service: DatabaseSessionService = DatabaseSessionService( db_url=connection_string ) session = await session_service.create_session( app_name="test_app", session_id="test_session", user_id="test_user" ) assert session is not None sessions = await session_service.list_sessions(app_name="test_app", user_id="test_user") assert len(sessions.sessions) > 0 session = await session_service.get_session( app_name="test_app", session_id="test_session", user_id="test_user" ) assert session is not None await session_service.delete_session( app_name="test_app", session_id="test_session", user_id="test_user" ) assert ( await session_service.get_session( app_name="test_app", session_id="test_session", user_id="test_user" ) is None ) ``` The implementation have been also tested using the following configurations for llm provider and Runner: ```python def get_azure_openai_model(deployment_id: str | None = None) -> LiteLlm: ... if not deployment_id: deployment_id = os.getenv("AZURE_OPENAI_DEPLOYMENT_ID") logger.info(f"Using Azure OpenAI deployment ID: {deployment_id}") return LiteLlm( model=f"azure/{os.getenv('AZURE_OPENAI_DEPLOYMENT_ID')}", stream=True, ) ... @staticmethod def _get_runner(agent: Agent) -> Runner: storage=DatabaseSessionService(db_url=get_pg_connection_string()) return Runner( agent=agent, app_name=APP_NAME, session_service=storage, ) ... async for event in self.runner.run_async( user_id=user_id, session_id=session_id, new_message=content, run_config=( RunConfig( streaming_mode=StreamingMode.SSE, response_modalities=["TEXT"] ) if stream else RunConfig() ), ): last_event = event if stream: yield event ... ``` ## Breaking Changes - Database connection string format may need updates for async drivers Co-authored-by: Shangjie Chen <deanchen@google.com> COPYBARA_INTEGRATE_REVIEW=#2889 from GitMarco27:feature/async_database_session_service e1b1b14 PiperOrigin-RevId: 830525148
1 parent 2443a1b commit 7495941

File tree

2 files changed

+94
-66
lines changed

2 files changed

+94
-66
lines changed

src/google/adk/sessions/database_session_service.py

Lines changed: 93 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import asyncio
1617
import copy
1718
from datetime import datetime
1819
from datetime import timezone
@@ -30,20 +31,21 @@
3031
from sqlalchemy import event
3132
from sqlalchemy import ForeignKeyConstraint
3233
from sqlalchemy import func
34+
from sqlalchemy import select
3335
from sqlalchemy import Text
3436
from sqlalchemy.dialects import mysql
3537
from sqlalchemy.dialects import postgresql
36-
from sqlalchemy.engine import create_engine
37-
from sqlalchemy.engine import Engine
3838
from sqlalchemy.exc import ArgumentError
39+
from sqlalchemy.ext.asyncio import async_sessionmaker
40+
from sqlalchemy.ext.asyncio import AsyncEngine
41+
from sqlalchemy.ext.asyncio import AsyncSession as DatabaseSessionFactory
42+
from sqlalchemy.ext.asyncio import create_async_engine
3943
from sqlalchemy.ext.mutable import MutableDict
4044
from sqlalchemy.inspection import inspect
4145
from sqlalchemy.orm import DeclarativeBase
4246
from sqlalchemy.orm import Mapped
4347
from sqlalchemy.orm import mapped_column
4448
from sqlalchemy.orm import relationship
45-
from sqlalchemy.orm import Session as DatabaseSessionFactory
46-
from sqlalchemy.orm import sessionmaker
4749
from sqlalchemy.schema import MetaData
4850
from sqlalchemy.types import DateTime
4951
from sqlalchemy.types import PickleType
@@ -417,11 +419,11 @@ def __init__(self, db_url: str, **kwargs: Any):
417419
# 2. Create all tables based on schema
418420
# 3. Initialize all properties
419421
try:
420-
db_engine = create_engine(db_url, **kwargs)
422+
db_engine = create_async_engine(db_url, **kwargs)
421423

422424
if db_engine.dialect.name == "sqlite":
423425
# Set sqlite pragma to enable foreign keys constraints
424-
event.listen(db_engine, "connect", set_sqlite_pragma)
426+
event.listen(db_engine.sync_engine, "connect", set_sqlite_pragma)
425427

426428
except Exception as e:
427429
if isinstance(e, ArgumentError):
@@ -440,18 +442,32 @@ def __init__(self, db_url: str, **kwargs: Any):
440442
local_timezone = get_localzone()
441443
logger.info("Local timezone: %s", local_timezone)
442444

443-
self.db_engine: Engine = db_engine
445+
self.db_engine: AsyncEngine = db_engine
444446
self.metadata: MetaData = MetaData()
445-
self.inspector = inspect(self.db_engine)
446447

447448
# DB session factory method
448-
self.database_session_factory: sessionmaker[DatabaseSessionFactory] = (
449-
sessionmaker(bind=self.db_engine)
450-
)
451-
452-
# Uncomment to recreate DB every time
453-
# Base.metadata.drop_all(self.db_engine)
454-
Base.metadata.create_all(self.db_engine)
449+
self.database_session_factory: async_sessionmaker[
450+
DatabaseSessionFactory
451+
] = async_sessionmaker(bind=self.db_engine, expire_on_commit=False)
452+
453+
# Flag to indicate if tables are created
454+
self._tables_created = False
455+
# Lock to ensure thread-safe table creation
456+
self._table_creation_lock = asyncio.Lock()
457+
458+
async def _ensure_tables_created(self):
459+
"""Ensure database tables are created. This is called lazily."""
460+
if self._tables_created:
461+
return
462+
463+
async with self._table_creation_lock:
464+
# Double-check after acquiring the lock
465+
if not self._tables_created:
466+
async with self.db_engine.begin() as conn:
467+
# Uncomment to recreate DB every time
468+
# await conn.run_sync(Base.metadata.drop_all)
469+
await conn.run_sync(Base.metadata.create_all)
470+
self._tables_created = True
455471

456472
@override
457473
async def create_session(
@@ -467,22 +483,25 @@ async def create_session(
467483
# 3. Add the object to the table
468484
# 4. Build the session object with generated id
469485
# 5. Return the session
486+
await self._ensure_tables_created()
487+
async with self.database_session_factory() as sql_session:
470488

471-
with self.database_session_factory() as sql_session:
472-
if session_id and sql_session.get(
489+
if session_id and await sql_session.get(
473490
StorageSession, (app_name, user_id, session_id)
474491
):
475492
raise AlreadyExistsError(
476493
f"Session with id {session_id} already exists."
477494
)
478495
# Fetch app and user states from storage
479-
storage_app_state = sql_session.get(StorageAppState, (app_name))
496+
storage_app_state = await sql_session.get(StorageAppState, (app_name))
497+
storage_user_state = await sql_session.get(
498+
StorageUserState, (app_name, user_id)
499+
)
500+
501+
# Create state tables if not exist
480502
if not storage_app_state:
481503
storage_app_state = StorageAppState(app_name=app_name, state={})
482504
sql_session.add(storage_app_state)
483-
storage_user_state = sql_session.get(
484-
StorageUserState, (app_name, user_id)
485-
)
486505
if not storage_user_state:
487506
storage_user_state = StorageUserState(
488507
app_name=app_name, user_id=user_id, state={}
@@ -509,9 +528,9 @@ async def create_session(
509528
state=session_state,
510529
)
511530
sql_session.add(storage_session)
512-
sql_session.commit()
531+
await sql_session.commit()
513532

514-
sql_session.refresh(storage_session)
533+
await sql_session.refresh(storage_session)
515534

516535
# Merge states for response
517536
merged_state = _merge_state(
@@ -529,39 +548,39 @@ async def get_session(
529548
session_id: str,
530549
config: Optional[GetSessionConfig] = None,
531550
) -> Optional[Session]:
551+
await self._ensure_tables_created()
532552
# 1. Get the storage session entry from session table
533553
# 2. Get all the events based on session id and filtering config
534554
# 3. Convert and return the session
535-
with self.database_session_factory() as sql_session:
536-
storage_session = sql_session.get(
555+
async with self.database_session_factory() as sql_session:
556+
storage_session = await sql_session.get(
537557
StorageSession, (app_name, user_id, session_id)
538558
)
539559
if storage_session is None:
540560
return None
541561

542-
query = sql_session.query(StorageEvent).filter(
543-
StorageEvent.app_name == app_name,
544-
StorageEvent.user_id == user_id,
545-
StorageEvent.session_id == storage_session.id,
562+
stmt = (
563+
select(StorageEvent)
564+
.filter(StorageEvent.app_name == app_name)
565+
.filter(StorageEvent.session_id == storage_session.id)
566+
.filter(StorageEvent.user_id == user_id)
546567
)
547568

548569
if config and config.after_timestamp:
549570
after_dt = datetime.fromtimestamp(config.after_timestamp)
550-
query = query.filter(StorageEvent.timestamp >= after_dt)
551-
552-
storage_events = (
553-
query.order_by(StorageEvent.timestamp.desc())
554-
.limit(
555-
config.num_recent_events
556-
if config and config.num_recent_events
557-
else None
558-
)
559-
.all()
560-
)
571+
stmt = stmt.filter(StorageEvent.timestamp >= after_dt)
572+
573+
stmt = stmt.order_by(StorageEvent.timestamp.desc())
574+
575+
if config and config.num_recent_events:
576+
stmt = stmt.limit(config.num_recent_events)
577+
578+
result = await sql_session.execute(stmt)
579+
storage_events = result.scalars().all()
561580

562581
# Fetch states from storage
563-
storage_app_state = sql_session.get(StorageAppState, (app_name))
564-
storage_user_state = sql_session.get(
582+
storage_app_state = await sql_session.get(StorageAppState, (app_name))
583+
storage_user_state = await sql_session.get(
565584
StorageUserState, (app_name, user_id)
566585
)
567586

@@ -581,32 +600,33 @@ async def get_session(
581600
async def list_sessions(
582601
self, *, app_name: str, user_id: Optional[str] = None
583602
) -> ListSessionsResponse:
584-
with self.database_session_factory() as sql_session:
585-
query = sql_session.query(StorageSession).filter(
586-
StorageSession.app_name == app_name
587-
)
603+
await self._ensure_tables_created()
604+
async with self.database_session_factory() as sql_session:
605+
stmt = select(StorageSession).filter(StorageSession.app_name == app_name)
588606
if user_id is not None:
589-
query = query.filter(StorageSession.user_id == user_id)
590-
results = query.all()
607+
stmt = stmt.filter(StorageSession.user_id == user_id)
608+
609+
result = await sql_session.execute(stmt)
610+
results = result.scalars().all()
591611

592612
# Fetch app state from storage
593-
storage_app_state = sql_session.get(StorageAppState, (app_name))
613+
storage_app_state = await sql_session.get(StorageAppState, (app_name))
594614
app_state = storage_app_state.state if storage_app_state else {}
595615

596616
# Fetch user state(s) from storage
597617
user_states_map = {}
598618
if user_id is not None:
599-
storage_user_state = sql_session.get(
619+
storage_user_state = await sql_session.get(
600620
StorageUserState, (app_name, user_id)
601621
)
602622
if storage_user_state:
603623
user_states_map[user_id] = storage_user_state.state
604624
else:
605-
all_user_states_for_app = (
606-
sql_session.query(StorageUserState)
607-
.filter(StorageUserState.app_name == app_name)
608-
.all()
625+
user_state_stmt = select(StorageUserState).filter(
626+
StorageUserState.app_name == app_name
609627
)
628+
user_state_result = await sql_session.execute(user_state_stmt)
629+
all_user_states_for_app = user_state_result.scalars().all()
610630
for storage_user_state in all_user_states_for_app:
611631
user_states_map[storage_user_state.user_id] = storage_user_state.state
612632

@@ -622,17 +642,19 @@ async def list_sessions(
622642
async def delete_session(
623643
self, app_name: str, user_id: str, session_id: str
624644
) -> None:
625-
with self.database_session_factory() as sql_session:
645+
await self._ensure_tables_created()
646+
async with self.database_session_factory() as sql_session:
626647
stmt = delete(StorageSession).where(
627648
StorageSession.app_name == app_name,
628649
StorageSession.user_id == user_id,
629650
StorageSession.id == session_id,
630651
)
631-
sql_session.execute(stmt)
632-
sql_session.commit()
652+
await sql_session.execute(stmt)
653+
await sql_session.commit()
633654

634655
@override
635656
async def append_event(self, session: Session, event: Event) -> Event:
657+
await self._ensure_tables_created()
636658
if event.partial:
637659
return event
638660

@@ -642,8 +664,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
642664
# 1. Check if timestamp is stale
643665
# 2. Update session attributes based on event config
644666
# 3. Store event to table
645-
with self.database_session_factory() as sql_session:
646-
storage_session = sql_session.get(
667+
async with self.database_session_factory() as sql_session:
668+
storage_session = await sql_session.get(
647669
StorageSession, (session.app_name, session.user_id, session.id)
648670
)
649671

@@ -657,8 +679,10 @@ async def append_event(self, session: Session, event: Event) -> Event:
657679
)
658680

659681
# Fetch states from storage
660-
storage_app_state = sql_session.get(StorageAppState, (session.app_name))
661-
storage_user_state = sql_session.get(
682+
storage_app_state = await sql_session.get(
683+
StorageAppState, (session.app_name)
684+
)
685+
storage_user_state = await sql_session.get(
662686
StorageUserState, (session.app_name, session.user_id)
663687
)
664688

@@ -680,8 +704,8 @@ async def append_event(self, session: Session, event: Event) -> Event:
680704

681705
sql_session.add(StorageEvent.from_event(session, event))
682706

683-
sql_session.commit()
684-
sql_session.refresh(storage_session)
707+
await sql_session.commit()
708+
await sql_session.refresh(storage_session)
685709

686710
# Update timestamp with commit time
687711
session.last_update_time = storage_session.update_timestamp_tz
@@ -691,8 +715,12 @@ async def append_event(self, session: Session, event: Event) -> Event:
691715
return event
692716

693717

694-
def _merge_state(app_state, user_state, session_state):
695-
# Merge states for response
718+
def _merge_state(
719+
app_state: dict[str, Any],
720+
user_state: dict[str, Any],
721+
session_state: dict[str, Any],
722+
) -> dict[str, Any]:
723+
"""Merge app, user, and session states into a single state dictionary."""
696724
merged_state = copy.deepcopy(session_state)
697725
for key in app_state.keys():
698726
merged_state[State.APP_PREFIX + key] = app_state[key]

tests/unittests/sessions/test_session_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_session_service(
3939
):
4040
"""Creates a session service for testing."""
4141
if service_type == SessionServiceType.DATABASE:
42-
return DatabaseSessionService('sqlite:///:memory:')
42+
return DatabaseSessionService('sqlite+aiosqlite:///:memory:')
4343
if service_type == SessionServiceType.SQLITE:
4444
return SqliteSessionService(str(tmp_path / 'sqlite.db'))
4545
return InMemorySessionService()

0 commit comments

Comments
 (0)