diff --git a/backend/app/core/config/settings.py b/backend/app/core/config/settings.py index 1349a02f..270d7779 100644 --- a/backend/app/core/config/settings.py +++ b/backend/app/core/config/settings.py @@ -19,6 +19,7 @@ class Settings(BaseSettings): # DB configuration supabase_url: str supabase_key: str + database_url: Optional[str] = None # LangSmith Tracing langsmith_tracing: bool = False diff --git a/backend/app/database/core.py b/backend/app/database/core.py new file mode 100644 index 00000000..72275066 --- /dev/null +++ b/backend/app/database/core.py @@ -0,0 +1,54 @@ +from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession +from typing import AsyncGenerator +from app.core.config import settings +import logging + +logger = logging.getLogger(__name__) + +# Database configuration +DATABASE_URL = settings.database_url + +if not DATABASE_URL: + logger.warning("DATABASE_URL is not set. Database connection pooling will not be available.") + # Fallback or strict error depending on requirements. + # For now ensuring tests/code doesn't crash on import if env missing during build. + # But initialization should be guarded. + +# Initialize SQLAlchemy Async Engine with pooling +# If DATABASE_URL is missing, engine will be None and get_db will fail/error out appropriately when called +engine = create_async_engine( + DATABASE_URL, + echo=False, + pool_size=20, # Maintain 20 open connections + max_overflow=10, # Allow 10 extra during spikes + pool_timeout=30, # Wait 30s for a connection before raising timeout + pool_pre_ping=True, # Check connection health before handing it out +) if DATABASE_URL else None + +# Session Factory +async_session_maker = async_sessionmaker( + engine, + class_=AsyncSession, + expire_on_commit=False, + autocommit=False, + autoflush=False, +) if engine else None + + +async def get_db() -> AsyncGenerator[AsyncSession, None]: + """ + Dependency to provide a thread-safe database session. + Ensures that the session is closed after the request is processed. + """ + if not async_session_maker: + raise RuntimeError("Database engine is not initialized. check DATABASE_URL.") + + async with async_session_maker() as session: + try: + yield session + # automatic commit/rollback is often handled by caller or service layer logic + except Exception: + logger.exception("Database session error") + await session.rollback() + raise + # session.close() is handled automatically by the async context manager diff --git a/backend/requirements.txt b/backend/requirements.txt index 59827539..18ebc755 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -12,6 +12,8 @@ attrs==25.3.0 auth0-python==4.9.0 Authlib==1.3.1 autoflake==2.3.1 +asyncpg==0.29.0 +pytest-asyncio==0.23.5 autopep8==2.3.2 backoff==2.2.1 bcrypt==4.3.0 diff --git a/docs/DATABASE_CONNECTION.md b/docs/DATABASE_CONNECTION.md new file mode 100644 index 00000000..7aa4cf06 --- /dev/null +++ b/docs/DATABASE_CONNECTION.md @@ -0,0 +1,49 @@ +# Database Connection Management + +This document describes the thread-safe database connection management implemented for the Devr.AI backend. + +## Overview + +We use **SQLAlchemy** (AsyncIO) with **asyncpg** to manage a pool of connections to the Supabase PostgreSQL database. This allows for high-concurrency operations without the limitations of HTTP-based PostgREST calls (which `supabase-py` wraps). + +## Configuration + +The connection manager reads the `DATABASE_URL` from the application settings (loaded from `.env`). + +```env +DATABASE_URL=postgresql+asyncpg://user:password@host:5432/dbname +``` + +## Key Components + +### 1. Engine & Pooling +Located in `app/database/core.py`. +- **Pool Size**: 20 connections maintained open. +- **Max Overflow**: 10 temporary connections allowed during high load. +- **Pool Timeout**: 30 seconds wait time before raising an error. +- **Pre-Ping**: Checked before checkout to ensure connection health. + +### 2. Dependency Injection +Use `get_db` in FastAPI routes or other async functions to get a session. + +```python +from app.database.core import get_db +from sqlalchemy import text + +@router.get("/items") +async def read_items(db: AsyncSession = Depends(get_db)): + result = await db.execute(text("SELECT * FROM items")) + return result.mappings().all() +``` + +The `get_db` generator ensures: +- A session is created from the pool. +- The session is passed to the function. +- The session is **automatically closed** after the function completes (even on error). +- If an error occurs, the transaction is rolled back. + +## Testing +Unit tests in `tests/test_db_pool.py` verify: +- Pool configuration. +- Concurrent session acquisition (simulating 50+ parallel requests). +- Proper cleanup (rollback and close) on errors. diff --git a/pyproject.toml b/pyproject.toml index 3d225571..821e56d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ dependencies = [ "pygit2 (>=1.18.2,<2.0.0)", "toml (>=0.10.2,<0.11.0)", "websockets (>=15.0.1,<16.0.0)", + "sqlalchemy (>=2.0.25,<3.0.0)", + "asyncpg (>=0.29.0,<1.0.0)", ] [tool.poetry] diff --git a/tests/test_db_pool.py b/tests/test_db_pool.py new file mode 100644 index 00000000..77e27af1 --- /dev/null +++ b/tests/test_db_pool.py @@ -0,0 +1,96 @@ +from sqlalchemy.ext.asyncio import create_async_engine +import pytest +import asyncio +from unittest.mock import MagicMock, patch, AsyncMock +import importlib +from app.core import config + +# We need to reload the module to pick up the patched settings because +# engine is created at module level in app.database.core. + +@pytest.fixture +def mock_db_module(): + """ + Fixture to reload app.database.core with patched settings. + """ + # Patch the settings object where it is defined or imported + with patch("app.core.config.settings") as mock_settings: + mock_settings.database_url = "postgresql+asyncpg://user:password@localhost:5432/testdb" + + # Reload the module so 'engine' is recreated with the new settings + import app.database.core + importlib.reload(app.database.core) + + yield app.database.core + +@pytest.mark.asyncio +async def test_connection_pooling_configuration(mock_db_module): + """ + Verify that the engine is configured with the expected pool size. + """ + engine = mock_db_module.engine + + if engine: + assert engine.pool.size() == 20 + assert engine.pool.timeout() == 30 + else: + pytest.fail("Engine not initialized") + +@pytest.mark.asyncio +async def test_concurrent_session_acquisition(mock_db_module): + """ + Simulate high concurrency to ensure sessions can be acquired without error. + """ + # Mock the session + mock_session = MagicMock() + # Ensure close and rollback return awaitables (Futures) + mock_session.close = MagicMock(return_value=asyncio.Future()) + mock_session.close.return_value.set_result(None) + mock_session.rollback = MagicMock(return_value=asyncio.Future()) + mock_session.rollback.return_value.set_result(None) + + # Mock async context manager: __aenter__ returns session, __aexit__ returns None (awaitable) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + # Patch the async_session_maker in the RELOADED module + with patch.object(mock_db_module, "async_session_maker", return_value=mock_session) as mock_maker: + + async def task(): + # Use the get_db from the RELOADED module + async for session in mock_db_module.get_db(): + # Simulate some work + await asyncio.sleep(0.01) + return True + + # Run 50 concurrent tasks + results = await asyncio.gather(*[task() for _ in range(50)]) + + assert all(results) + # Verify correct number of calls + assert mock_maker.call_count == 50 + + # Automatic closing is handled by the context manager, which calls __aexit__ + assert mock_session.__aexit__.call_count == 50 + +@pytest.mark.asyncio +async def test_session_rollback_on_error(mock_db_module): + """ + Ensure rollback is called if an exception occurs during session usage. + """ + mock_session = MagicMock() + mock_session.rollback = MagicMock(return_value=asyncio.Future()) + mock_session.rollback.return_value.set_result(None) + + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + + with patch.object(mock_db_module, "async_session_maker", return_value=mock_session): + with pytest.raises(ValueError): + async for session in mock_db_module.get_db(): + raise ValueError("Simulated Error") + + # Verify rollback was called once + assert mock_session.rollback.call_count == 1 + # Verify exit was called (which would handle cleanup in a real scenario) + assert mock_session.__aexit__.call_count == 1