diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b2c7c857e..f100cb230 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -50,6 +50,15 @@ repos: pass_filenames: false always_run: false + # Check tenant context ordering in MCP tools + - id: check-tenant-context-order + name: Check tenant context auth before get_current_tenant() + entry: uv run python .pre-commit-hooks/check_tenant_context_order.py + language: system + files: '^src/core/tools/.*\.py$' + pass_filenames: true + always_run: false + # Prevent skipping tests (but allow skip_ci for CI-specific issues) - id: no-skip-tests name: No @pytest.mark.skip decorators allowed diff --git a/.pre-commit-hooks/check_tenant_context_order.py b/.pre-commit-hooks/check_tenant_context_order.py new file mode 100755 index 000000000..630eafd11 --- /dev/null +++ b/.pre-commit-hooks/check_tenant_context_order.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +"""Pre-commit hook to detect tenant context ordering bugs. + +This hook checks that all MCP tool implementations call authentication +(get_principal_id_from_context or get_principal_from_context) BEFORE +calling get_current_tenant(). + +Prevents regression of the bug fixed in update_media_buy where +get_current_tenant() was called before tenant context was established. +""" + +import re +import sys +from pathlib import Path + + +def check_file(file_path: Path) -> list[str]: + """Check a single file for tenant context ordering issues. + + Args: + file_path: Path to file to check + + Returns: + List of error messages (empty if no issues) + """ + content = file_path.read_text() + errors = [] + + # Find all function definitions + # Look for def _*_impl( or async def _*_impl( functions (implementation functions) + impl_pattern = re.compile(r"(?:async\s+)?def\s+_(\w+)_impl\s*\(", re.MULTILINE) + + for match in impl_pattern.finditer(content): + func_name = match.group(1) + func_start = match.start() + + # Find the end of this function (next function definition or end of file) + next_func = re.search(r"\n(?:async\s+)?def\s+", content[func_start + len(match.group(0)) :]) + func_end = func_start + len(match.group(0)) + next_func.start() if next_func else len(content) + + func_body = content[func_start:func_end] + + # Check if function uses get_current_tenant() + tenant_match = re.search(r"get_current_tenant\s*\(", func_body) + if not tenant_match: + continue # No tenant usage, skip + + # Check if function has auth call before tenant call + auth_patterns = [ + r"get_principal_id_from_context\s*\(", + r"get_principal_from_context\s*\(", + r"_get_principal_id_from_context\s*\(", + ] + + auth_pos = None + auth_pattern_used = None + for pattern in auth_patterns: + auth_match = re.search(pattern, func_body) + if auth_match: + if auth_pos is None or auth_match.start() < auth_pos: + auth_pos = auth_match.start() + auth_pattern_used = pattern.replace(r"\s*\(", "") + + tenant_pos = tenant_match.start() + + if auth_pos is None: + # Uses get_current_tenant() but no auth call - potential bug + line_num = content[:func_start].count("\n") + 1 + errors.append( + f"{file_path}:{line_num}: " + f"Function '_{func_name}_impl' calls get_current_tenant() " + f"but does not call get_principal_*_from_context() first. " + f"This will cause 'No tenant context set' errors." + ) + elif auth_pos > tenant_pos: + # Auth call comes after tenant call - definitely a bug! + line_num = content[:func_start].count("\n") + 1 + errors.append( + f"{file_path}:{line_num}: " + f"BUG: Function '_{func_name}_impl' calls get_current_tenant() " + f"BEFORE {auth_pattern_used}(). This causes 'No tenant context set' errors. " + f"FIX: Move the auth call to before get_current_tenant()." + ) + + return errors + + +def main(): + """Main entry point for pre-commit hook.""" + # Get list of files to check from arguments + files_to_check = [Path(f) for f in sys.argv[1:]] + + # Only check tool implementation files + tool_files = [ + f + for f in files_to_check + if f.suffix == ".py" and "/tools/" in str(f) and f.name not in ["__init__.py", "tool_context.py"] + ] + + if not tool_files: + return 0 # No tool files to check + + all_errors = [] + for file_path in tool_files: + errors = check_file(file_path) + all_errors.extend(errors) + + if all_errors: + print("❌ Tenant context ordering errors detected:") + print() + for error in all_errors: + print(f" {error}") + print() + print("CRITICAL: All tool implementations must call authentication") + print("(get_principal_id_from_context or get_principal_from_context)") + print("BEFORE calling get_current_tenant().") + print() + print("Correct pattern:") + print(" 1. principal_id = get_principal_id_from_context(ctx)") + print(" 2. tenant = get_current_tenant() # Now safe") + print() + return 1 + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/core/config_loader.py b/src/core/config_loader.py index 7900b0947..7666b87c5 100644 --- a/src/core/config_loader.py +++ b/src/core/config_loader.py @@ -32,16 +32,45 @@ def safe_json_loads(value, default=None): def get_current_tenant() -> dict[str, Any]: - """Get current tenant from context.""" + """Get current tenant from context. + + CRITICAL: This function must only be called AFTER tenant context has been established + via get_principal_id_from_context() or get_principal_from_context() + set_current_tenant(). + + Common mistake: Calling get_current_tenant() before authenticating the request. + Correct order: + 1. principal_id = get_principal_id_from_context(ctx) # Sets tenant context + 2. tenant = get_current_tenant() # Now safe to call + + Raises: + RuntimeError: If tenant context is not set (indicates authentication/ordering bug) + """ + import inspect + tenant = current_tenant.get() if not tenant: # SECURITY: Do NOT fall back to default tenant in production. # This would cause tenant isolation breach. # Only CLI/testing scripts should call this without context. + + # Get caller information for debugging + frame = inspect.currentframe() + caller_frame = frame.f_back if frame else None + caller_info = "" + if caller_frame: + caller_file = caller_frame.f_code.co_filename + caller_line = caller_frame.f_lineno + caller_func = caller_frame.f_code.co_name + caller_info = f"\n Called from: {caller_file}:{caller_line} in {caller_func}()" + raise RuntimeError( "No tenant context set. Tenant must be set via set_current_tenant() " "before calling this function. This is a critical security error - " - "falling back to default tenant would breach tenant isolation." + "falling back to default tenant would breach tenant isolation.\n" + "\n" + "COMMON CAUSE: Calling get_current_tenant() before authenticating the request.\n" + "FIX: Ensure get_principal_id_from_context(ctx) is called BEFORE get_current_tenant()." + f"{caller_info}" ) return tenant diff --git a/src/core/tools/media_buy_update.py b/src/core/tools/media_buy_update.py index f2d048e31..758f14961 100644 --- a/src/core/tools/media_buy_update.py +++ b/src/core/tools/media_buy_update.py @@ -207,14 +207,22 @@ def _update_media_buy_impl( if ctx is None: raise ValueError("Context is required for update_media_buy") + # CRITICAL: Establish tenant context FIRST by extracting principal from auth token + # This must happen before any database queries that need tenant_id + principal_id = get_principal_id_from_context(ctx) + if principal_id is None: + raise ValueError("principal_id is required but was None - authentication required") + + # Now tenant context is set, we can safely call get_current_tenant() + tenant = get_current_tenant() + # Resolve media_buy_id from buyer_ref if needed (AdCP oneOf constraint) media_buy_id_to_use = req.media_buy_id if not media_buy_id_to_use and req.buyer_ref: - # Look up media_buy_id by buyer_ref + # Look up media_buy_id by buyer_ref (tenant context already set above) from src.core.database.database_session import get_db_session from src.core.database.models import MediaBuy as MediaBuyModel - tenant = get_current_tenant() with get_db_session() as session: stmt = select(MediaBuyModel).where( MediaBuyModel.buyer_ref == req.buyer_ref, MediaBuyModel.tenant_id == tenant["tenant_id"] @@ -233,14 +241,8 @@ def _update_media_buy_impl( # Update req.media_buy_id for downstream processing req.media_buy_id = media_buy_id_to_use + # Verify principal owns this media buy _verify_principal(media_buy_id_to_use, ctx) - principal_id = get_principal_id_from_context(ctx) # Already verified by _verify_principal - - # Verify principal_id is not None (get_principal_id_from_context should raise if None) - if principal_id is None: - raise ValueError("principal_id is required but was None") - - tenant = get_current_tenant() # Create or get persistent context ctx_manager = get_context_manager() diff --git a/tests/integration/test_update_media_buy_persistence.py b/tests/integration/test_update_media_buy_persistence.py index 5581d4a23..04d6667d1 100644 --- a/tests/integration/test_update_media_buy_persistence.py +++ b/tests/integration/test_update_media_buy_persistence.py @@ -8,7 +8,6 @@ """ from datetime import date, timedelta -from unittest.mock import MagicMock import pytest @@ -177,23 +176,13 @@ def test_update_media_buy_requires_context(): @pytest.mark.requires_db -def test_update_media_buy_requires_media_buy_id(): +def test_update_media_buy_requires_media_buy_id(test_tenant_setup): """Test update_media_buy raises error when buyer_ref lookup fails.""" - # Create minimal mock context - context = MagicMock() - context.headers = {"x-adcp-auth": "test_token", "host": "test-tenant.test.com"} - - # Set tenant context (required by get_current_tenant) - from src.core.config_loader import set_current_tenant - - set_current_tenant( - { - "tenant_id": "test_tenant_no_mb", - "name": "Test Tenant", - "subdomain": "test-tenant", - "ad_server": "mock", - "is_active": True, - } + # Use valid authentication from fixture (required after auth ordering fix) + context = MockContext( + tenant_id=test_tenant_setup["tenant_id"], + principal_id=test_tenant_setup["principal_id"], + token=test_tenant_setup["token"], ) # Note: When media_buy_id is None and buyer_ref is provided, diff --git a/tests/unit/test_tenant_context_ordering.py b/tests/unit/test_tenant_context_ordering.py new file mode 100644 index 000000000..a8c7c9ee6 --- /dev/null +++ b/tests/unit/test_tenant_context_ordering.py @@ -0,0 +1,210 @@ +"""Tests to prevent tenant context ordering regressions. + +This test suite ensures that all MCP tools follow the correct pattern: +1. Call get_principal_id_from_context() or get_principal_from_context() FIRST +2. Only then call get_current_tenant() + +The bug fixed in update_media_buy (calling get_current_tenant() before auth) +must never happen again. +""" + +from unittest.mock import patch + +import pytest + +from src.core.config_loader import get_current_tenant, set_current_tenant + + +def test_get_current_tenant_raises_if_not_set(): + """Test that get_current_tenant() raises RuntimeError if context not set.""" + # Clear any existing tenant context + from src.core.config_loader import current_tenant + + current_tenant.set(None) + + # Should raise RuntimeError with helpful message + with pytest.raises(RuntimeError) as exc_info: + get_current_tenant() + + error_msg = str(exc_info.value) + assert "No tenant context set" in error_msg + assert "get_principal_id_from_context(ctx)" in error_msg + assert "BEFORE get_current_tenant()" in error_msg + + +def test_get_current_tenant_includes_caller_info(): + """Test that error message includes caller information for debugging.""" + from src.core.config_loader import current_tenant + + current_tenant.set(None) + + try: + get_current_tenant() + pytest.fail("Should have raised RuntimeError") + except RuntimeError as e: + error_msg = str(e) + # Should include file, line, and function name + assert "Called from:" in error_msg + assert "test_tenant_context_ordering.py" in error_msg + assert "test_get_current_tenant_includes_caller_info" in error_msg + + +def test_get_current_tenant_succeeds_after_set_current_tenant(): + """Test that get_current_tenant() works after set_current_tenant().""" + test_tenant = {"tenant_id": "test_tenant", "name": "Test Tenant"} + + set_current_tenant(test_tenant) + tenant = get_current_tenant() + + assert tenant == test_tenant + assert tenant["tenant_id"] == "test_tenant" + + +def test_update_media_buy_calls_auth_before_tenant(): + """Regression test: update_media_buy must call auth before get_current_tenant().""" + from datetime import UTC, datetime + + from src.core.tool_context import ToolContext + from src.core.tools.media_buy_update import _update_media_buy_impl + + # Create mock context with auth + ctx = ToolContext( + context_id="test_ctx", + principal_id="test_principal", + tenant_id="test_tenant", + tool_name="update_media_buy", + request_timestamp=datetime.now(UTC), + ) + + # Mock dependencies + with ( + patch("src.core.tools.media_buy_update.get_principal_id_from_context") as mock_auth, + patch("src.core.tools.media_buy_update.get_current_tenant") as mock_tenant, + patch("src.core.tools.media_buy_update._verify_principal"), + patch("src.core.tools.media_buy_update.get_db_session"), + ): + mock_auth.return_value = "test_principal" + mock_tenant.return_value = {"tenant_id": "test_tenant"} + + # Try to call (will fail on DB access, but we're testing call order) + try: + _update_media_buy_impl(media_buy_id="mb_123", ctx=ctx) + except Exception: + pass # Expected to fail, we're just checking call order + + # CRITICAL: auth must be called before tenant + # If mock_auth wasn't called, it means get_current_tenant() was called first (bug!) + assert mock_auth.called, "get_principal_id_from_context() must be called to set tenant context" + + +def test_create_media_buy_has_correct_pattern_in_source(): + """Verify create_media_buy source code follows correct pattern.""" + from pathlib import Path + + # Read the create_media_buy_impl source + file_path = Path(__file__).parent.parent.parent / "src" / "core" / "tools" / "media_buy_create.py" + source = file_path.read_text() + + # Find the _create_media_buy_impl function + impl_start = source.find("async def _create_media_buy_impl(") + assert impl_start != -1, "_create_media_buy_impl function not found" + + # Extract just the implementation function (up to next function definition) + impl_end = source.find("\nasync def ", impl_start + 1) + if impl_end == -1: + impl_end = source.find("\ndef ", impl_start + 1) + impl_source = source[impl_start:impl_end] if impl_end != -1 else source[impl_start:] + + # Find first occurrence of get_principal_id_from_context + auth_pos = impl_source.find("get_principal_id_from_context(") + # Find first occurrence of get_current_tenant + tenant_pos = impl_source.find("get_current_tenant()") + + # Both should be present + assert auth_pos != -1, "get_principal_id_from_context() not found in _create_media_buy_impl" + assert tenant_pos != -1, "get_current_tenant() not found in _create_media_buy_impl" + + # Auth must come before tenant + assert auth_pos < tenant_pos, ( + f"BUG: get_current_tenant() called before get_principal_id_from_context() in create_media_buy\n" + f" Auth call at position {auth_pos}\n" + f" Tenant call at position {tenant_pos}\n" + f" This is the bug we fixed in update_media_buy!" + ) + + +def test_all_tools_have_auth_before_tenant_pattern(): + """Documentation test: Verify pattern is documented in all tool files.""" + from pathlib import Path + + tools_dir = Path(__file__).parent.parent.parent / "src" / "core" / "tools" + tool_files = [ + "products.py", + "creative_formats.py", + "creatives.py", + "media_buy_create.py", + "media_buy_update.py", + "media_buy_delivery.py", + "performance.py", + "properties.py", + "signals.py", + ] + + issues = [] + for tool_file in tool_files: + file_path = tools_dir / tool_file + if not file_path.exists(): + issues.append(f"{tool_file}: File not found") + continue + + content = file_path.read_text() + + # Check for authentication calls + has_auth = any( + pattern in content + for pattern in [ + "get_principal_id_from_context", + "get_principal_from_context", + "_get_principal_id_from_context", + ] + ) + + # Check for tenant usage + has_tenant = "get_current_tenant" in content + + # If tool uses tenant context, it MUST call auth first + if has_tenant and not has_auth: + issues.append(f"{tool_file}: Uses get_current_tenant() but missing auth call") + + if issues: + pytest.fail("Tool files with tenant context issues:\n" + "\n".join(f" - {issue}" for issue in issues)) + + +def test_helper_function_sets_tenant_context(): + """Test that get_principal_id_from_context() actually sets tenant context.""" + from datetime import UTC, datetime + + # Clear tenant context + from src.core.config_loader import current_tenant + from src.core.helpers.context_helpers import get_principal_id_from_context + from src.core.tool_context import ToolContext + + current_tenant.set(None) + + # Create context with tenant + ctx = ToolContext( + context_id="test_ctx", + principal_id="test_principal", + tenant_id="test_tenant", + tool_name="test", + request_timestamp=datetime.now(UTC), + ) + + # Call helper + principal_id = get_principal_id_from_context(ctx) + + assert principal_id == "test_principal" + + # Verify tenant context was set + tenant = get_current_tenant() + assert tenant["tenant_id"] == "test_tenant"