diff --git a/src/admin/blueprints/tenants.py b/src/admin/blueprints/tenants.py index 9cf12be7b..f94970393 100644 --- a/src/admin/blueprints/tenants.py +++ b/src/admin/blueprints/tenants.py @@ -127,7 +127,10 @@ def tenant_settings(tenant_id, section=None): """ try: with get_db_session() as db_session: - tenant = db_session.query(Tenant).filter_by(tenant_id=tenant_id).first() + from sqlalchemy import select + + stmt = select(Tenant).filter_by(tenant_id=tenant_id) + tenant = db_session.scalars(stmt).first() if not tenant: flash("Tenant not found", "error") return redirect(url_for("core.index")) @@ -150,7 +153,8 @@ def tenant_settings(tenant_id, section=None): # Get advertiser data for the advertisers section from src.core.database.models import Principal - principals = db_session.query(Principal).filter_by(tenant_id=tenant_id).all() + stmt = select(Principal).filter_by(tenant_id=tenant_id) + principals = db_session.scalars(stmt).all() advertiser_count = len(principals) active_advertisers = len(principals) # For now, assume all are active @@ -179,7 +183,8 @@ def tenant_settings(tenant_id, section=None): # Get product counts from src.core.database.models import Product - products = db_session.query(Product).filter_by(tenant_id=tenant_id).all() + stmt = select(Product).filter_by(tenant_id=tenant_id) + products = db_session.scalars(stmt).all() product_count = len(products) active_products = len([p for p in products if p.status == "active"]) draft_products = len([p for p in products if p.status == "draft"]) @@ -187,7 +192,8 @@ def tenant_settings(tenant_id, section=None): # Get creative formats from src.core.database.models import CreativeFormat - creative_formats = db_session.query(CreativeFormat).filter_by(tenant_id=tenant_id).all() + stmt = select(CreativeFormat).filter_by(tenant_id=tenant_id) + creative_formats = db_session.scalars(stmt).all() # Get admin port admin_port = int(os.environ.get("ADMIN_UI_PORT", 8001)) diff --git a/src/admin/services/dashboard_service.py b/src/admin/services/dashboard_service.py index 72a522410..a5fae440f 100644 --- a/src/admin/services/dashboard_service.py +++ b/src/admin/services/dashboard_service.py @@ -8,8 +8,6 @@ import logging from datetime import UTC, datetime, timedelta -from sqlalchemy.orm import joinedload - from src.admin.services.business_activity_service import get_business_activities from src.admin.services.media_buy_readiness_service import MediaBuyReadinessService from src.core.database.database_session import get_db_session @@ -35,7 +33,10 @@ def get_tenant(self) -> Tenant | None: """Get tenant object, cached for this service instance.""" if self._tenant is None: with get_db_session() as db_session: - self._tenant = db_session.query(Tenant).filter_by(tenant_id=self.tenant_id).first() + from sqlalchemy import select + + stmt = select(Tenant).filter_by(tenant_id=self.tenant_id) + self._tenant = db_session.scalars(stmt).first() return self._tenant def get_dashboard_metrics(self) -> dict[str, any]: @@ -55,16 +56,22 @@ def get_dashboard_metrics(self) -> dict[str, any]: readiness_summary = MediaBuyReadinessService.get_tenant_readiness_summary(self.tenant_id) # Core business metrics - principals_count = db_session.query(Principal).filter_by(tenant_id=self.tenant_id).count() - products_count = db_session.query(Product).filter_by(tenant_id=self.tenant_id).count() + from sqlalchemy import func, select + + principals_count = db_session.scalar( + select(func.count()).select_from(Principal).where(Principal.tenant_id == self.tenant_id) + ) + products_count = db_session.scalar( + select(func.count()).select_from(Product).where(Product.tenant_id == self.tenant_id) + ) # Calculate total spend from live and completed media buys - total_spend_buys = ( - db_session.query(MediaBuy) + stmt = ( + select(MediaBuy) .filter_by(tenant_id=self.tenant_id) - .filter(MediaBuy.status.in_(["active", "completed"])) - .all() + .where(MediaBuy.status.in_(["active", "completed"])) ) + total_spend_buys = db_session.scalars(stmt).all() total_spend_amount = float(sum(buy.budget or 0 for buy in total_spend_buys)) # Revenue trend data (last 30 days) @@ -124,15 +131,18 @@ def get_recent_media_buys(self, limit: int = 10) -> list[MediaBuy]: """Get recent media buys with relationships loaded and readiness state.""" try: with get_db_session() as db_session: - recent_buys = ( - db_session.query(MediaBuy) - .filter(MediaBuy.tenant_id == self.tenant_id) - .filter(MediaBuy.media_buy_id.isnot(None)) # Defensive: ensure valid ID + from sqlalchemy import select + from sqlalchemy.orm import joinedload + + stmt = ( + select(MediaBuy) + .where(MediaBuy.tenant_id == self.tenant_id) + .where(MediaBuy.media_buy_id.isnot(None)) # Defensive: ensure valid ID .options(joinedload(MediaBuy.principal)) # Eager load to avoid N+1 .order_by(MediaBuy.created_at.desc()) .limit(limit) - .all() ) + recent_buys = db_session.scalars(stmt).all() # Transform for template consumption for media_buy in recent_buys: @@ -171,14 +181,16 @@ def _calculate_revenue_trend(self, db_session, days: int = 30) -> list[dict[str, date = today - timedelta(days=days - 1 - i) # Calculate revenue for this date - daily_buys = ( - db_session.query(MediaBuy) + from sqlalchemy import select + + stmt = ( + select(MediaBuy) .filter_by(tenant_id=self.tenant_id) - .filter(MediaBuy.start_date <= date) - .filter(MediaBuy.end_date >= date) - .filter(MediaBuy.status.in_(["active", "completed"])) - .all() + .where(MediaBuy.start_date <= date) + .where(MediaBuy.end_date >= date) + .where(MediaBuy.status.in_(["active", "completed"])) ) + daily_buys = db_session.scalars(stmt).all() daily_revenue = 0 for buy in daily_buys: diff --git a/src/core/main.py b/src/core/main.py index 2b96c94c9..35d26bf73 100644 --- a/src/core/main.py +++ b/src/core/main.py @@ -12,6 +12,7 @@ from fastmcp.server.context import Context from fastmcp.server.dependencies import get_http_headers from rich.console import Console +from sqlalchemy import select from src.adapters.google_ad_manager import GoogleAdManager from src.adapters.kevel import Kevel @@ -163,12 +164,14 @@ def get_principal_from_token(token: str, tenant_id: str | None = None) -> str | if tenant_id: # If tenant_id specified, ONLY look in that tenant console.print(f"[blue]Searching for principal in tenant '{tenant_id}'[/blue]") - principal = session.query(ModelPrincipal).filter_by(access_token=token, tenant_id=tenant_id).first() + stmt = select(ModelPrincipal).filter_by(access_token=token, tenant_id=tenant_id) + principal = session.scalars(stmt).first() if not principal: console.print(f"[yellow]No principal found in tenant '{tenant_id}', checking admin token[/yellow]") # Also check if it's the admin token for this specific tenant - tenant = session.query(Tenant).filter_by(tenant_id=tenant_id, is_active=True).first() + stmt = select(Tenant).filter_by(tenant_id=tenant_id, is_active=True) + tenant = session.scalars(stmt).first() if tenant and token == tenant.admin_token: console.print(f"[green]Token matches admin token for tenant '{tenant_id}'[/green]") # Set tenant context for admin token @@ -198,7 +201,8 @@ def get_principal_from_token(token: str, tenant_id: str | None = None) -> str | else: # No tenant specified - search globally by token console.print("[blue]No tenant specified - searching globally by token[/blue]") - principal = session.query(ModelPrincipal).filter_by(access_token=token).first() + stmt = select(ModelPrincipal).filter_by(access_token=token) + principal = session.scalars(stmt).first() if not principal: console.print("[red]No principal found with this token globally[/red]") @@ -209,14 +213,16 @@ def get_principal_from_token(token: str, tenant_id: str | None = None) -> str | ) # CRITICAL: Validate the tenant exists and is active before proceeding - tenant_check = session.query(Tenant).filter_by(tenant_id=principal.tenant_id, is_active=True).first() + stmt = select(Tenant).filter_by(tenant_id=principal.tenant_id, is_active=True) + tenant_check = session.scalars(stmt).first() if not tenant_check: console.print(f"[red]Tenant '{principal.tenant_id}' is inactive or deleted[/red]") # Tenant is disabled or deleted - fail securely return None # Get the tenant for this principal and set it as current context - tenant = session.query(Tenant).filter_by(tenant_id=principal.tenant_id, is_active=True).first() + stmt = select(Tenant).filter_by(tenant_id=principal.tenant_id, is_active=True) + tenant = session.scalars(stmt).first() if tenant: tenant_dict = { "tenant_id": tenant.tenant_id, @@ -381,9 +387,8 @@ def get_principal_adapter_mapping(principal_id: str) -> dict[str, Any]: """Get the platform mappings for a principal.""" tenant = get_current_tenant() with get_db_session() as session: - principal = ( - session.query(ModelPrincipal).filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]).first() - ) + stmt = select(ModelPrincipal).filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]) + principal = session.scalars(stmt).first() return principal.platform_mappings if principal else {} @@ -391,9 +396,8 @@ def get_principal_object(principal_id: str) -> Principal | None: """Get a Principal object for the given principal_id.""" tenant = get_current_tenant() with get_db_session() as session: - principal = ( - session.query(ModelPrincipal).filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]).first() - ) + stmt = select(ModelPrincipal).filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]) + principal = session.scalars(stmt).first() if principal: return Principal( @@ -430,7 +434,8 @@ def get_adapter(principal: Principal, dry_run: bool = False, testing_context=Non # Get adapter config from adapter_config table with get_db_session() as session: - config_row = session.query(AdapterConfig).filter_by(tenant_id=tenant["tenant_id"]).first() + stmt = select(AdapterConfig).filter_by(tenant_id=tenant["tenant_id"]) + config_row = session.scalars(stmt).first() adapter_config = {"enabled": True} if config_row: @@ -744,11 +749,8 @@ def log_tool_activity(context: Context, tool_name: str, start_time: float = None if principal_id: with get_db_session() as session: - principal = ( - session.query(ModelPrincipal) - .filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]) - .first() - ) + stmt = select(ModelPrincipal).filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]) + principal = session.scalars(stmt).first() if principal: principal_name = principal.name @@ -1224,14 +1226,10 @@ def _list_creative_formats_impl( from src.core.schemas import AssetRequirement, Format # Get formats for this tenant (or global formats) - db_formats = ( - session.query(CreativeFormat) - .filter( - (CreativeFormat.tenant_id == tenant["tenant_id"]) - | (CreativeFormat.tenant_id.is_(None)) # Global formats - ) - .all() + stmt = select(CreativeFormat).where( + (CreativeFormat.tenant_id == tenant["tenant_id"]) | (CreativeFormat.tenant_id.is_(None)) # Global formats ) + db_formats = session.scalars(stmt).all() for db_format in db_formats: # Convert database model to schema format @@ -1512,11 +1510,10 @@ def _sync_creatives_impl( if creative.get("creative_id"): from src.core.database.models import Creative as DBCreative - existing_creative = ( - session.query(DBCreative) - .filter_by(tenant_id=tenant["tenant_id"], creative_id=creative.get("creative_id")) - .first() + stmt = select(DBCreative).filter_by( + tenant_id=tenant["tenant_id"], creative_id=creative.get("creative_id") ) + existing_creative = session.scalars(stmt).first() if existing_creative: # Update existing creative (respects patch vs full upsert) @@ -1672,7 +1669,8 @@ def _sync_creatives_impl( for package_id in package_ids: # Find which media buy this package belongs to # Packages are stored in media_buy.raw_request["packages"] - media_buys = session.query(MediaBuy).filter_by(tenant_id=tenant["tenant_id"]).all() + stmt = select(MediaBuy).filter_by(tenant_id=tenant["tenant_id"]) + media_buys = session.scalars(stmt).all() media_buy_id = None for mb in media_buys: @@ -1793,11 +1791,10 @@ def _sync_creatives_impl( with get_db_session() as session: from src.core.database.models import Creative as DBCreative - db_creative = ( - session.query(DBCreative) - .filter_by(tenant_id=tenant["tenant_id"], creative_id=creative_dict.get("creative_id")) - .first() + stmt = select(DBCreative).filter_by( + tenant_id=tenant["tenant_id"], creative_id=creative_dict.get("creative_id") ) + db_creative = session.scalars(stmt).first() if db_creative: # Create schema object with populated internal fields # Using aliased field names for construction @@ -1992,47 +1989,49 @@ def _list_creatives_impl( from src.core.database.models import MediaBuy # Build query - query = session.query(DBCreative).filter_by(tenant_id=tenant["tenant_id"]) + stmt = select(DBCreative).filter_by(tenant_id=tenant["tenant_id"]) # Apply filters if req.media_buy_id: # Filter by media buy assignments - query = query.join(DBAssignment, DBCreative.creative_id == DBAssignment.creative_id).filter( + stmt = stmt.join(DBAssignment, DBCreative.creative_id == DBAssignment.creative_id).where( DBAssignment.media_buy_id == req.media_buy_id ) if req.buyer_ref: # Filter by buyer_ref through media buy - query = ( - query.join(DBAssignment, DBCreative.creative_id == DBAssignment.creative_id) + stmt = ( + stmt.join(DBAssignment, DBCreative.creative_id == DBAssignment.creative_id) .join(MediaBuy, DBAssignment.media_buy_id == MediaBuy.media_buy_id) - .filter(MediaBuy.buyer_ref == req.buyer_ref) + .where(MediaBuy.buyer_ref == req.buyer_ref) ) if req.status: - query = query.filter(DBCreative.status == req.status) + stmt = stmt.where(DBCreative.status == req.status) if req.format: - query = query.filter(DBCreative.format == req.format) + stmt = stmt.where(DBCreative.format == req.format) if req.tags: # Simple tag filtering - in production, might use JSON operators for tag in req.tags: - query = query.filter(DBCreative.name.contains(tag)) # Simplified + stmt = stmt.where(DBCreative.name.contains(tag)) # Simplified if req.created_after: - query = query.filter(DBCreative.created_at >= req.created_after) + stmt = stmt.where(DBCreative.created_at >= req.created_after) if req.created_before: - query = query.filter(DBCreative.created_at <= req.created_before) + stmt = stmt.where(DBCreative.created_at <= req.created_before) if req.search: # Search in name and description search_term = f"%{req.search}%" - query = query.filter(DBCreative.name.ilike(search_term)) + stmt = stmt.where(DBCreative.name.ilike(search_term)) # Get total count before pagination - total_count = query.count() + from sqlalchemy import func + + total_count = session.scalar(select(func.count()).select_from(stmt.subquery())) # Apply sorting if req.sort_by == "name": @@ -2043,13 +2042,13 @@ def _list_creatives_impl( sort_column = DBCreative.created_at if req.sort_order == "asc": - query = query.order_by(sort_column.asc()) + stmt = stmt.order_by(sort_column.asc()) else: - query = query.order_by(sort_column.desc()) + stmt = stmt.order_by(sort_column.desc()) # Apply pagination offset = (req.page - 1) * req.limit - db_creatives = query.offset(offset).limit(req.limit).all() + db_creatives = session.scalars(stmt.offset(offset).limit(req.limit)).all() # Convert to schema objects for db_creative in db_creatives: @@ -2485,7 +2484,7 @@ def _list_authorized_properties_impl( try: with get_db_session() as session: # Query authorized properties for this tenant - query = session.query(AuthorizedProperty).filter(AuthorizedProperty.tenant_id == tenant_id) + stmt = select(AuthorizedProperty).where(AuthorizedProperty.tenant_id == tenant_id) # Apply tag filtering if requested if req.tags: @@ -2493,12 +2492,12 @@ def _list_authorized_properties_impl( tag_filters = [] for tag in req.tags: tag_filters.append(AuthorizedProperty.tags.contains([tag])) - query = query.filter(sa.or_(*tag_filters)) + stmt = stmt.where(sa.or_(*tag_filters)) # Only include verified properties - query = query.filter(AuthorizedProperty.verification_status == "verified") + stmt = stmt.where(AuthorizedProperty.verification_status == "verified") - authorized_properties = query.all() + authorized_properties = session.scalars(stmt).all() # Convert database models to Pydantic models properties = [] @@ -2526,11 +2525,8 @@ def _list_authorized_properties_impl( # Get tag metadata for all referenced tags tag_metadata = {} if all_tags: - property_tags = ( - session.query(PropertyTag) - .filter(PropertyTag.tenant_id == tenant_id, PropertyTag.tag_id.in_(all_tags)) - .all() - ) + stmt = select(PropertyTag).where(PropertyTag.tenant_id == tenant_id, PropertyTag.tag_id.in_(all_tags)) + property_tags = session.scalars(stmt).all() for tag in property_tags: tag_metadata[tag.tag_id] = PropertyTagMetadata(name=tag.name, description=tag.description) @@ -2889,7 +2885,8 @@ def _create_media_buy_impl( # Persist the auto-generated config to database with get_db_session() as db_session: - db_product = db_session.query(ModelProduct).filter_by(product_id=product.product_id).first() + stmt = select(ModelProduct).filter_by(product_id=product.product_id) + db_product = db_session.scalars(stmt).first() if db_product: db_product.implementation_config = product.implementation_config db_session.commit() @@ -3073,11 +3070,8 @@ def _create_media_buy_impl( # Verify the creative exists from src.core.database.models import Creative as DBCreative - creative = ( - session.query(DBCreative) - .filter_by(tenant_id=tenant["tenant_id"], creative_id=creative_id) - .first() - ) + stmt = select(DBCreative).filter_by(tenant_id=tenant["tenant_id"], creative_id=creative_id) + creative = session.scalars(stmt).first() if not creative: logger.warning( @@ -3159,11 +3153,8 @@ def _create_media_buy_impl( try: principal_name = "Unknown" with get_db_session() as session: - principal_db = ( - session.query(ModelPrincipal) - .filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]) - .first() - ) + stmt = select(ModelPrincipal).filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]) + principal_db = session.scalars(stmt).first() if principal_db: principal_name = principal_db.name @@ -3202,11 +3193,8 @@ def _create_media_buy_impl( # Get principal name for notification (reuse from activity logging above) principal_name = "Unknown" with get_db_session() as session: - principal_db = ( - session.query(ModelPrincipal) - .filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]) - .first() - ) + stmt = select(ModelPrincipal).filter_by(principal_id=principal_id, tenant_id=tenant["tenant_id"]) + principal_db = session.scalars(stmt).first() if principal_db: principal_name = principal_db.name @@ -4193,7 +4181,8 @@ def complete_task(req, context): ) with get_db_session() as db_session: - db_task = db_session.query(Task).filter_by(task_id=req.task_id, tenant_id=tenant["tenant_id"]).first() + stmt = select(Task).filter_by(task_id=req.task_id, tenant_id=tenant["tenant_id"]) + db_task = db_session.scalars(stmt).first() if not db_task: raise ToolError("NOT_FOUND", f"Task {req.task_id} not found") @@ -4490,7 +4479,8 @@ def mark_task_complete(req, context): raise ToolError("DEPRECATED", "Task system has been replaced with workflow steps.") with get_db_session() as db_session: - db_task = db_session.query(Task).filter_by(task_id=req.task_id, tenant_id=tenant["tenant_id"]).first() + stmt = select(Task).filter_by(task_id=req.task_id, tenant_id=tenant["tenant_id"]) + db_task = db_session.scalars(stmt).first() if not db_task: raise ToolError("NOT_FOUND", f"Task {req.task_id} not found") @@ -4553,7 +4543,8 @@ def get_product_catalog() -> list[Product]: tenant = get_current_tenant() with get_db_session() as session: - products = session.query(ModelProduct).filter_by(tenant_id=tenant["tenant_id"]).all() + stmt = select(ModelProduct).filter_by(tenant_id=tenant["tenant_id"]) + products = session.scalars(stmt).all() loaded_products = [] for product in products: @@ -4769,7 +4760,8 @@ async def debug_root_logic(request: Request): # This is the fallback logic we don't need for test-agent try: with get_db_session() as db_session: - tenant_obj = db_session.query(Tenant).filter_by(subdomain=subdomain, is_active=True).first() + stmt = select(Tenant).filter_by(subdomain=subdomain, is_active=True) + tenant_obj = db_session.scalars(stmt).first() if tenant_obj: debug_info["subdomain_tenant_found"] = True # Build tenant dict... @@ -4880,7 +4872,8 @@ async def handle_landing_page(request: Request): # Look up tenant by subdomain try: with get_db_session() as db_session: - tenant_obj = db_session.query(Tenant).filter_by(subdomain=subdomain, is_active=True).first() + stmt = select(Tenant).filter_by(subdomain=subdomain, is_active=True) + tenant_obj = db_session.scalars(stmt).first() if tenant_obj: tenant = { "tenant_id": tenant_obj.tenant_id, @@ -4951,31 +4944,34 @@ def list_tasks( with get_db_session() as session: # Base query for workflow steps in this tenant - query = session.query(WorkflowStep).join(Context).filter(Context.tenant_id == tenant["tenant_id"]) + stmt = select(WorkflowStep).join(Context).where(Context.tenant_id == tenant["tenant_id"]) # Apply status filter if status: - query = query.filter(WorkflowStep.status == status) + stmt = stmt.where(WorkflowStep.status == status) # Apply object type/ID filters if object_type and object_id: - query = query.join(ObjectWorkflowMapping).filter( + stmt = stmt.join(ObjectWorkflowMapping).where( ObjectWorkflowMapping.object_type == object_type, ObjectWorkflowMapping.object_id == object_id ) elif object_type: - query = query.join(ObjectWorkflowMapping).filter(ObjectWorkflowMapping.object_type == object_type) + stmt = stmt.join(ObjectWorkflowMapping).where(ObjectWorkflowMapping.object_type == object_type) # Get total count before pagination - total = query.count() + from sqlalchemy import func + + total = session.scalar(select(func.count()).select_from(stmt.subquery())) # Apply pagination and ordering - tasks = query.order_by(WorkflowStep.created_at.desc()).offset(offset).limit(limit).all() + tasks = session.scalars(stmt.order_by(WorkflowStep.created_at.desc()).offset(offset).limit(limit)).all() # Format tasks for response formatted_tasks = [] for task in tasks: # Get associated objects - mappings = session.query(ObjectWorkflowMapping).filter_by(step_id=task.step_id).all() + stmt = select(ObjectWorkflowMapping).filter_by(step_id=task.step_id) + mappings = session.scalars(stmt).all() formatted_task = { "task_id": task.step_id, @@ -5036,18 +5032,19 @@ def get_task(task_id: str, context: Context = None) -> dict: with get_db_session() as session: # Find the task in this tenant - task = ( - session.query(WorkflowStep) + stmt = ( + select(WorkflowStep) .join(Context) - .filter(WorkflowStep.step_id == task_id, Context.tenant_id == tenant["tenant_id"]) - .first() + .where(WorkflowStep.step_id == task_id, Context.tenant_id == tenant["tenant_id"]) ) + task = session.scalars(stmt).first() if not task: raise ValueError(f"Task {task_id} not found") # Get associated objects - mappings = session.query(ObjectWorkflowMapping).filter_by(step_id=task_id).all() + stmt = select(ObjectWorkflowMapping).filter_by(step_id=task_id) + mappings = session.scalars(stmt).all() # Build detailed response task_detail = { @@ -5105,12 +5102,12 @@ def complete_task( with get_db_session() as session: # Find the task in this tenant - task = ( - session.query(WorkflowStep) + stmt = ( + select(WorkflowStep) .join(Context) - .filter(WorkflowStep.step_id == task_id, Context.tenant_id == tenant["tenant_id"]) - .first() + .where(WorkflowStep.step_id == task_id, Context.tenant_id == tenant["tenant_id"]) ) + task = session.scalars(stmt).first() if not task: raise ValueError(f"Task {task_id} not found") diff --git a/src/services/ai_creative_format_service.py b/src/services/ai_creative_format_service.py index 66f2a2008..996bb0acf 100644 --- a/src/services/ai_creative_format_service.py +++ b/src/services/ai_creative_format_service.py @@ -18,6 +18,7 @@ import aiohttp import google.generativeai as genai from bs4 import BeautifulSoup +from sqlalchemy import select from src.core.database.database_session import get_db_session from src.core.database.models import CreativeFormat @@ -808,7 +809,8 @@ async def sync_standard_formats(): for fmt in formats: try: # Check if format already exists - existing = session.query(CreativeFormat).filter_by(format_id=fmt.format_id).first() + stmt = select(CreativeFormat).filter_by(format_id=fmt.format_id) + existing = session.scalars(stmt).first() if not existing: # Insert new format diff --git a/src/services/ai_product_service.py b/src/services/ai_product_service.py index 7b4d68b8b..6aa071860 100644 --- a/src/services/ai_product_service.py +++ b/src/services/ai_product_service.py @@ -16,6 +16,7 @@ from typing import Any import google.generativeai as genai +from sqlalchemy import select from src.core.database.database_session import get_db_session from src.core.database.models import Principal as PrincipalModel @@ -255,13 +256,15 @@ async def _fetch_ad_server_inventory(self, tenant_id: str, adapter_type: str) -> # Get adapter configuration and principal with get_db_session() as db_session: # Get tenant ad server - tenant = db_session.query(Tenant).filter_by(tenant_id=tenant_id).first() + stmt = select(Tenant).filter_by(tenant_id=tenant_id) + tenant = db_session.scalars(stmt).first() if not tenant: logger.error(f"Tenant {tenant_id} not found") return AdServerInventory(ad_units=[], targeting_keys=[], formats=[]) # Get a principal for this tenant (use first available) - principal_model = db_session.query(PrincipalModel).filter_by(tenant_id=tenant_id).first() + stmt = select(PrincipalModel).filter_by(tenant_id=tenant_id) + principal_model = db_session.scalars(stmt).first() if not principal_model: # Create a temporary principal for inventory fetching @@ -305,7 +308,8 @@ async def _fetch_ad_server_inventory(self, tenant_id: str, adapter_type: str) -> # Get adapter config from adapter_config table with get_db_session() as db_session: - adapter_config_row = db_session.query(AdapterConfig).filter_by(tenant_id=tenant_id).first() + stmt = select(AdapterConfig).filter_by(tenant_id=tenant_id) + adapter_config_row = db_session.scalars(stmt).first() adapter_config = {} if adapter_config_row: # Build config from individual fields based on adapter type @@ -372,14 +376,14 @@ def _get_available_formats(self, tenant_id: str) -> list[dict[str, Any]]: with get_db_session() as db_session: from sqlalchemy import or_ - formats_query = ( - db_session.query(CreativeFormat) - .filter(or_(CreativeFormat.tenant_id.is_(None), CreativeFormat.tenant_id == tenant_id)) + stmt = ( + select(CreativeFormat) + .where(or_(CreativeFormat.tenant_id.is_(None), CreativeFormat.tenant_id == tenant_id)) .order_by(CreativeFormat.is_standard.desc(), CreativeFormat.type, CreativeFormat.name) ) formats = [] - for format_obj in formats_query: + for format_obj in db_session.scalars(stmt): format_dict = { "format_id": format_obj.format_id, "name": format_obj.name, @@ -644,7 +648,8 @@ async def analyze_product_description( # Get tenant's adapter type with get_db_session() as db_session: - tenant = db_session.query(Tenant).filter_by(tenant_id=tenant_id).first() + stmt = select(Tenant).filter_by(tenant_id=tenant_id) + tenant = db_session.scalars(stmt).first() if not tenant: raise ValueError(f"Tenant {tenant_id} not found") diff --git a/src/services/format_metrics_service.py b/src/services/format_metrics_service.py index aed8b43a6..9569cbebe 100644 --- a/src/services/format_metrics_service.py +++ b/src/services/format_metrics_service.py @@ -12,7 +12,7 @@ from datetime import datetime, timedelta from typing import Any -from sqlalchemy import and_ +from sqlalchemy import and_, select from sqlalchemy.orm import Session from src.adapters.gam_reporting_service import GAMReportingService @@ -275,17 +275,17 @@ def aggregate_all_tenants(period_days: int = 30) -> dict[str, Any]: with get_db_session() as db_session: # Get all active tenants with GAM configured - tenants = ( - db_session.query(Tenant, AdapterConfig) + stmt = ( + select(Tenant, AdapterConfig) .join(AdapterConfig, Tenant.tenant_id == AdapterConfig.tenant_id) - .filter( + .where( Tenant.ad_server == "google_ad_manager", Tenant.is_active, AdapterConfig.gam_network_code.isnot(None), AdapterConfig.gam_refresh_token.isnot(None), ) - .all() ) + tenants = db_session.execute(stmt).all() summary = {"total_tenants": len(tenants), "successful": 0, "failed": 0, "details": []} diff --git a/src/services/gam_inventory_service.py b/src/services/gam_inventory_service.py index 01578ccd1..677d71da8 100644 --- a/src/services/gam_inventory_service.py +++ b/src/services/gam_inventory_service.py @@ -12,7 +12,7 @@ from datetime import datetime, timedelta from typing import Any -from sqlalchemy import String, and_, create_engine, func, or_ +from sqlalchemy import String, and_, create_engine, func, or_, select from sqlalchemy.orm import Session, scoped_session, sessionmaker from src.adapters.gam.client import GAMClientManager @@ -780,7 +780,8 @@ def sync_inventory(tenant_id): from src.adapters.google_ad_manager import GoogleAdManager from src.core.database.models import AdapterConfig, Tenant - tenant = db_session.query(Tenant).filter_by(tenant_id=tenant_id).first() + stmt = select(Tenant).filter_by(tenant_id=tenant_id) + tenant = db_session.scalars(stmt).first() if not tenant: db_session.remove() return jsonify({"error": "Tenant not found"}), 404 @@ -791,7 +792,8 @@ def sync_inventory(tenant_id): return jsonify({"error": "GAM not enabled for tenant"}), 400 # Get adapter config from adapter_config table - adapter_config = db_session.query(AdapterConfig).filter_by(tenant_id=tenant_id).first() + stmt = select(AdapterConfig).filter_by(tenant_id=tenant_id) + adapter_config = db_session.scalars(stmt).first() if not adapter_config: db_session.remove() diff --git a/src/services/property_verification_service.py b/src/services/property_verification_service.py index b78bf39bb..914fff42d 100644 --- a/src/services/property_verification_service.py +++ b/src/services/property_verification_service.py @@ -8,6 +8,7 @@ import requests from requests.exceptions import RequestException +from sqlalchemy import select from src.core.database.database_session import get_db_session from src.core.database.models import AuthorizedProperty @@ -43,14 +44,11 @@ def verify_property(self, tenant_id: str, property_id: str, agent_url: str) -> t logger.info(f"🔍 Starting verification - tenant: {tenant_id}, property: {property_id}, agent: {agent_url}") with get_db_session() as session: - property_obj = ( - session.query(AuthorizedProperty) - .filter( - AuthorizedProperty.tenant_id == tenant_id, - AuthorizedProperty.property_id == property_id, - ) - .first() + stmt = select(AuthorizedProperty).where( + AuthorizedProperty.tenant_id == tenant_id, + AuthorizedProperty.property_id == property_id, ) + property_obj = session.scalars(stmt).first() if not property_obj: logger.error(f"❌ Property not found: {property_id} in tenant {tenant_id}") @@ -324,13 +322,10 @@ def verify_all_properties(self, tenant_id: str, agent_url: str) -> dict[str, Any try: with get_db_session() as session: # Get all pending properties - pending_properties = ( - session.query(AuthorizedProperty) - .filter( - AuthorizedProperty.tenant_id == tenant_id, AuthorizedProperty.verification_status == "pending" - ) - .all() + stmt = select(AuthorizedProperty).where( + AuthorizedProperty.tenant_id == tenant_id, AuthorizedProperty.verification_status == "pending" ) + pending_properties = session.scalars(stmt).all() results["total_checked"] = len(pending_properties) diff --git a/tests/integration/test_dashboard_service_integration.py b/tests/integration/test_dashboard_service_integration.py index 38c852ef3..a0b65c822 100644 --- a/tests/integration/test_dashboard_service_integration.py +++ b/tests/integration/test_dashboard_service_integration.py @@ -7,6 +7,7 @@ from datetime import datetime import pytest +from sqlalchemy import delete from src.admin.services.dashboard_service import DashboardService from src.core.database.database_session import get_db_session @@ -24,10 +25,12 @@ def test_tenant_data(self, integration_db): principal_id = "dashboard_test_principal" with get_db_session() as session: - # Clean up existing test data - session.query(Product).filter_by(tenant_id=tenant_id).delete() - session.query(Principal).filter_by(tenant_id=tenant_id, principal_id=principal_id).delete() - session.query(Tenant).filter_by(tenant_id=tenant_id).delete() + # Clean up existing test data (SQLAlchemy 2.0 pattern) + session.execute(delete(Product).where(Product.tenant_id == tenant_id)) + session.execute( + delete(Principal).where(Principal.tenant_id == tenant_id, Principal.principal_id == principal_id) + ) + session.execute(delete(Tenant).where(Tenant.tenant_id == tenant_id)) # Create test tenant with proper timestamps using helper tenant = create_tenant_with_timestamps( @@ -83,10 +86,12 @@ def test_tenant_data(self, integration_db): yield {"tenant_id": tenant_id, "principal_id": principal_id, "tenant": tenant, "principal": principal} - # Cleanup - session.query(Product).filter_by(tenant_id=tenant_id).delete() - session.query(Principal).filter_by(tenant_id=tenant_id, principal_id=principal_id).delete() - session.query(Tenant).filter_by(tenant_id=tenant_id).delete() + # Cleanup (SQLAlchemy 2.0 pattern) + session.execute(delete(Product).where(Product.tenant_id == tenant_id)) + session.execute( + delete(Principal).where(Principal.tenant_id == tenant_id, Principal.principal_id == principal_id) + ) + session.execute(delete(Tenant).where(Tenant.tenant_id == tenant_id)) session.commit() def test_dashboard_service_init_validation(self): diff --git a/tests/unit/test_dashboard_service.py b/tests/unit/test_dashboard_service.py index a51954d49..ed8f6b5f8 100644 --- a/tests/unit/test_dashboard_service.py +++ b/tests/unit/test_dashboard_service.py @@ -34,10 +34,12 @@ def test_get_tenant_caches_result(self, mock_get_db): mock_session = Mock() mock_get_db.return_value.__enter__.return_value = mock_session - # Mock tenant + # Mock tenant (SQLAlchemy 2.0 pattern) mock_tenant = Mock(spec=Tenant) mock_tenant.tenant_id = "test_tenant" - mock_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant + mock_scalars = Mock() + mock_scalars.first.return_value = mock_tenant + mock_session.scalars.return_value = mock_scalars service = DashboardService("test_tenant") @@ -51,7 +53,7 @@ def test_get_tenant_caches_result(self, mock_get_db): assert result2 == mock_tenant # Should only have called database once - mock_session.query.assert_called_once() + mock_session.scalars.assert_called_once() @patch("src.admin.services.dashboard_service.MediaBuyReadinessService") @patch("src.admin.services.dashboard_service.get_db_session") @@ -66,13 +68,11 @@ def test_get_dashboard_metrics_single_data_source(self, mock_get_activities, moc mock_tenant = Mock(spec=Tenant) mock_tenant.tenant_id = "test_tenant" - # Mock query results - need to set up proper query chain - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.filter.return_value = mock_query - mock_query.count.return_value = 5 - mock_query.all.return_value = [] - mock_session.query.return_value = mock_query + # Mock SQLAlchemy 2.0 query results + mock_scalars = Mock() + mock_scalars.all.return_value = [] + mock_session.scalars.return_value = mock_scalars + mock_session.scalar.return_value = 5 # For count queries # Mock readiness summary mock_readiness_summary = { diff --git a/tests/unit/test_property_verification_service.py b/tests/unit/test_property_verification_service.py index a4657cc7d..f764a9c4a 100644 --- a/tests/unit/test_property_verification_service.py +++ b/tests/unit/test_property_verification_service.py @@ -10,7 +10,7 @@ class MockSetup: @staticmethod def create_mock_db_session_with_property(property_data): - """Create mock database session with property.""" + """Create mock database session with property (SQLAlchemy 2.0 compatible).""" mock_session = Mock() mock_db_session_patcher = patch("src.services.property_verification_service.get_db_session") mock_db_session = mock_db_session_patcher.start() @@ -21,7 +21,11 @@ def create_mock_db_session_with_property(property_data): for key, value in property_data.items(): setattr(mock_property, key, value) - mock_session.query.return_value.filter.return_value.first.return_value = mock_property + # Mock SQLAlchemy 2.0 pattern: session.scalars(stmt).first() + mock_scalars = Mock() + mock_scalars.first.return_value = mock_property + mock_session.scalars.return_value = mock_scalars + return mock_db_session_patcher, mock_session, mock_property @staticmethod @@ -157,8 +161,7 @@ def test_check_agent_authorization(self): # No matching agent URL assert not self.service._check_agent_authorization(agents, "https://other-agent.example.com", property_obj) - @patch("requests.Session.get") - def test_verify_property_success(self, mock_get): + def test_verify_property_success(self): """Test successful property verification.""" # Use centralized mock setup property_data = { @@ -180,22 +183,24 @@ def test_verify_property_success(self, mock_get): # Setup mocks using centralized helper db_patcher, _, _ = MockSetup.create_mock_db_session_with_property(property_data) - mock_get.return_value = MockSetup.create_mock_http_response(response_data) - try: - # Test verification - is_verified, error = self.service.verify_property("tenant1", "prop1", "https://sales-agent.scope3.com") + # Mock the service's session.get method + with patch.object(self.service.session, "get") as mock_get: + mock_get.return_value = MockSetup.create_mock_http_response(response_data) + + try: + # Test verification + is_verified, error = self.service.verify_property("tenant1", "prop1", "https://sales-agent.scope3.com") - assert is_verified - assert error is None + assert is_verified + assert error is None - # Verify HTTP request was made - mock_get.assert_called_once_with("https://example.com/.well-known/adagents.json", timeout=10) - finally: - db_patcher.stop() + # Verify HTTP request was made + mock_get.assert_called_once_with("https://example.com/.well-known/adagents.json", timeout=10) + finally: + db_patcher.stop() - @patch("requests.Session.get") - def test_verify_property_not_authorized(self, mock_get): + def test_verify_property_not_authorized(self): """Test property verification when agent is not authorized.""" # Use centralized mock setup property_data = { @@ -217,19 +222,20 @@ def test_verify_property_not_authorized(self, mock_get): # Setup mocks using centralized helper db_patcher, _, _ = MockSetup.create_mock_db_session_with_property(property_data) - mock_get.return_value = MockSetup.create_mock_http_response(response_data) - try: - # Test verification - is_verified, error = self.service.verify_property("tenant1", "prop1", "https://sales-agent.scope3.com") + with patch.object(self.service.session, "get") as mock_get: + mock_get.return_value = MockSetup.create_mock_http_response(response_data) + + try: + # Test verification + is_verified, error = self.service.verify_property("tenant1", "prop1", "https://sales-agent.scope3.com") - assert not is_verified - assert "not found in authorized agents list" in error - finally: - db_patcher.stop() + assert not is_verified + assert "not found in authorized agents list" in error + finally: + db_patcher.stop() - @patch("requests.Session.get") - def test_verify_property_http_error(self, mock_get): + def test_verify_property_http_error(self): """Test property verification when HTTP request fails.""" from requests.exceptions import RequestException @@ -238,16 +244,18 @@ def test_verify_property_http_error(self, mock_get): # Setup mocks using centralized helper db_patcher, _, _ = MockSetup.create_mock_db_session_with_property(property_data) - mock_get.side_effect = RequestException("Connection failed") - try: - # Test verification - is_verified, error = self.service.verify_property("tenant1", "prop1", "https://sales-agent.scope3.com") + with patch.object(self.service.session, "get") as mock_get: + mock_get.side_effect = RequestException("Connection failed") + + try: + # Test verification + is_verified, error = self.service.verify_property("tenant1", "prop1", "https://sales-agent.scope3.com") - assert not is_verified - assert "Failed to fetch adagents.json" in error - finally: - db_patcher.stop() + assert not is_verified + assert "Failed to fetch adagents.json" in error + finally: + db_patcher.stop() def test_verify_property_not_found(self): """Test property verification when property doesn't exist.""" @@ -255,10 +263,14 @@ def test_verify_property_not_found(self): with patch("src.services.property_verification_service.get_db_session") as mock_db_session: mock_session = Mock() mock_db_session.return_value.__enter__.return_value = mock_session - mock_session.query.return_value.filter.return_value.first.return_value = None - # Test verification - is_verified, error = self.service.verify_property("tenant1", "prop1", "https://sales-agent.scope3.com") + # Mock SQLAlchemy 2.0 pattern: session.scalars(stmt).first() returns None + mock_scalars = Mock() + mock_scalars.first.return_value = None + mock_session.scalars.return_value = mock_scalars + + # Test verification + is_verified, error = self.service.verify_property("tenant1", "prop1", "https://sales-agent.scope3.com") assert not is_verified assert error == "Property not found"