Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 19 additions & 34 deletions backend/core/mock_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from core.config import Settings, get_settings
from core.identity_service import IdentityService
from rag_solution.core.exceptions import NotFoundError
from rag_solution.schemas.user_schema import UserInput
from rag_solution.services.user_service import UserService

Expand Down Expand Up @@ -107,18 +106,17 @@ def is_bypass_mode_active() -> bool:


def ensure_mock_user_exists(db: Session, settings: Settings, user_key: str = "default") -> UUID: # pylint: disable=unused-argument
"""
Ensure a mock user exists with full initialization.
"""Ensure a mock user exists using standard user creation flow.

This function uses the UserService to properly create the user
with all required components:
- User record
This function uses the UserService.get_or_create_user() method to maintain
consistency with how OIDC and API users are created. The get_or_create_user()
method automatically handles:
- User record creation/retrieval
- Prompt templates (RAG_QUERY, QUESTION_GENERATION, PODCAST_GENERATION)
- LLM provider assignment
- LLM parameters
- Pipeline configuration

Uses settings for configuration to ensure consistency across the application.
- Defensive reinitialization if defaults are missing

Args:
db: Database session
Expand All @@ -127,39 +125,26 @@ def ensure_mock_user_exists(db: Session, settings: Settings, user_key: str = "de

Returns:
UUID: The user's ID
"""
# Get mock user configuration from settings (not os.getenv directly)
# This ensures consistency with create_mock_user_data() and get_current_user()
config = {
"ibm_id": os.getenv("MOCK_USER_IBM_ID", "mock-user-ibm-id"), # Still use env for IBM ID
"email": settings.mock_user_email,
"name": settings.mock_user_name,
"role": os.getenv("MOCK_USER_ROLE", "admin"), # Still use env for role
}

Note:
This method now uses the same code path as OIDC users (get_or_create_user)
instead of having separate logic for mock users. This ensures consistent
behavior across all authentication methods.
"""
try:
user_service = UserService(db, settings)

# Try to get existing user first
try:
existing_user = user_service.user_repository.get_by_ibm_id(str(config["ibm_id"]))
logger.debug("Mock user already exists: %s", existing_user.id)
return existing_user.id
except (NotFoundError, ValueError, AttributeError, TypeError):
# User doesn't exist, proceed to create
logger.debug("Mock user not found, will create new user")

# Create new user with full initialization
# Use standardized user creation flow (same as OIDC/API users)
user_input = UserInput(
ibm_id=str(config["ibm_id"]),
email=str(config["email"]),
name=str(config["name"]),
role=str(config["role"]),
ibm_id=os.getenv("MOCK_USER_IBM_ID", "mock-user-ibm-id"),
email=settings.mock_user_email,
name=settings.mock_user_name,
role=os.getenv("MOCK_USER_ROLE", "admin"),
)

logger.info("Creating mock user: %s", config["email"])
user = user_service.create_user(user_input)
logger.info("Mock user created successfully: %s", user.id)
logger.info("Ensuring mock user exists: %s", user_input.email)
user = user_service.get_or_create_user(user_input)
logger.info("Mock user ready: %s", user.id)

return user.id

Expand Down
60 changes: 56 additions & 4 deletions backend/rag_solution/services/user_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
from rag_solution.core.exceptions import NotFoundError, ValidationError
from rag_solution.repository.user_repository import UserRepository
from rag_solution.schemas.user_schema import UserInput, UserOutput
from rag_solution.services.prompt_template_service import PromptTemplateService
from rag_solution.services.user_provider_service import UserProviderService

logger = get_logger(__name__)

# Minimum number of required templates for user initialization
# Includes: RAG_QUERY, QUESTION_GENERATION, PODCAST_GENERATION
MIN_REQUIRED_TEMPLATES = 3


class UserService:
"""Service for managing user-related operations."""
Expand All @@ -22,6 +27,7 @@ def __init__(self: Any, db: Session, settings: Settings) -> None:
self.settings = settings
self.user_repository = UserRepository(db)
self.user_provider_service = UserProviderService(db, settings)
self.prompt_template_service = PromptTemplateService(db)

def create_user(self, user_input: UserInput) -> UserOutput:
"""Creates a new user with validation.
Expand All @@ -37,7 +43,7 @@ def create_user(self, user_input: UserInput) -> UserOutput:
provider, templates, parameters = self.user_provider_service.initialize_user_defaults(user.id)

# Validate that all required defaults were created (RAG, Question, Podcast)
if not provider or not templates or len(templates) < 3 or not parameters:
if not provider or not templates or len(templates) < MIN_REQUIRED_TEMPLATES or not parameters:
self.db.rollback()
raise ValidationError("Failed to initialize required user configuration")

Expand All @@ -51,11 +57,57 @@ def get_or_create_user_by_fields(self, ibm_id: str, email: EmailStr, name: str,
)

def get_or_create_user(self, user_input: UserInput) -> UserOutput:
"""Gets existing user or creates new one from input model."""
"""Gets existing user or creates new one, ensuring all required defaults exist.

This method provides defensive initialization to handle edge cases where users
may exist in the database but are missing required defaults (e.g., after database
wipes, failed initializations, or data migrations).

Args:
user_input: User data for creation or lookup

Returns:
UserOutput: User with all required defaults initialized

Note:
Automatically reinitializes missing defaults (templates, parameters, pipelines)
for existing users. This adds one DB query per user access but prevents
silent failures during collection creation or search operations.
"""
try:
return self.user_repository.get_by_ibm_id(user_input.ibm_id)
existing_user = self.user_repository.get_by_ibm_id(user_input.ibm_id)

# Defensive check: Ensure user has required defaults
# Handles edge case where user exists after DB wipe but missing defaults
templates = self.prompt_template_service.get_user_templates(existing_user.id)

if not templates or len(templates) < MIN_REQUIRED_TEMPLATES:
logger.warning(
"User %s exists but missing defaults (has %d/%d templates) - attempting recovery...",
existing_user.id,
len(templates) if templates else 0,
MIN_REQUIRED_TEMPLATES,
)
try:
_, reinit_templates, parameters = self.user_provider_service.initialize_user_defaults(
existing_user.id
)
logger.info(
"✅ Successfully recovered user %s: %d templates, %s parameters",
existing_user.id,
len(reinit_templates),
"created" if parameters else "failed",
)
except Exception as e:
logger.error("❌ Failed to recover user %s: %s", existing_user.id, str(e))
raise ValidationError(
f"User {existing_user.id} missing required defaults and recovery failed: {e}",
field="user_initialization",
) from e

return existing_user
except NotFoundError:
# User doesn't exist, create a new one
# User doesn't exist, create with full initialization
return self.create_user(user_input)

def get_user_by_id(self, user_id: UUID4) -> UserOutput:
Expand Down
46 changes: 45 additions & 1 deletion backend/tests/integration/test_user_database.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
"""
Simplified version of test_user_database.py
Integration tests for user initialization and recovery after database operations.
"""

import pytest
from sqlalchemy import text

from core.mock_auth import ensure_mock_user_exists
from rag_solution.services.prompt_template_service import PromptTemplateService


@pytest.mark.integration
Expand All @@ -22,3 +26,43 @@ def test_mock_services(self, mock_watsonx_provider):
"""Test mock services."""
assert mock_watsonx_provider is not None
assert hasattr(mock_watsonx_provider, "generate_response")


@pytest.mark.integration
class TestUserInitializationRecovery:
"""Integration tests for user initialization recovery after database wipes."""

def test_mock_user_initialization_after_db_wipe(self, db_session, integration_settings):
"""Integration test: Mock user gets defaults even after DB wipe simulating template deletion."""
# Create mock user with full initialization
user_id = ensure_mock_user_exists(db_session, integration_settings)

# Verify templates exist
template_service = PromptTemplateService(db_session)
templates_before = template_service.get_user_templates(user_id)
assert len(templates_before) >= 3, f"Expected at least 3 templates, got {len(templates_before)}"

# Simulate DB wipe (delete templates but keep user)
# This simulates what happens after scripts/wipe_database.py
db_session.execute(text("DELETE FROM prompt_templates WHERE user_id = :uid"), {"uid": str(user_id)})
db_session.commit()

# Verify templates were deleted
templates_after_wipe = template_service.get_user_templates(user_id)
assert len(templates_after_wipe) == 0, "Templates should be deleted after simulated wipe"

# Call ensure_mock_user_exists again - should trigger defensive initialization
recovered_user_id = ensure_mock_user_exists(db_session, integration_settings)
assert recovered_user_id == user_id, "Should return same user ID"

# Verify templates were recreated by defensive initialization
templates_after_recovery = template_service.get_user_templates(user_id)
assert len(templates_after_recovery) >= 3, (
f"Expected at least 3 templates after recovery, got {len(templates_after_recovery)}"
)

# Verify we have all required template types
template_types = {t.template_type for t in templates_after_recovery}
assert "RAG_QUERY" in template_types, "Missing RAG_QUERY template"
assert "QUESTION_GENERATION" in template_types, "Missing QUESTION_GENERATION template"
assert "PODCAST_GENERATION" in template_types, "Missing PODCAST_GENERATION template"
73 changes: 72 additions & 1 deletion backend/tests/unit/test_user_service_tdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def service(self, mock_db, mock_settings):
with (
patch("rag_solution.services.user_service.UserRepository"),
patch("rag_solution.services.user_service.UserProviderService"),
patch("rag_solution.services.user_service.PromptTemplateService"),
):
service = UserService(mock_db, mock_settings)
service.user_repository = Mock()
service.user_provider_service = Mock()
service.prompt_template_service = Mock()
return service

def test_create_user_success_red_phase(self, service, mock_db):
Expand Down Expand Up @@ -151,7 +153,7 @@ def test_create_user_insufficient_templates_red_phase(self, service, mock_db):
mock_db.rollback.assert_called_once()

def test_get_or_create_user_existing_user_red_phase(self, service):
"""RED: Test get_or_create when user already exists."""
"""RED: Test get_or_create when user already exists with sufficient templates."""
user_input = UserInput(
ibm_id="existing_user",
email="existing@example.com",
Expand All @@ -172,11 +174,15 @@ def test_get_or_create_user_existing_user_red_phase(self, service):
)

service.user_repository.get_by_ibm_id.return_value = existing_user
# Mock that user has 3 templates (sufficient)
service.prompt_template_service.get_user_templates.return_value = [Mock(), Mock(), Mock()]

result = service.get_or_create_user(user_input)

assert result is existing_user
service.user_repository.get_by_ibm_id.assert_called_once_with("existing_user")
service.prompt_template_service.get_user_templates.assert_called_once_with(existing_user.id)
service.user_provider_service.initialize_user_defaults.assert_not_called()
service.user_repository.create.assert_not_called()

def test_get_or_create_user_new_user_red_phase(self, service, mock_db): # noqa: ARG002
Expand Down Expand Up @@ -211,6 +217,69 @@ def test_get_or_create_user_new_user_red_phase(self, service, mock_db): # noqa:
service.user_repository.get_by_ibm_id.assert_called_once_with("new_user")
service.user_repository.create.assert_called_once_with(user_input)

def test_get_or_create_user_missing_templates_reinitializes(self, service):
"""Test that existing user with missing templates triggers reinitialization."""
user_input = UserInput(
ibm_id="user1", email="user@test.com", name="User", role="user", preferred_provider_id=None
)
user_id = uuid4()
existing_user = UserOutput(
id=user_id,
ibm_id="user1",
email="user@test.com",
name="User",
role="user",
preferred_provider_id=None,
created_at="2024-01-01T00:00:00Z",
updated_at="2024-01-01T00:00:00Z",
)

# Mock existing user but with < 3 templates
service.user_repository.get_by_ibm_id.return_value = existing_user
service.prompt_template_service.get_user_templates.return_value = [Mock()] # Only 1 template

# Mock successful reinitialization
service.user_provider_service.initialize_user_defaults.return_value = (
Mock(), # provider
[Mock(), Mock(), Mock()], # 3 templates
Mock(), # parameters
)

result = service.get_or_create_user(user_input)

# Assert reinitialization was triggered
assert result is existing_user
service.prompt_template_service.get_user_templates.assert_called_once_with(existing_user.id)
service.user_provider_service.initialize_user_defaults.assert_called_once_with(existing_user.id)

def test_get_or_create_user_with_sufficient_templates_skips_reinit(self, service):
"""Test that existing user with 3+ templates skips reinitialization."""
user_input = UserInput(
ibm_id="user1", email="user@test.com", name="User", role="user", preferred_provider_id=None
)
user_id = uuid4()
existing_user = UserOutput(
id=user_id,
ibm_id="user1",
email="user@test.com",
name="User",
role="user",
preferred_provider_id=None,
created_at="2024-01-01T00:00:00Z",
updated_at="2024-01-01T00:00:00Z",
)

service.user_repository.get_by_ibm_id.return_value = existing_user
# User has 3 templates - sufficient
service.prompt_template_service.get_user_templates.return_value = [Mock(), Mock(), Mock()]

result = service.get_or_create_user(user_input)

# Assert reinitialization was NOT triggered
assert result is existing_user
service.prompt_template_service.get_user_templates.assert_called_once_with(existing_user.id)
service.user_provider_service.initialize_user_defaults.assert_not_called()

def test_get_or_create_user_by_fields_red_phase(self, service):
"""RED: Test get_or_create_user_by_fields convenience method."""
existing_user = UserOutput(
Expand All @@ -225,6 +294,8 @@ def test_get_or_create_user_by_fields_red_phase(self, service):
)

service.user_repository.get_by_ibm_id.return_value = existing_user
# Mock that user has 3 templates (sufficient)
service.prompt_template_service.get_user_templates.return_value = [Mock(), Mock(), Mock()]

result = service.get_or_create_user_by_fields(
ibm_id="field_user", email="field@example.com", name="Field User", role="admin"
Expand Down
Loading
Loading