Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ repos:
pass_filenames: false
always_run: false

# Check tenant context ordering in MCP tools
- id: check-tenant-context-order
name: Check tenant context auth before get_current_tenant()
entry: uv run python .pre-commit-hooks/check_tenant_context_order.py
language: system
files: '^src/core/tools/.*\.py$'
pass_filenames: true
always_run: false

# Prevent skipping tests (but allow skip_ci for CI-specific issues)
- id: no-skip-tests
name: No @pytest.mark.skip decorators allowed
Expand Down
128 changes: 128 additions & 0 deletions .pre-commit-hooks/check_tenant_context_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#!/usr/bin/env python3
"""Pre-commit hook to detect tenant context ordering bugs.

This hook checks that all MCP tool implementations call authentication
(get_principal_id_from_context or get_principal_from_context) BEFORE
calling get_current_tenant().

Prevents regression of the bug fixed in update_media_buy where
get_current_tenant() was called before tenant context was established.
"""

import re
import sys
from pathlib import Path


def check_file(file_path: Path) -> list[str]:
"""Check a single file for tenant context ordering issues.

Args:
file_path: Path to file to check

Returns:
List of error messages (empty if no issues)
"""
content = file_path.read_text()
errors = []

# Find all function definitions
# Look for def _*_impl( or async def _*_impl( functions (implementation functions)
impl_pattern = re.compile(r"(?:async\s+)?def\s+_(\w+)_impl\s*\(", re.MULTILINE)

for match in impl_pattern.finditer(content):
func_name = match.group(1)
func_start = match.start()

# Find the end of this function (next function definition or end of file)
next_func = re.search(r"\n(?:async\s+)?def\s+", content[func_start + len(match.group(0)) :])
func_end = func_start + len(match.group(0)) + next_func.start() if next_func else len(content)

func_body = content[func_start:func_end]

# Check if function uses get_current_tenant()
tenant_match = re.search(r"get_current_tenant\s*\(", func_body)
if not tenant_match:
continue # No tenant usage, skip

# Check if function has auth call before tenant call
auth_patterns = [
r"get_principal_id_from_context\s*\(",
r"get_principal_from_context\s*\(",
r"_get_principal_id_from_context\s*\(",
]

auth_pos = None
auth_pattern_used = None
for pattern in auth_patterns:
auth_match = re.search(pattern, func_body)
if auth_match:
if auth_pos is None or auth_match.start() < auth_pos:
auth_pos = auth_match.start()
auth_pattern_used = pattern.replace(r"\s*\(", "")

tenant_pos = tenant_match.start()

if auth_pos is None:
# Uses get_current_tenant() but no auth call - potential bug
line_num = content[:func_start].count("\n") + 1
errors.append(
f"{file_path}:{line_num}: "
f"Function '_{func_name}_impl' calls get_current_tenant() "
f"but does not call get_principal_*_from_context() first. "
f"This will cause 'No tenant context set' errors."
)
elif auth_pos > tenant_pos:
# Auth call comes after tenant call - definitely a bug!
line_num = content[:func_start].count("\n") + 1
errors.append(
f"{file_path}:{line_num}: "
f"BUG: Function '_{func_name}_impl' calls get_current_tenant() "
f"BEFORE {auth_pattern_used}(). This causes 'No tenant context set' errors. "
f"FIX: Move the auth call to before get_current_tenant()."
)

return errors


def main():
"""Main entry point for pre-commit hook."""
# Get list of files to check from arguments
files_to_check = [Path(f) for f in sys.argv[1:]]

# Only check tool implementation files
tool_files = [
f
for f in files_to_check
if f.suffix == ".py" and "/tools/" in str(f) and f.name not in ["__init__.py", "tool_context.py"]
]

if not tool_files:
return 0 # No tool files to check

all_errors = []
for file_path in tool_files:
errors = check_file(file_path)
all_errors.extend(errors)

if all_errors:
print("❌ Tenant context ordering errors detected:")
print()
for error in all_errors:
print(f" {error}")
print()
print("CRITICAL: All tool implementations must call authentication")
print("(get_principal_id_from_context or get_principal_from_context)")
print("BEFORE calling get_current_tenant().")
print()
print("Correct pattern:")
print(" 1. principal_id = get_principal_id_from_context(ctx)")
print(" 2. tenant = get_current_tenant() # Now safe")
print()
return 1

return 0


if __name__ == "__main__":
sys.exit(main())
33 changes: 31 additions & 2 deletions src/core/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,45 @@ def safe_json_loads(value, default=None):


def get_current_tenant() -> dict[str, Any]:
"""Get current tenant from context."""
"""Get current tenant from context.

CRITICAL: This function must only be called AFTER tenant context has been established
via get_principal_id_from_context() or get_principal_from_context() + set_current_tenant().

Common mistake: Calling get_current_tenant() before authenticating the request.
Correct order:
1. principal_id = get_principal_id_from_context(ctx) # Sets tenant context
2. tenant = get_current_tenant() # Now safe to call

Raises:
RuntimeError: If tenant context is not set (indicates authentication/ordering bug)
"""
import inspect

tenant = current_tenant.get()
if not tenant:
# SECURITY: Do NOT fall back to default tenant in production.
# This would cause tenant isolation breach.
# Only CLI/testing scripts should call this without context.

# Get caller information for debugging
frame = inspect.currentframe()
caller_frame = frame.f_back if frame else None
caller_info = ""
if caller_frame:
caller_file = caller_frame.f_code.co_filename
caller_line = caller_frame.f_lineno
caller_func = caller_frame.f_code.co_name
caller_info = f"\n Called from: {caller_file}:{caller_line} in {caller_func}()"

raise RuntimeError(
"No tenant context set. Tenant must be set via set_current_tenant() "
"before calling this function. This is a critical security error - "
"falling back to default tenant would breach tenant isolation."
"falling back to default tenant would breach tenant isolation.\n"
"\n"
"COMMON CAUSE: Calling get_current_tenant() before authenticating the request.\n"
"FIX: Ensure get_principal_id_from_context(ctx) is called BEFORE get_current_tenant()."
f"{caller_info}"
)
return tenant

Expand Down
20 changes: 11 additions & 9 deletions src/core/tools/media_buy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,22 @@ def _update_media_buy_impl(
if ctx is None:
raise ValueError("Context is required for update_media_buy")

# CRITICAL: Establish tenant context FIRST by extracting principal from auth token
# This must happen before any database queries that need tenant_id
principal_id = get_principal_id_from_context(ctx)
if principal_id is None:
raise ValueError("principal_id is required but was None - authentication required")

# Now tenant context is set, we can safely call get_current_tenant()
tenant = get_current_tenant()

# Resolve media_buy_id from buyer_ref if needed (AdCP oneOf constraint)
media_buy_id_to_use = req.media_buy_id
if not media_buy_id_to_use and req.buyer_ref:
# Look up media_buy_id by buyer_ref
# Look up media_buy_id by buyer_ref (tenant context already set above)
from src.core.database.database_session import get_db_session
from src.core.database.models import MediaBuy as MediaBuyModel

tenant = get_current_tenant()
with get_db_session() as session:
stmt = select(MediaBuyModel).where(
MediaBuyModel.buyer_ref == req.buyer_ref, MediaBuyModel.tenant_id == tenant["tenant_id"]
Expand All @@ -233,14 +241,8 @@ def _update_media_buy_impl(
# Update req.media_buy_id for downstream processing
req.media_buy_id = media_buy_id_to_use

# Verify principal owns this media buy
_verify_principal(media_buy_id_to_use, ctx)
principal_id = get_principal_id_from_context(ctx) # Already verified by _verify_principal

# Verify principal_id is not None (get_principal_id_from_context should raise if None)
if principal_id is None:
raise ValueError("principal_id is required but was None")

tenant = get_current_tenant()

# Create or get persistent context
ctx_manager = get_context_manager()
Expand Down
23 changes: 6 additions & 17 deletions tests/integration/test_update_media_buy_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
"""

from datetime import date, timedelta
from unittest.mock import MagicMock

import pytest

Expand Down Expand Up @@ -177,23 +176,13 @@ def test_update_media_buy_requires_context():


@pytest.mark.requires_db
def test_update_media_buy_requires_media_buy_id():
def test_update_media_buy_requires_media_buy_id(test_tenant_setup):
"""Test update_media_buy raises error when buyer_ref lookup fails."""
# Create minimal mock context
context = MagicMock()
context.headers = {"x-adcp-auth": "test_token", "host": "test-tenant.test.com"}

# Set tenant context (required by get_current_tenant)
from src.core.config_loader import set_current_tenant

set_current_tenant(
{
"tenant_id": "test_tenant_no_mb",
"name": "Test Tenant",
"subdomain": "test-tenant",
"ad_server": "mock",
"is_active": True,
}
# Use valid authentication from fixture (required after auth ordering fix)
context = MockContext(
tenant_id=test_tenant_setup["tenant_id"],
principal_id=test_tenant_setup["principal_id"],
token=test_tenant_setup["token"],
)

# Note: When media_buy_id is None and buyer_ref is provided,
Expand Down
Loading