diff --git a/state-manager/app/controller/executed_state.py b/state-manager/app/controller/executed_state.py index c11febd3..66d2ba46 100644 --- a/state-manager/app/controller/executed_state.py +++ b/state-manager/app/controller/executed_state.py @@ -1,5 +1,4 @@ from beanie import PydanticObjectId -from beanie.operators import In from app.models.executed_models import ExecutedRequestModel, ExecutedResponseModel from fastapi import HTTPException, status, BackgroundTasks @@ -7,7 +6,7 @@ from app.models.db.state import State from app.models.state_status_enum import StateStatusEnum from app.singletons.logs_manager import LogsManager -from app.tasks.create_next_state import create_next_state +from app.tasks.create_next_states import create_next_states logger = LogsManager().get_logger() @@ -23,19 +22,20 @@ async def executed_state(namespace_name: str, state_id: PydanticObjectId, body: if state.status != StateStatusEnum.QUEUED: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not queued") + next_state_ids = [] if len(body.outputs) == 0: state.status = StateStatusEnum.EXECUTED state.outputs = {} await state.save() - background_tasks.add_task(create_next_state, state) + next_state_ids.append(state.id) else: state.outputs = body.outputs[0] state.status = StateStatusEnum.EXECUTED await state.save() - background_tasks.add_task(create_next_state, state) + next_state_ids.append(state.id) new_states = [] for output in body.outputs[1:]: @@ -54,16 +54,9 @@ async def executed_state(namespace_name: str, state_id: PydanticObjectId, body: if len(new_states) > 0: inserted_ids = (await State.insert_many(new_states)).inserted_ids + next_state_ids.extend(inserted_ids) - inserted_states = await State.find( - In(State.id, inserted_ids) - ).to_list() - - if len(inserted_states) != len(new_states): - raise RuntimeError(f"Failed to insert all new states. Expected {len(new_states)} states, but only {len(inserted_states)} were inserted") - - for inserted_state in inserted_states: - background_tasks.add_task(create_next_state, inserted_state) + background_tasks.add_task(create_next_states, next_state_ids, state.identifier, state.namespace_name, state.graph_name, state.parents) return ExecutedResponseModel(status=StateStatusEnum.EXECUTED) diff --git a/state-manager/app/models/db/graph_template_model.py b/state-manager/app/models/db/graph_template_model.py index ee3d8d35..6037f0bb 100644 --- a/state-manager/app/models/db/graph_template_model.py +++ b/state-manager/app/models/db/graph_template_model.py @@ -1,7 +1,9 @@ import base64 +import time +import asyncio from .base import BaseDatabaseModel -from pydantic import Field, field_validator +from pydantic import Field, field_validator, PrivateAttr from typing import Optional, List from ..graph_template_validation_status import GraphTemplateValidationStatus from ..node_template_model import NodeTemplate @@ -17,6 +19,7 @@ class GraphTemplate(BaseDatabaseModel): validation_status: GraphTemplateValidationStatus = Field(..., description="Validation status of the graph") validation_errors: Optional[List[str]] = Field(None, description="Validation errors of the graph") secrets: Dict[str, str] = Field(default_factory=dict, description="Secrets of the graph") + _node_by_identifier: Dict[str, NodeTemplate] | None = PrivateAttr(default=None) class Settings: indexes = [ @@ -27,12 +30,18 @@ class Settings: ) ] + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _build_node_by_identifier(self) -> None: + self._node_by_identifier = {node.identifier: node for node in self.nodes} + def get_node_by_identifier(self, identifier: str) -> NodeTemplate | None: """Get a node by its identifier using O(1) dictionary lookup.""" - for node in self.nodes: - if node.identifier == identifier: - return node - return None + if self._node_by_identifier is None: + self._build_node_by_identifier() + + return self._node_by_identifier.get(identifier) # type: ignore @field_validator('secrets') @classmethod @@ -78,4 +87,40 @@ def get_secret(self, secret_name: str) -> str | None: return None if secret_name not in self.secrets: return None - return get_encrypter().decrypt(self.secrets[secret_name]) \ No newline at end of file + return get_encrypter().decrypt(self.secrets[secret_name]) + + def is_valid(self) -> bool: + return self.validation_status == GraphTemplateValidationStatus.VALID + + def is_validating(self) -> bool: + return self.validation_status in (GraphTemplateValidationStatus.ONGOING, GraphTemplateValidationStatus.PENDING) + + @staticmethod + async def get(namespace: str, graph_name: str) -> "GraphTemplate": + graph_template = await GraphTemplate.find_one(GraphTemplate.namespace == namespace, GraphTemplate.name == graph_name) + if not graph_template: + raise ValueError(f"Graph template not found for namespace: {namespace} and graph name: {graph_name}") + return graph_template + + @staticmethod + async def get_valid(namespace: str, graph_name: str, polling_interval: float = 1.0, timeout: float = 300.0) -> "GraphTemplate": + # Validate polling_interval and timeout + if polling_interval <= 0: + raise ValueError("polling_interval must be positive") + if timeout <= 0: + raise ValueError("timeout must be positive") + + # Coerce polling_interval to a sensible minimum + if polling_interval < 0.1: + polling_interval = 0.1 + + start_time = time.monotonic() + while time.monotonic() - start_time < timeout: + graph_template = await GraphTemplate.get(namespace, graph_name) + if graph_template.is_valid(): + return graph_template + if graph_template.is_validating(): + await asyncio.sleep(polling_interval) + else: + raise ValueError(f"Graph template is in a non-validating state: {graph_template.validation_status.value} for namespace: {namespace} and graph name: {graph_name}") + raise ValueError(f"Graph template is not valid for namespace: {namespace} and graph name: {graph_name} after {timeout} seconds") \ No newline at end of file diff --git a/state-manager/app/models/state_status_enum.py b/state-manager/app/models/state_status_enum.py index a536b880..8da97002 100644 --- a/state-manager/app/models/state_status_enum.py +++ b/state-manager/app/models/state_status_enum.py @@ -11,4 +11,5 @@ class StateStatusEnum(str, Enum): TIMEDOUT = 'TIMEDOUT' ERRORED = 'ERRORED' CANCELLED = 'CANCELLED' - SUCCESS = 'SUCCESS' \ No newline at end of file + SUCCESS = 'SUCCESS' + NEXT_CREATED_ERROR = 'NEXT_CREATED_ERROR' diff --git a/state-manager/app/tasks/create_next_state.py b/state-manager/app/tasks/create_next_state.py deleted file mode 100644 index bf03679e..00000000 --- a/state-manager/app/tasks/create_next_state.py +++ /dev/null @@ -1,157 +0,0 @@ -import asyncio -import time - -from app.models.db.state import State -from app.models.db.graph_template_model import GraphTemplate -from app.models.graph_template_validation_status import GraphTemplateValidationStatus -from app.models.db.registered_node import RegisteredNode -from app.models.state_status_enum import StateStatusEnum -from beanie.operators import NE -from app.singletons.logs_manager import LogsManager - -from json_schema_to_pydantic import create_model - -logger = LogsManager().get_logger() - -async def create_next_state(state: State): - logger.info(f"Creating next state for {state.identifier}") - graph_template = None - - if state is None or state.id is None: - raise ValueError("State is not valid") - try: - start_time = time.time() - timeout_seconds = 300 # 5 minutes - - while True: - graph_template = await GraphTemplate.find_one(GraphTemplate.name == state.graph_name, GraphTemplate.namespace == state.namespace_name) - if not graph_template: - raise Exception(f"Graph template {state.graph_name} not found") - if graph_template.validation_status == GraphTemplateValidationStatus.VALID: - break - - # Check if we've exceeded the timeout - if time.time() - start_time > timeout_seconds: - raise Exception(f"Timeout waiting for graph template {state.graph_name} to become valid after {timeout_seconds} seconds") - - await asyncio.sleep(1) - - node_template = graph_template.get_node_by_identifier(state.identifier) - if not node_template: - raise Exception(f"Node template {state.identifier} not found") - - next_node_identifier = node_template.next_nodes - if not next_node_identifier: - state.status = StateStatusEnum.SUCCESS - await state.save() - return - - cache_states = {} - - parents = state.parents | {state.identifier: state.id} - - for identifier in next_node_identifier: - next_node_template = graph_template.get_node_by_identifier(identifier) - if not next_node_template: - continue - - depends_satisfied = True - if next_node_template.unites is not None and len(next_node_template.unites) > 0: - pending_count = 0 - for depend in next_node_template.unites: - if depend.identifier == state.identifier: - continue - else: - root_parent = state.parents.get(depend.identifier) - if root_parent is None: - raise Exception(f"Root parent of {depend.identifier} not found") - - pending_count = await State.find( - State.identifier == depend.identifier, - State.namespace_name == state.namespace_name, - State.graph_name == state.graph_name, - NE(State.status, StateStatusEnum.SUCCESS), - {f"parents.{depend.identifier}": parents[depend.identifier]} - ).count() - if pending_count > 0: - logger.info(f"Node {next_node_template.identifier} depends on {depend.identifier} but it is not satisfied") - depends_satisfied = False - break - - if not depends_satisfied: - continue - - registered_node = await RegisteredNode.find_one(RegisteredNode.name == next_node_template.node_name, RegisteredNode.namespace == next_node_template.namespace) - - if not registered_node: - raise Exception(f"Registered node {next_node_template.node_name} not found") - - next_node_input_model = create_model(registered_node.inputs_schema) - next_node_input_data = {} - - for field_name, _ in next_node_input_model.model_fields.items(): - temporary_input = next_node_template.inputs[field_name] - splits = temporary_input.split("${{") - - if len(splits) == 0: - next_node_input_data[field_name] = temporary_input - continue - - constructed_string = "" - for split in splits: - if "}}" in split: - placeholder_content = split.split("}}")[0] - parts = [p.strip() for p in placeholder_content.split('.')] - - if len(parts) != 3 or parts[1] != 'outputs': - raise Exception(f"Invalid input placeholder format: '{placeholder_content}' for field {field_name}") - - input_identifier = parts[0] - input_field = parts[2] - - parent_id = parents.get(input_identifier) - - if not parent_id: - raise Exception(f"Parent identifier '{input_identifier}' not found in state parents.") - - if parent_id not in cache_states: - dependent_state = await State.get(parent_id) - if not dependent_state: - raise Exception(f"Dependent state {input_identifier} not found") - cache_states[parent_id] = dependent_state - else: - dependent_state = cache_states[parent_id] - - if input_field not in dependent_state.outputs: - raise Exception(f"Input field {input_field} not found in dependent state {input_identifier}") - - constructed_string += dependent_state.outputs[input_field] + split.split("}}")[1] - - else: - constructed_string += split - - next_node_input_data[field_name] = constructed_string - - new_state = State( - node_name=next_node_template.node_name, - namespace_name=next_node_template.namespace, - identifier=next_node_template.identifier, - graph_name=state.graph_name, - run_id=state.run_id, - status=StateStatusEnum.CREATED, - inputs=next_node_input_data, - outputs={}, - error=None, - parents=parents - ) - - await new_state.save() - - state.status = StateStatusEnum.SUCCESS - await state.save() - - except Exception as e: - state.status = StateStatusEnum.ERRORED - state.error = str(e) - await state.save() - return diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py new file mode 100644 index 00000000..bbf6b182 --- /dev/null +++ b/state-manager/app/tasks/create_next_states.py @@ -0,0 +1,220 @@ +from beanie import PydanticObjectId +from beanie.operators import In, NE +from app.singletons.logs_manager import LogsManager +from app.models.db.graph_template_model import GraphTemplate +from app.models.db.state import State +from app.models.state_status_enum import StateStatusEnum +from app.models.node_template_model import NodeTemplate +from app.models.db.registered_node import RegisteredNode +from json_schema_to_pydantic import create_model +from pydantic import BaseModel +from typing import Type + +logger = LogsManager().get_logger() + +class Dependent(BaseModel): + identifier: str + field: str + tail: str + value: str | None = None + +class DependentString(BaseModel): + head: str + dependents: dict[int, Dependent] + + def generate_string(self) -> str: + base = self.head + for key in sorted(self.dependents.keys()): + dependent = self.dependents[key] + if dependent.value is None: + raise ValueError(f"Dependent value is not set for: {dependent}") + base += dependent.value + dependent.tail + return base + +async def mark_success_states(state_ids: list[PydanticObjectId]): + await State.find( + In(State.id, state_ids) + ).set({ + "status": StateStatusEnum.SUCCESS + }) # type: ignore + + +async def check_unites_satisfied(namespace: str, graph_name: str, node_template: NodeTemplate, parents: dict[str, PydanticObjectId]) -> bool: + if node_template.unites is None or len(node_template.unites) == 0: + return True + + for unit in node_template.unites: + unites_id = parents.get(unit.identifier) + if not unites_id: + raise ValueError(f"Unit identifier not found in parents: {unit.identifier}") + else: + pending_count = await State.find( + State.identifier == unit.identifier, + State.namespace_name == namespace, + State.graph_name == graph_name, + NE(State.status, StateStatusEnum.SUCCESS), + { + f"parents.{unit.identifier}": unites_id + } + ).count() + if pending_count > 0: + return False + return True + +def get_dependents(syntax_string: str) -> DependentString: + splits = syntax_string.split("${{") + if len(splits) <= 1: + return DependentString(head=syntax_string, dependents={}) + + dependent_string = DependentString(head=splits[0], dependents={}) + order = 0 + + for split in splits[1:]: + if "}}" not in split: + raise ValueError(f"Invalid syntax string placeholder {split} for: {syntax_string} '${{' not closed") + placeholder_content, tail = split.split("}}") + + parts = [p.strip() for p in placeholder_content.split(".")] + if len(parts) != 3 or parts[1] != "outputs": + raise ValueError(f"Invalid syntax string placeholder {placeholder_content} for: {syntax_string}") + + dependent_string.dependents[order] = Dependent(identifier=parts[0], field=parts[2], tail=tail) + order += 1 + + return dependent_string + +def validate_dependencies(next_state_node_template: NodeTemplate, next_state_input_model: Type[BaseModel], identifier: str, parents: dict[str, State]) -> None: + """Validate that all dependencies exist before processing them.""" + # 1) Confirm each model field is present in next_state_node_template.inputs + for field_name in next_state_input_model.model_fields.keys(): + if field_name not in next_state_node_template.inputs: + raise ValueError(f"Field '{field_name}' not found in inputs for template '{next_state_node_template.identifier}'") + + dependency_string = get_dependents(next_state_node_template.inputs[field_name]) + + for dependent in dependency_string.dependents.values(): + # 2) For each placeholder, verify the identifier is either current or present in parents + if dependent.identifier != identifier and dependent.identifier not in parents: + raise KeyError(f"Identifier '{dependent.identifier}' not found in parents for template '{next_state_node_template.identifier}'") + + # 3) For each dependent, verify the target output field exists on the resolved state + if dependent.identifier == identifier: + # This will be resolved to current_state later, skip validation here + continue + else: + parent_state = parents[dependent.identifier] + if dependent.field not in parent_state.outputs: + raise AttributeError(f"Output field '{dependent.field}' not found on state '{dependent.identifier}' for template '{next_state_node_template.identifier}'") + + +async def create_next_states(state_ids: list[PydanticObjectId], identifier: str, namespace: str, graph_name: str, parents_ids: dict[str, PydanticObjectId]): + + try: + if len(state_ids) == 0: + raise ValueError("State ids is empty") + + graph_template = await GraphTemplate.get_valid(namespace, graph_name) + + current_state_node_template = graph_template.get_node_by_identifier(identifier) + if not current_state_node_template: + raise ValueError(f"Current state node template not found for identifier: {identifier}") + + next_state_identifiers = current_state_node_template.next_nodes + if not next_state_identifiers or len(next_state_identifiers) == 0: + await mark_success_states(state_ids) + return + + cached_registered_nodes = {} + cached_input_models = {} + new_states = [] + + async def get_registered_node(node_template: NodeTemplate) -> RegisteredNode: + if node_template.node_name not in cached_registered_nodes: + registered_node = await RegisteredNode.find_one( + RegisteredNode.name == node_template.node_name, + RegisteredNode.namespace == node_template.namespace, + ) + if not registered_node: + raise ValueError(f"Registered node not found for node name: {node_template.node_name} and namespace: {node_template.namespace}") + cached_registered_nodes[node_template.node_name] = registered_node + return cached_registered_nodes[node_template.node_name] + + async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: + if node_template.node_name not in cached_input_models: + cached_input_models[node_template.node_name] = create_model((await get_registered_node(node_template)).inputs_schema) + return cached_input_models[node_template.node_name] + + current_states = await State.find( + In(State.id, state_ids) + ).to_list() + + if not parents_ids: + parent_states = [] + else: + parent_states = await State.find( + In(State.id, list(parents_ids.values())) + ).to_list() + + parents = {} + for parent_state in parent_states: + parents[parent_state.identifier] = parent_state + + + for next_state_identifier in next_state_identifiers: + next_state_node_template = graph_template.get_node_by_identifier(next_state_identifier) + if not next_state_node_template: + raise ValueError(f"Next state node template not found for identifier: {next_state_identifier}") + + if not await check_unites_satisfied(namespace, graph_name, next_state_node_template, parents_ids): + continue + + next_state_input_model = await get_input_model(next_state_node_template) + validate_dependencies(next_state_node_template, next_state_input_model, identifier, parents) + + for current_state in current_states: + next_state_input_data = {} + + for field_name, _ in next_state_input_model.model_fields.items(): + dependency_string = get_dependents(next_state_node_template.inputs[field_name]) + + for key in sorted(dependency_string.dependents.keys()): + if dependency_string.dependents[key].identifier == identifier: + if dependency_string.dependents[key].field not in current_state.outputs: + raise AttributeError(f"Output field '{dependency_string.dependents[key].field}' not found on current state '{identifier}' for template '{next_state_node_template.identifier}'") + dependency_string.dependents[key].value = current_state.outputs[dependency_string.dependents[key].field] + else: + dependency_string.dependents[key].value = parents[dependency_string.dependents[key].identifier].outputs[dependency_string.dependents[key].field] + + next_state_input_data[field_name] = dependency_string.generate_string() + + new_parents = { + **parents_ids, + identifier: current_state.id + } + + new_states.append( + State( + node_name=next_state_node_template.node_name, + identifier=next_state_node_template.identifier, + namespace_name=next_state_node_template.namespace, + graph_name=graph_name, + status=StateStatusEnum.CREATED, + parents=new_parents, + inputs=next_state_input_data, + outputs={}, + run_id=current_state.run_id, + error=None + ) + ) + + await State.insert_many(new_states) + await mark_success_states(state_ids) + + except Exception as e: + await State.find( + In(State.id, state_ids) + ).set({ + "status": StateStatusEnum.NEXT_CREATED_ERROR, + "error": str(e) + }) # type: ignore + raise \ No newline at end of file diff --git a/state-manager/tests/unit/controller/test_executed_state.py b/state-manager/tests/unit/controller/test_executed_state.py index 14936d78..6b234355 100644 --- a/state-manager/tests/unit/controller/test_executed_state.py +++ b/state-manager/tests/unit/controller/test_executed_state.py @@ -47,10 +47,10 @@ def mock_executed_request(self): ) @patch('app.controller.executed_state.State') - @patch('app.controller.executed_state.create_next_state') + @patch('app.controller.executed_state.create_next_states') async def test_executed_state_success_single_output( self, - mock_create_next_state, + mock_create_next_states, mock_state_class, mock_namespace, mock_state_id, @@ -84,13 +84,13 @@ async def test_executed_state_success_single_output( # Assert assert result.status == StateStatusEnum.EXECUTED assert mock_state_class.find_one.call_count == 1 # Called once for finding - mock_background_tasks.add_task.assert_called_once_with(mock_create_next_state, mock_state) + mock_background_tasks.add_task.assert_called_once_with(mock_create_next_states, [mock_state.id], mock_state.identifier, mock_state.namespace_name, mock_state.graph_name, mock_state.parents) @patch('app.controller.executed_state.State') - @patch('app.controller.executed_state.create_next_state') + @patch('app.controller.executed_state.create_next_states') async def test_executed_state_success_multiple_outputs( self, - mock_create_next_state, + mock_create_next_states, mock_state_class, mock_namespace, mock_state_id, @@ -140,9 +140,9 @@ async def test_executed_state_success_multiple_outputs( assert result.status == StateStatusEnum.EXECUTED # Should create 2 additional states (3 outputs total, 1 for main state, 2 new states) assert mock_state_class.call_count == 2 - # Should add 3 background tasks (1 for main state + 2 for new states) - assert mock_background_tasks.add_task.call_count == 3 - # State.find_one should be called multiple times: once for finding, once for updating main state, and twice in the loop + # Should add 1 background task with all state IDs + assert mock_background_tasks.add_task.call_count == 1 + # State.find_one should be called once for finding the state assert mock_state_class.find_one.call_count == 1 @patch('app.controller.executed_state.State') @@ -202,10 +202,10 @@ async def test_executed_state_not_queued( assert exc_info.value.detail == "State is not queued" @patch('app.controller.executed_state.State') - @patch('app.controller.executed_state.create_next_state') + @patch('app.controller.executed_state.create_next_states') async def test_executed_state_empty_outputs( self, - mock_create_next_state, + mock_create_next_states, mock_state_class, mock_namespace, mock_state_id, @@ -239,7 +239,7 @@ async def test_executed_state_empty_outputs( # Assert assert result.status == StateStatusEnum.EXECUTED assert mock_state.outputs == {} - mock_background_tasks.add_task.assert_called_once_with(mock_create_next_state, mock_state) + mock_background_tasks.add_task.assert_called_once_with(mock_create_next_states, [mock_state.id], mock_state.identifier, mock_state.namespace_name, mock_state.graph_name, mock_state.parents) @patch('app.controller.executed_state.State') async def test_executed_state_database_error( @@ -268,10 +268,10 @@ async def test_executed_state_database_error( assert str(exc_info.value) == "Database error" @patch('app.controller.executed_state.State') - @patch('app.controller.executed_state.create_next_state') + @patch('app.controller.executed_state.create_next_states') async def test_executed_state_general_exception_handling( self, - mock_create_next_state, + mock_create_next_states, mock_state_class, mock_namespace, mock_state_id, @@ -297,3 +297,280 @@ async def test_executed_state_general_exception_handling( assert str(exc_info.value) == "Save error" + @patch('app.controller.executed_state.State') + @patch('app.controller.executed_state.create_next_states') + async def test_executed_state_state_id_none( + self, + mock_create_next_states, + mock_state_class, + mock_namespace, + mock_state_id, + mock_executed_request, + mock_background_tasks, + mock_request_id + ): + """Test when state is found but has None ID""" + # Arrange + mock_state = MagicMock() + mock_state.id = None + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await executed_state( + mock_namespace, + mock_state_id, + mock_executed_request, + mock_request_id, + mock_background_tasks + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert exc_info.value.detail == "State not found" + + @patch('app.controller.executed_state.State') + @patch('app.controller.executed_state.create_next_states') + async def test_executed_state_insert_many_partial_failure( + self, + mock_create_next_states, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state, + mock_background_tasks, + mock_request_id + ): + """Test when insert_many returns partial results (this is valid behavior)""" + # Arrange + executed_request = ExecutedRequestModel( + outputs=[ + {"result": "success1"}, + {"result": "success2"}, + {"result": "success3"} + ] + ) + + mock_state_class.find_one = AsyncMock(return_value=mock_state) + mock_state.save = AsyncMock() + + # Mock partial insert - only 1 state inserted instead of 2 (this is valid) + new_ids = [PydanticObjectId()] + mock_state_class.insert_many = AsyncMock(return_value=MagicMock(inserted_ids=new_ids)) + mock_state_class.find = MagicMock(return_value=AsyncMock(to_list=AsyncMock(return_value=[mock_state]))) + + # Act + result = await executed_state( + mock_namespace, + mock_state_id, + executed_request, + mock_request_id, + mock_background_tasks + ) + + # Assert - Should complete successfully with partial results + assert result.status == StateStatusEnum.EXECUTED + + @patch('app.controller.executed_state.State') + @patch('app.controller.executed_state.create_next_states') + async def test_executed_state_insert_many_complete_failure( + self, + mock_create_next_states, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state, + mock_background_tasks, + mock_request_id + ): + """Test when insert_many returns no inserted states (this is valid behavior)""" + # Arrange + executed_request = ExecutedRequestModel( + outputs=[ + {"result": "success1"}, + {"result": "success2"} + ] + ) + + mock_state_class.find_one = AsyncMock(return_value=mock_state) + mock_state.save = AsyncMock() + + # Mock complete insert failure - no states inserted (this is valid) + mock_state_class.insert_many = AsyncMock(return_value=MagicMock(inserted_ids=[])) + mock_state_class.find = MagicMock(return_value=AsyncMock(to_list=AsyncMock(return_value=[]))) + + # Act + result = await executed_state( + mock_namespace, + mock_state_id, + executed_request, + mock_request_id, + mock_background_tasks + ) + + # Assert - Should complete successfully even with no new states + assert result.status == StateStatusEnum.EXECUTED + + @patch('app.controller.executed_state.State') + @patch('app.controller.executed_state.create_next_states') + @patch('app.controller.executed_state.logger') + async def test_executed_state_logging_info_and_error( + self, + mock_logger, + mock_create_next_states, + mock_state_class, + mock_namespace, + mock_state_id, + mock_executed_request, + mock_background_tasks, + mock_request_id + ): + """Test that proper logging occurs during success and error scenarios""" + # Arrange - Success scenario + mock_state = MagicMock() + mock_state.id = PydanticObjectId() + mock_state.status = StateStatusEnum.QUEUED + mock_state.save = AsyncMock() + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + # Act - Success scenario + await executed_state( + mock_namespace, + mock_state_id, + mock_executed_request, + mock_request_id, + mock_background_tasks + ) + + # Assert - Success logging + mock_logger.info.assert_called_once_with( + f"Executed state {mock_state_id} for namespace {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + # Arrange - Error scenario + mock_logger.reset_mock() + mock_state_class.find_one = AsyncMock(side_effect=Exception("Test error")) + + # Act - Error scenario + with pytest.raises(Exception): + await executed_state( + mock_namespace, + mock_state_id, + mock_executed_request, + mock_request_id, + mock_background_tasks + ) + + # Assert - Error logging + mock_logger.error.assert_called_once() + call_args = mock_logger.error.call_args + assert f"Error executing state {mock_state_id} for namespace {mock_namespace}" in str(call_args) + + @patch('app.controller.executed_state.State') + @patch('app.controller.executed_state.create_next_states') + async def test_executed_state_preserves_state_attributes_for_new_states( + self, + mock_create_next_states, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state, + mock_background_tasks, + mock_request_id + ): + """Test that new states preserve all necessary attributes from the original state""" + # Arrange + executed_request = ExecutedRequestModel( + outputs=[ + {"result": "success1"}, + {"result": "success2"} + ] + ) + + # Set up specific state attributes + mock_state.node_name = "test_node" + mock_state.namespace_name = "test_namespace" + mock_state.identifier = "test_identifier" + mock_state.graph_name = "test_graph" + mock_state.run_id = "test_run_id" + mock_state.inputs = {"key": "value"} + mock_state.parents = {"parent1": PydanticObjectId()} + + mock_state_class.find_one = AsyncMock(return_value=mock_state) + mock_state.save = AsyncMock() + + new_ids = [PydanticObjectId()] + mock_state_class.insert_many = AsyncMock(return_value=MagicMock(inserted_ids=new_ids)) + mock_state_class.find = MagicMock(return_value=AsyncMock(to_list=AsyncMock(return_value=[mock_state]))) + + # Act + await executed_state( + mock_namespace, + mock_state_id, + executed_request, + mock_request_id, + mock_background_tasks + ) + + # Assert that State was called with correct parameters for new state creation + state_call = mock_state_class.call_args + assert state_call[1]['node_name'] == mock_state.node_name + assert state_call[1]['namespace_name'] == mock_state.namespace_name + assert state_call[1]['identifier'] == mock_state.identifier + assert state_call[1]['graph_name'] == mock_state.graph_name + assert state_call[1]['run_id'] == mock_state.run_id + assert state_call[1]['inputs'] == mock_state.inputs + assert state_call[1]['parents'] == mock_state.parents + assert state_call[1]['status'] == StateStatusEnum.EXECUTED + assert state_call[1]['outputs'] == {"result": "success2"} + assert state_call[1]['error'] is None + + @patch('app.controller.executed_state.State') + @patch('app.controller.executed_state.create_next_states') + async def test_executed_state_all_status_transitions( + self, + mock_create_next_states, + mock_state_class, + mock_namespace, + mock_state_id, + mock_state, + mock_background_tasks, + mock_request_id + ): + """Test all valid status transitions in executed_state""" + # Test with QUEUED status (valid) + mock_state.status = StateStatusEnum.QUEUED + mock_state_class.find_one = AsyncMock(return_value=mock_state) + mock_state.save = AsyncMock() + + executed_request = ExecutedRequestModel(outputs=[{"result": "success"}]) + + result = await executed_state( + mock_namespace, + mock_state_id, + executed_request, + mock_request_id, + mock_background_tasks + ) + + assert result.status == StateStatusEnum.EXECUTED + assert mock_state.status == StateStatusEnum.EXECUTED + + # Test with invalid statuses + for invalid_status in [StateStatusEnum.CREATED, StateStatusEnum.EXECUTED, + StateStatusEnum.SUCCESS, StateStatusEnum.ERRORED]: + mock_state.status = invalid_status + mock_state_class.find_one = AsyncMock(return_value=mock_state) + + with pytest.raises(HTTPException) as exc_info: + await executed_state( + mock_namespace, + mock_state_id, + executed_request, + mock_request_id, + mock_background_tasks + ) + + assert exc_info.value.status_code == status.HTTP_400_BAD_REQUEST + assert exc_info.value.detail == "State is not queued" + diff --git a/state-manager/tests/unit/controller/test_get_current_states.py b/state-manager/tests/unit/controller/test_get_current_states.py new file mode 100644 index 00000000..b7f7cb81 --- /dev/null +++ b/state-manager/tests/unit/controller/test_get_current_states.py @@ -0,0 +1,255 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.controller.get_current_states import get_current_states +from app.models.db.state import State +from app.models.state_status_enum import StateStatusEnum + + +class TestGetCurrentStates: + """Test cases for get_current_states function""" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_states(self): + """Create mock states for testing""" + states = [] + for i in range(3): + state = MagicMock(spec=State) + state.id = f"state_id_{i}" + state.namespace_name = "test_namespace" + state.status = StateStatusEnum.CREATED + state.identifier = f"node_{i}" + state.graph_name = "test_graph" + state.run_id = f"run_{i}" + states.append(state) + return states + + @patch('app.controller.get_current_states.State') + @patch('app.controller.get_current_states.LogsManager') + async def test_get_current_states_success( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_request_id, + mock_states + ): + """Test successful retrieval of current states""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_states) + mock_state_class.find.return_value = mock_query + + # Act + result = await get_current_states(mock_namespace, mock_request_id) + + # Assert + assert result == mock_states + assert len(result) == 3 + mock_state_class.find.assert_called_once() + mock_query.to_list.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Fetching current states for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Found {len(mock_states)} states for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.get_current_states.State') + @patch('app.controller.get_current_states.LogsManager') + async def test_get_current_states_empty_result( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_request_id + ): + """Test when no states are found""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_state_class.find.return_value = mock_query + + # Act + result = await get_current_states(mock_namespace, mock_request_id) + + # Assert + assert result == [] + assert len(result) == 0 + mock_state_class.find.assert_called_once() + mock_query.to_list.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Fetching current states for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Found 0 states for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.get_current_states.State') + @patch('app.controller.get_current_states.LogsManager') + async def test_get_current_states_database_error( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_request_id + ): + """Test handling of database errors""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(side_effect=Exception("Database connection error")) + mock_state_class.find.return_value = mock_query + + # Act & Assert + with pytest.raises(Exception, match="Database connection error"): + await get_current_states(mock_namespace, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + error_call = mock_logger.error.call_args + assert "Error fetching current states for namespace test_namespace" in str(error_call) + + @patch('app.controller.get_current_states.State') + @patch('app.controller.get_current_states.LogsManager') + async def test_get_current_states_find_error( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_request_id + ): + """Test error during State.find operation""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_state_class.find.side_effect = Exception("Find operation failed") + + # Act & Assert + with pytest.raises(Exception, match="Find operation failed"): + await get_current_states(mock_namespace, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + + @patch('app.controller.get_current_states.State') + @patch('app.controller.get_current_states.LogsManager') + async def test_get_current_states_filter_criteria( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_request_id, + mock_states + ): + """Test that the correct filter criteria are used""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_states) + mock_state_class.find.return_value = mock_query + + # Act + await get_current_states(mock_namespace, mock_request_id) + + # Assert that State.find was called with the correct namespace filter + mock_state_class.find.assert_called_once() + call_args = mock_state_class.find.call_args[0] + # The filter should match the namespace_name + assert len(call_args) == 1 # Should have one filter condition + + @patch('app.controller.get_current_states.State') + @patch('app.controller.get_current_states.LogsManager') + async def test_get_current_states_different_namespaces( + self, + mock_logs_manager, + mock_state_class, + mock_request_id + ): + """Test with different namespace values""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_state_class.find.return_value = mock_query + + namespaces = ["prod", "staging", "dev", "test-123", ""] + + # Act & Assert + for namespace in namespaces: + mock_state_class.reset_mock() + mock_logger.reset_mock() + + result = await get_current_states(namespace, mock_request_id) + + assert result == [] + mock_state_class.find.assert_called_once() + mock_logger.info.assert_any_call( + f"Fetching current states for namespace: {namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.get_current_states.State') + @patch('app.controller.get_current_states.LogsManager') + async def test_get_current_states_large_result_set( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_request_id + ): + """Test with large number of states""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + # Create large number of mock states + large_states_list = [] + for i in range(1000): + state = MagicMock(spec=State) + state.id = f"state_{i}" + state.namespace_name = mock_namespace + large_states_list.append(state) + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=large_states_list) + mock_state_class.find.return_value = mock_query + + # Act + result = await get_current_states(mock_namespace, mock_request_id) + + # Assert + assert len(result) == 1000 + mock_logger.info.assert_any_call( + f"Found 1000 states for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) \ No newline at end of file diff --git a/state-manager/tests/unit/controller/test_get_states_by_run_id.py b/state-manager/tests/unit/controller/test_get_states_by_run_id.py new file mode 100644 index 00000000..c108e6be --- /dev/null +++ b/state-manager/tests/unit/controller/test_get_states_by_run_id.py @@ -0,0 +1,373 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.controller.get_states_by_run_id import get_states_by_run_id +from app.models.db.state import State +from app.models.state_status_enum import StateStatusEnum + + +class TestGetStatesByRunId: + """Test cases for get_states_by_run_id function""" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_run_id(self): + return "test-run-id-123" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_states(self, mock_namespace, mock_run_id): + """Create mock states for testing""" + states = [] + for i in range(4): + state = MagicMock(spec=State) + state.id = f"state_id_{i}" + state.namespace_name = mock_namespace + state.run_id = mock_run_id + state.status = StateStatusEnum.CREATED if i % 2 == 0 else StateStatusEnum.EXECUTED + state.identifier = f"node_{i}" + state.graph_name = "test_graph" + states.append(state) + return states + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_success( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_run_id, + mock_request_id, + mock_states + ): + """Test successful retrieval of states by run ID""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_states) + mock_state_class.find.return_value = mock_query + + # Act + result = await get_states_by_run_id(mock_namespace, mock_run_id, mock_request_id) + + # Assert + assert result == mock_states + assert len(result) == 4 + mock_state_class.find.assert_called_once() + mock_query.to_list.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Fetching states for run ID: {mock_run_id} in namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Found {len(mock_states)} states for run ID: {mock_run_id}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_empty_result( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_run_id, + mock_request_id + ): + """Test when no states are found for the run ID""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_state_class.find.return_value = mock_query + + # Act + result = await get_states_by_run_id(mock_namespace, mock_run_id, mock_request_id) + + # Assert + assert result == [] + assert len(result) == 0 + mock_state_class.find.assert_called_once() + mock_query.to_list.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Fetching states for run ID: {mock_run_id} in namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Found 0 states for run ID: {mock_run_id}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_database_error( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_run_id, + mock_request_id + ): + """Test handling of database errors""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(side_effect=Exception("Database connection error")) + mock_state_class.find.return_value = mock_query + + # Act & Assert + with pytest.raises(Exception, match="Database connection error"): + await get_states_by_run_id(mock_namespace, mock_run_id, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + error_call = mock_logger.error.call_args + assert f"Error fetching states for run ID {mock_run_id} in namespace {mock_namespace}" in str(error_call) + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_find_error( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_run_id, + mock_request_id + ): + """Test error during State.find operation""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_state_class.find.side_effect = Exception("Find operation failed") + + # Act & Assert + with pytest.raises(Exception, match="Find operation failed"): + await get_states_by_run_id(mock_namespace, mock_run_id, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_filter_criteria( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_run_id, + mock_request_id, + mock_states + ): + """Test that the correct filter criteria are used""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_states) + mock_state_class.find.return_value = mock_query + + # Act + await get_states_by_run_id(mock_namespace, mock_run_id, mock_request_id) + + # Assert that State.find was called with correct filters + mock_state_class.find.assert_called_once() + call_args = mock_state_class.find.call_args[0] + # Should have two filter conditions: run_id and namespace_name + assert len(call_args) == 2 + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_different_run_ids( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_request_id + ): + """Test with different run ID values""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_state_class.find.return_value = mock_query + + run_ids = ["run-123", "run-abc-456", "test_run_789", "run-with-special-chars-!@#", ""] + + # Act & Assert + for run_id in run_ids: + mock_state_class.reset_mock() + mock_logger.reset_mock() + + result = await get_states_by_run_id(mock_namespace, run_id, mock_request_id) + + assert result == [] + mock_state_class.find.assert_called_once() + mock_logger.info.assert_any_call( + f"Fetching states for run ID: {run_id} in namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_different_namespaces( + self, + mock_logs_manager, + mock_state_class, + mock_run_id, + mock_request_id + ): + """Test with different namespace values""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_state_class.find.return_value = mock_query + + namespaces = ["prod", "staging", "dev", "test-123", ""] + + # Act & Assert + for namespace in namespaces: + mock_state_class.reset_mock() + mock_logger.reset_mock() + + result = await get_states_by_run_id(namespace, mock_run_id, mock_request_id) + + assert result == [] + mock_state_class.find.assert_called_once() + mock_logger.info.assert_any_call( + f"Fetching states for run ID: {mock_run_id} in namespace: {namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_large_result_set( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_run_id, + mock_request_id + ): + """Test with large number of states""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + # Create large number of mock states + large_states_list = [] + for i in range(1500): + state = MagicMock(spec=State) + state.id = f"state_{i}" + state.namespace_name = mock_namespace + state.run_id = mock_run_id + large_states_list.append(state) + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=large_states_list) + mock_state_class.find.return_value = mock_query + + # Act + result = await get_states_by_run_id(mock_namespace, mock_run_id, mock_request_id) + + # Assert + assert len(result) == 1500 + mock_logger.info.assert_any_call( + f"Found 1500 states for run ID: {mock_run_id}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_return_type( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_run_id, + mock_request_id, + mock_states + ): + """Test that the function returns the correct type""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_states) + mock_state_class.find.return_value = mock_query + + # Act + result = await get_states_by_run_id(mock_namespace, mock_run_id, mock_request_id) + + # Assert + assert isinstance(result, list) + for state in result: + assert isinstance(state, MagicMock) # Since we're using mocks + + # Verify each state has expected attributes + for state in result: + assert hasattr(state, 'id') + assert hasattr(state, 'namespace_name') + assert hasattr(state, 'run_id') + assert state.namespace_name == mock_namespace + assert state.run_id == mock_run_id + + @patch('app.controller.get_states_by_run_id.State') + @patch('app.controller.get_states_by_run_id.LogsManager') + async def test_get_states_by_run_id_single_state( + self, + mock_logs_manager, + mock_state_class, + mock_namespace, + mock_run_id, + mock_request_id + ): + """Test with single state result""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + single_state = MagicMock(spec=State) + single_state.id = "single_state_id" + single_state.namespace_name = mock_namespace + single_state.run_id = mock_run_id + single_state.status = StateStatusEnum.SUCCESS + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[single_state]) + mock_state_class.find.return_value = mock_query + + # Act + result = await get_states_by_run_id(mock_namespace, mock_run_id, mock_request_id) + + # Assert + assert len(result) == 1 + assert result[0] == single_state + mock_logger.info.assert_any_call( + f"Found 1 states for run ID: {mock_run_id}", + x_exosphere_request_id=mock_request_id + ) \ No newline at end of file diff --git a/state-manager/tests/unit/controller/test_list_graph_templates.py b/state-manager/tests/unit/controller/test_list_graph_templates.py new file mode 100644 index 00000000..8a6bdd1d --- /dev/null +++ b/state-manager/tests/unit/controller/test_list_graph_templates.py @@ -0,0 +1,437 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.controller.list_graph_templates import list_graph_templates +from app.models.db.graph_template_model import GraphTemplate +from app.models.graph_template_validation_status import GraphTemplateValidationStatus + + +class TestListGraphTemplates: + """Test cases for list_graph_templates function""" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_graph_templates(self): + """Create mock graph templates for testing""" + templates = [] + for i in range(3): + template = MagicMock(spec=GraphTemplate) + template.id = f"template_id_{i}" + template.name = f"test_template_{i}" + template.namespace = "test_namespace" + template.validation_status = GraphTemplateValidationStatus.VALID if i % 2 == 0 else GraphTemplateValidationStatus.INVALID + template.validation_errors = [] if i % 2 == 0 else [f"Error {i}"] + template.nodes = [] + template.secrets = {} + templates.append(template) + return templates + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_success( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id, + mock_graph_templates + ): + """Test successful retrieval of graph templates""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_graph_templates) + mock_graph_template_class.find.return_value = mock_query + + # Act + result = await list_graph_templates(mock_namespace, mock_request_id) + + # Assert + assert result == mock_graph_templates + assert len(result) == 3 + mock_graph_template_class.find.assert_called_once() + mock_query.to_list.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Listing graph templates for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Found {len(mock_graph_templates)} graph templates for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_empty_result( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id + ): + """Test when no graph templates are found""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_graph_template_class.find.return_value = mock_query + + # Act + result = await list_graph_templates(mock_namespace, mock_request_id) + + # Assert + assert result == [] + assert len(result) == 0 + mock_graph_template_class.find.assert_called_once() + mock_query.to_list.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Listing graph templates for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Found 0 graph templates for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_database_error( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id + ): + """Test handling of database errors""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(side_effect=Exception("Database connection error")) + mock_graph_template_class.find.return_value = mock_query + + # Act & Assert + with pytest.raises(Exception, match="Database connection error"): + await list_graph_templates(mock_namespace, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + error_call = mock_logger.error.call_args + assert f"Error listing graph templates for namespace {mock_namespace}" in str(error_call) + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_find_error( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id + ): + """Test error during GraphTemplate.find operation""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_graph_template_class.find.side_effect = Exception("Find operation failed") + + # Act & Assert + with pytest.raises(Exception, match="Find operation failed"): + await list_graph_templates(mock_namespace, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_filter_criteria( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id, + mock_graph_templates + ): + """Test that the correct filter criteria are used""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_graph_templates) + mock_graph_template_class.find.return_value = mock_query + + # Act + await list_graph_templates(mock_namespace, mock_request_id) + + # Assert that GraphTemplate.find was called with the correct namespace filter + mock_graph_template_class.find.assert_called_once() + call_args = mock_graph_template_class.find.call_args[0] + # The filter should match the namespace + assert len(call_args) == 1 # Should have one filter condition + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_different_namespaces( + self, + mock_logs_manager, + mock_graph_template_class, + mock_request_id + ): + """Test with different namespace values""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_graph_template_class.find.return_value = mock_query + + namespaces = ["prod", "staging", "dev", "test-123", ""] + + # Act & Assert + for namespace in namespaces: + mock_graph_template_class.reset_mock() + mock_logger.reset_mock() + + result = await list_graph_templates(namespace, mock_request_id) + + assert result == [] + mock_graph_template_class.find.assert_called_once() + mock_logger.info.assert_any_call( + f"Listing graph templates for namespace: {namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_large_result_set( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id + ): + """Test with large number of graph templates""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + # Create large number of mock templates + large_templates_list = [] + for i in range(200): + template = MagicMock(spec=GraphTemplate) + template.id = f"template_{i}" + template.name = f"template_{i}" + template.namespace = mock_namespace + large_templates_list.append(template) + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=large_templates_list) + mock_graph_template_class.find.return_value = mock_query + + # Act + result = await list_graph_templates(mock_namespace, mock_request_id) + + # Assert + assert len(result) == 200 + mock_logger.info.assert_any_call( + f"Found 200 graph templates for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_return_type( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id, + mock_graph_templates + ): + """Test that the function returns the correct type""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_graph_templates) + mock_graph_template_class.find.return_value = mock_query + + # Act + result = await list_graph_templates(mock_namespace, mock_request_id) + + # Assert + assert isinstance(result, list) + for template in result: + assert isinstance(template, MagicMock) # Since we're using mocks + + # Verify each template has expected attributes (via mock) + for template in result: + assert hasattr(template, 'id') + assert hasattr(template, 'name') + assert hasattr(template, 'namespace') + assert hasattr(template, 'validation_status') + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_mixed_validation_statuses( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id + ): + """Test with templates having different validation statuses""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + # Create templates with different validation statuses + templates = [] + statuses = [GraphTemplateValidationStatus.VALID, + GraphTemplateValidationStatus.INVALID, + GraphTemplateValidationStatus.PENDING] + + for i, status in enumerate(statuses): + template = MagicMock(spec=GraphTemplate) + template.id = f"template_{i}" + template.name = f"template_{i}" + template.namespace = mock_namespace + template.validation_status = status + templates.append(template) + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=templates) + mock_graph_template_class.find.return_value = mock_query + + # Act + result = await list_graph_templates(mock_namespace, mock_request_id) + + # Assert + assert len(result) == 3 + assert result[0].validation_status == GraphTemplateValidationStatus.VALID + assert result[1].validation_status == GraphTemplateValidationStatus.INVALID + assert result[2].validation_status == GraphTemplateValidationStatus.PENDING + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_concurrent_requests( + self, + mock_logs_manager, + mock_graph_template_class, + mock_request_id + ): + """Test handling concurrent requests with different namespaces""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_graph_template_class.find.return_value = mock_query + + # Simulate concurrent requests to different namespaces + namespaces = ["namespace1", "namespace2", "namespace3"] + + # Act + import asyncio + tasks = [list_graph_templates(ns, f"{mock_request_id}_{i}") for i, ns in enumerate(namespaces)] + results = await asyncio.gather(*tasks) + + # Assert + assert len(results) == 3 + for result in results: + assert result == [] + + # Each namespace should have been queried + assert mock_graph_template_class.find.call_count == 3 + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_single_template( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id + ): + """Test with single template result""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + single_template = MagicMock(spec=GraphTemplate) + single_template.id = "single_template_id" + single_template.name = "single_template" + single_template.namespace = mock_namespace + single_template.validation_status = GraphTemplateValidationStatus.VALID + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[single_template]) + mock_graph_template_class.find.return_value = mock_query + + # Act + result = await list_graph_templates(mock_namespace, mock_request_id) + + # Assert + assert len(result) == 1 + assert result[0] == single_template + mock_logger.info.assert_any_call( + f"Found 1 graph templates for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.list_graph_templates.GraphTemplate') + @patch('app.controller.list_graph_templates.LogsManager') + async def test_list_graph_templates_with_complex_templates( + self, + mock_logs_manager, + mock_graph_template_class, + mock_namespace, + mock_request_id + ): + """Test with complex graph templates containing nodes and secrets""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + complex_template = MagicMock(spec=GraphTemplate) + complex_template.id = "complex_template" + complex_template.name = "complex_template" + complex_template.namespace = mock_namespace + complex_template.validation_status = GraphTemplateValidationStatus.VALID + complex_template.nodes = [MagicMock() for _ in range(5)] # Mock 5 nodes + complex_template.secrets = {"secret1": "value1", "secret2": "value2"} + complex_template.validation_errors = None + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[complex_template]) + mock_graph_template_class.find.return_value = mock_query + + # Act + result = await list_graph_templates(mock_namespace, mock_request_id) + + # Assert + assert len(result) == 1 + template = result[0] + assert template == complex_template + assert len(template.nodes) == 5 + assert len(template.secrets) == 2 \ No newline at end of file diff --git a/state-manager/tests/unit/controller/test_list_registered_nodes.py b/state-manager/tests/unit/controller/test_list_registered_nodes.py new file mode 100644 index 00000000..ae287924 --- /dev/null +++ b/state-manager/tests/unit/controller/test_list_registered_nodes.py @@ -0,0 +1,323 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.controller.list_registered_nodes import list_registered_nodes +from app.models.db.registered_node import RegisteredNode + + +class TestListRegisteredNodes: + """Test cases for list_registered_nodes function""" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_registered_nodes(self): + """Create mock registered nodes for testing""" + nodes = [] + for i in range(3): + node = MagicMock(spec=RegisteredNode) + node.id = f"node_id_{i}" + node.name = f"test_node_{i}" + node.namespace = "test_namespace" + node.runtime_name = f"runtime_{i}" + node.runtime_namespace = "test_namespace" + node.inputs_schema = {"type": "object", "properties": {"input": {"type": "string"}}} + node.outputs_schema = {"type": "object", "properties": {"output": {"type": "string"}}} + node.secrets = ["secret1", "secret2"] + nodes.append(node) + return nodes + + @patch('app.controller.list_registered_nodes.RegisteredNode') + @patch('app.controller.list_registered_nodes.LogsManager') + async def test_list_registered_nodes_success( + self, + mock_logs_manager, + mock_registered_node_class, + mock_namespace, + mock_request_id, + mock_registered_nodes + ): + """Test successful retrieval of registered nodes""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_registered_nodes) + mock_registered_node_class.find.return_value = mock_query + + # Act + result = await list_registered_nodes(mock_namespace, mock_request_id) + + # Assert + assert result == mock_registered_nodes + assert len(result) == 3 + mock_registered_node_class.find.assert_called_once() + mock_query.to_list.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Listing registered nodes for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Found {len(mock_registered_nodes)} registered nodes for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.list_registered_nodes.RegisteredNode') + @patch('app.controller.list_registered_nodes.LogsManager') + async def test_list_registered_nodes_empty_result( + self, + mock_logs_manager, + mock_registered_node_class, + mock_namespace, + mock_request_id + ): + """Test when no registered nodes are found""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_registered_node_class.find.return_value = mock_query + + # Act + result = await list_registered_nodes(mock_namespace, mock_request_id) + + # Assert + assert result == [] + assert len(result) == 0 + mock_registered_node_class.find.assert_called_once() + mock_query.to_list.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Listing registered nodes for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Found 0 registered nodes for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.list_registered_nodes.RegisteredNode') + @patch('app.controller.list_registered_nodes.LogsManager') + async def test_list_registered_nodes_database_error( + self, + mock_logs_manager, + mock_registered_node_class, + mock_namespace, + mock_request_id + ): + """Test handling of database errors""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(side_effect=Exception("Database connection error")) + mock_registered_node_class.find.return_value = mock_query + + # Act & Assert + with pytest.raises(Exception, match="Database connection error"): + await list_registered_nodes(mock_namespace, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + error_call = mock_logger.error.call_args + assert "Error listing registered nodes for namespace test_namespace" in str(error_call) + + @patch('app.controller.list_registered_nodes.RegisteredNode') + @patch('app.controller.list_registered_nodes.LogsManager') + async def test_list_registered_nodes_find_error( + self, + mock_logs_manager, + mock_registered_node_class, + mock_namespace, + mock_request_id + ): + """Test error during RegisteredNode.find operation""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_registered_node_class.find.side_effect = Exception("Find operation failed") + + # Act & Assert + with pytest.raises(Exception, match="Find operation failed"): + await list_registered_nodes(mock_namespace, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + + @patch('app.controller.list_registered_nodes.RegisteredNode') + @patch('app.controller.list_registered_nodes.LogsManager') + async def test_list_registered_nodes_filter_criteria( + self, + mock_logs_manager, + mock_registered_node_class, + mock_namespace, + mock_request_id, + mock_registered_nodes + ): + """Test that the correct filter criteria are used""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_registered_nodes) + mock_registered_node_class.find.return_value = mock_query + + # Act + await list_registered_nodes(mock_namespace, mock_request_id) + + # Assert that RegisteredNode.find was called with the correct namespace filter + mock_registered_node_class.find.assert_called_once() + call_args = mock_registered_node_class.find.call_args[0] + # The filter should match the namespace + assert len(call_args) == 1 # Should have one filter condition + + @patch('app.controller.list_registered_nodes.RegisteredNode') + @patch('app.controller.list_registered_nodes.LogsManager') + async def test_list_registered_nodes_different_namespaces( + self, + mock_logs_manager, + mock_registered_node_class, + mock_request_id + ): + """Test with different namespace values""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_registered_node_class.find.return_value = mock_query + + namespaces = ["prod", "staging", "dev", "test-123", ""] + + # Act & Assert + for namespace in namespaces: + mock_registered_node_class.reset_mock() + mock_logger.reset_mock() + + result = await list_registered_nodes(namespace, mock_request_id) + + assert result == [] + mock_registered_node_class.find.assert_called_once() + mock_logger.info.assert_any_call( + f"Listing registered nodes for namespace: {namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.list_registered_nodes.RegisteredNode') + @patch('app.controller.list_registered_nodes.LogsManager') + async def test_list_registered_nodes_large_result_set( + self, + mock_logs_manager, + mock_registered_node_class, + mock_namespace, + mock_request_id + ): + """Test with large number of registered nodes""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + # Create large number of mock nodes + large_nodes_list = [] + for i in range(500): + node = MagicMock(spec=RegisteredNode) + node.id = f"node_{i}" + node.name = f"node_{i}" + node.namespace = mock_namespace + large_nodes_list.append(node) + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=large_nodes_list) + mock_registered_node_class.find.return_value = mock_query + + # Act + result = await list_registered_nodes(mock_namespace, mock_request_id) + + # Assert + assert len(result) == 500 + mock_logger.info.assert_any_call( + f"Found 500 registered nodes for namespace: {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.list_registered_nodes.RegisteredNode') + @patch('app.controller.list_registered_nodes.LogsManager') + async def test_list_registered_nodes_return_type( + self, + mock_logs_manager, + mock_registered_node_class, + mock_namespace, + mock_request_id, + mock_registered_nodes + ): + """Test that the function returns the correct type""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=mock_registered_nodes) + mock_registered_node_class.find.return_value = mock_query + + # Act + result = await list_registered_nodes(mock_namespace, mock_request_id) + + # Assert + assert isinstance(result, list) + for node in result: + assert isinstance(node, MagicMock) # Since we're using mocks + + # Verify each node has expected attributes (via mock) + for node in result: + assert hasattr(node, 'id') + assert hasattr(node, 'name') + assert hasattr(node, 'namespace') + + @patch('app.controller.list_registered_nodes.RegisteredNode') + @patch('app.controller.list_registered_nodes.LogsManager') + async def test_list_registered_nodes_concurrent_requests( + self, + mock_logs_manager, + mock_registered_node_class, + mock_request_id + ): + """Test handling concurrent requests with different namespaces""" + # Arrange + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_query = MagicMock() + mock_query.to_list = AsyncMock(return_value=[]) + mock_registered_node_class.find.return_value = mock_query + + # Simulate concurrent requests to different namespaces + namespaces = ["namespace1", "namespace2", "namespace3"] + + # Act + import asyncio + tasks = [list_registered_nodes(ns, f"{mock_request_id}_{i}") for i, ns in enumerate(namespaces)] + results = await asyncio.gather(*tasks) + + # Assert + assert len(results) == 3 + for result in results: + assert result == [] + + # Each namespace should have been queried + assert mock_registered_node_class.find.call_count == 3 \ No newline at end of file diff --git a/state-manager/tests/unit/controller/test_register_nodes.py b/state-manager/tests/unit/controller/test_register_nodes.py new file mode 100644 index 00000000..e26a371c --- /dev/null +++ b/state-manager/tests/unit/controller/test_register_nodes.py @@ -0,0 +1,435 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from beanie.operators import Set + +from app.controller.register_nodes import register_nodes +from app.models.register_nodes_request import RegisterNodesRequestModel, NodeRegistrationModel +from app.models.register_nodes_response import RegisterNodesResponseModel, RegisteredNodeModel + + +class TestRegisterNodes: + """Test cases for register_nodes function""" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_runtime_name(self): + return "test-runtime" + + @pytest.fixture + def mock_node_registration(self): + """Create mock node registration data""" + return NodeRegistrationModel( + name="test_node", + inputs_schema={"type": "object", "properties": {"input": {"type": "string"}}}, + outputs_schema={"type": "object", "properties": {"output": {"type": "string"}}}, + secrets=["secret1", "secret2"] + ) + + @pytest.fixture + def mock_multiple_node_registrations(self): + """Create multiple mock node registration data""" + nodes = [] + for i in range(3): + node = NodeRegistrationModel( + name=f"test_node_{i}", + inputs_schema={"type": "object", "properties": {"input": {"type": "string"}}}, + outputs_schema={"type": "object", "properties": {"output": {"type": "string"}}}, + secrets=[f"secret{i}_1", f"secret{i}_2"] + ) + nodes.append(node) + return nodes + + @pytest.fixture + def mock_register_request(self, mock_runtime_name, mock_node_registration): + """Create mock register nodes request""" + return RegisterNodesRequestModel( + runtime_name=mock_runtime_name, + nodes=[mock_node_registration] + ) + + @pytest.fixture + def mock_multiple_register_request(self, mock_runtime_name, mock_multiple_node_registrations): + """Create mock register nodes request with multiple nodes""" + return RegisterNodesRequestModel( + runtime_name=mock_runtime_name, + nodes=mock_multiple_node_registrations + ) + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_create_new_node_success( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_register_request, + mock_request_id + ): + """Test successful creation of new node""" + # Arrange + # No existing node found + mock_registered_node_class.find_one = AsyncMock(return_value=None) + + # Mock new node creation + mock_new_node = MagicMock() + mock_new_node.insert = AsyncMock() + mock_registered_node_class.return_value = mock_new_node + + # Act + result = await register_nodes(mock_namespace, mock_register_request, mock_request_id) + + # Assert + assert isinstance(result, RegisterNodesResponseModel) + assert result.runtime_name == mock_register_request.runtime_name + assert len(result.registered_nodes) == 1 + + registered_node = result.registered_nodes[0] + assert registered_node.name == "test_node" + assert registered_node.inputs_schema == mock_register_request.nodes[0].inputs_schema + assert registered_node.outputs_schema == mock_register_request.nodes[0].outputs_schema + assert registered_node.secrets == mock_register_request.nodes[0].secrets + + # Verify database operations + mock_registered_node_class.find_one.assert_called_once() + mock_registered_node_class.assert_called_once() + mock_new_node.insert.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Registering nodes for namespace {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Created new node test_node in namespace {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_update_existing_node_success( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_register_request, + mock_request_id + ): + """Test successful update of existing node""" + # Arrange + # Mock existing node + mock_existing_node = MagicMock() + mock_existing_node.update = AsyncMock() + mock_registered_node_class.find_one = AsyncMock(return_value=mock_existing_node) + + # Act + result = await register_nodes(mock_namespace, mock_register_request, mock_request_id) + + # Assert + assert isinstance(result, RegisterNodesResponseModel) + assert result.runtime_name == mock_register_request.runtime_name + assert len(result.registered_nodes) == 1 + + # Verify database operations + mock_registered_node_class.find_one.assert_called_once() + mock_existing_node.update.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Updated existing node test_node in namespace {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_multiple_nodes_mixed_operations( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_multiple_register_request, + mock_request_id + ): + """Test registering multiple nodes with mixed create/update operations""" + # Arrange + # First node exists, second and third are new + mock_existing_node = MagicMock() + mock_existing_node.update = AsyncMock() + + mock_new_node_1 = MagicMock() + mock_new_node_1.insert = AsyncMock() + mock_new_node_2 = MagicMock() + mock_new_node_2.insert = AsyncMock() + + # Mock find_one to return existing for first call, None for others + mock_registered_node_class.find_one = AsyncMock(side_effect=[mock_existing_node, None, None]) + mock_registered_node_class.side_effect = [mock_new_node_1, mock_new_node_2] + + # Act + result = await register_nodes(mock_namespace, mock_multiple_register_request, mock_request_id) + + # Assert + assert isinstance(result, RegisterNodesResponseModel) + assert result.runtime_name == mock_multiple_register_request.runtime_name + assert len(result.registered_nodes) == 3 + + # Verify database operations + assert mock_registered_node_class.find_one.call_count == 3 + mock_existing_node.update.assert_called_once() + mock_new_node_1.insert.assert_called_once() + mock_new_node_2.insert.assert_called_once() + + # Verify logging + mock_logger.info.assert_any_call( + f"Updated existing node test_node_0 in namespace {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + f"Created new node test_node_1 in namespace {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_database_error_during_find( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_register_request, + mock_request_id + ): + """Test error handling during database find operation""" + # Arrange + mock_registered_node_class.find_one = AsyncMock(side_effect=Exception("Database error")) + + # Act & Assert + with pytest.raises(Exception, match="Database error"): + await register_nodes(mock_namespace, mock_register_request, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + error_call = mock_logger.error.call_args + assert f"Error registering nodes for namespace {mock_namespace}" in str(error_call) + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_database_error_during_update( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_register_request, + mock_request_id + ): + """Test error handling during database update operation""" + # Arrange + mock_existing_node = MagicMock() + mock_existing_node.update = AsyncMock(side_effect=Exception("Update failed")) + mock_registered_node_class.find_one = AsyncMock(return_value=mock_existing_node) + + # Act & Assert + with pytest.raises(Exception, match="Update failed"): + await register_nodes(mock_namespace, mock_register_request, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_database_error_during_insert( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_register_request, + mock_request_id + ): + """Test error handling during database insert operation""" + # Arrange + mock_registered_node_class.find_one = AsyncMock(return_value=None) + mock_new_node = MagicMock() + mock_new_node.insert = AsyncMock(side_effect=Exception("Insert failed")) + mock_registered_node_class.return_value = mock_new_node + + # Act & Assert + with pytest.raises(Exception, match="Insert failed"): + await register_nodes(mock_namespace, mock_register_request, mock_request_id) + + # Verify error logging + mock_logger.error.assert_called_once() + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_empty_node_list( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_runtime_name, + mock_request_id + ): + """Test registering with empty node list""" + # Arrange + empty_request = RegisterNodesRequestModel( + runtime_name=mock_runtime_name, + nodes=[] + ) + + # Act + result = await register_nodes(mock_namespace, empty_request, mock_request_id) + + # Assert + assert isinstance(result, RegisterNodesResponseModel) + assert result.runtime_name == mock_runtime_name + assert len(result.registered_nodes) == 0 + + # Verify no database operations were performed + mock_registered_node_class.find_one.assert_not_called() + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_update_fields_verification( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_register_request, + mock_request_id + ): + """Test that update operation includes all required fields""" + # Arrange + mock_existing_node = MagicMock() + mock_existing_node.update = AsyncMock() + mock_registered_node_class.find_one = AsyncMock(return_value=mock_existing_node) + + # Act + await register_nodes(mock_namespace, mock_register_request, mock_request_id) + + # Assert - Verify update was called with correct fields + mock_existing_node.update.assert_called_once() + update_call_args = mock_existing_node.update.call_args[0][0] + + # The update method is called with a Set object, not a dict + # We can't easily inspect the Set object contents, so just verify it was called + assert isinstance(update_call_args, type(Set({}))) + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_new_node_fields_verification( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_register_request, + mock_request_id + ): + """Test that new node creation includes all required fields""" + # Arrange + mock_registered_node_class.find_one = AsyncMock(return_value=None) + mock_new_node = MagicMock() + mock_new_node.insert = AsyncMock() + mock_registered_node_class.return_value = mock_new_node + + # Act + await register_nodes(mock_namespace, mock_register_request, mock_request_id) + + # Assert - Verify new node was created with correct fields + mock_registered_node_class.assert_called_once() + create_call_args = mock_registered_node_class.call_args[1] + + expected_fields = { + 'name': mock_register_request.nodes[0].name, + 'namespace': mock_namespace, + 'runtime_name': mock_register_request.runtime_name, + 'runtime_namespace': mock_namespace, + 'inputs_schema': mock_register_request.nodes[0].inputs_schema, + 'outputs_schema': mock_register_request.nodes[0].outputs_schema, + 'secrets': mock_register_request.nodes[0].secrets + } + + for field, expected_value in expected_fields.items(): + assert create_call_args[field] == expected_value + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_response_structure_verification( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_multiple_register_request, + mock_request_id + ): + """Test that response structure is correct""" + # Arrange + mock_registered_node_class.find_one = AsyncMock(return_value=None) + mock_new_node = MagicMock() + mock_new_node.insert = AsyncMock() + mock_registered_node_class.return_value = mock_new_node + + # Act + result = await register_nodes(mock_namespace, mock_multiple_register_request, mock_request_id) + + # Assert + assert isinstance(result, RegisterNodesResponseModel) + assert result.runtime_name == mock_multiple_register_request.runtime_name + assert isinstance(result.registered_nodes, list) + assert len(result.registered_nodes) == len(mock_multiple_register_request.nodes) + + for i, registered_node in enumerate(result.registered_nodes): + assert isinstance(registered_node, RegisteredNodeModel) + original_node = mock_multiple_register_request.nodes[i] + assert registered_node.name == original_node.name + assert registered_node.inputs_schema == original_node.inputs_schema + assert registered_node.outputs_schema == original_node.outputs_schema + assert registered_node.secrets == original_node.secrets + + @patch('app.controller.register_nodes.RegisteredNode') + @patch('app.controller.register_nodes.logger') + async def test_register_nodes_success_logging( + self, + mock_logger, + mock_registered_node_class, + mock_namespace, + mock_multiple_register_request, + mock_request_id + ): + """Test comprehensive logging for successful operations""" + # Arrange + mock_registered_node_class.find_one = AsyncMock(return_value=None) + mock_new_node = MagicMock() + mock_new_node.insert = AsyncMock() + mock_registered_node_class.return_value = mock_new_node + + # Act + result = await register_nodes(mock_namespace, mock_multiple_register_request, mock_request_id) + + # Assert logging calls + expected_log_calls = [ + f"Registering nodes for namespace {mock_namespace}", + f"Successfully registered {len(result.registered_nodes)} nodes for namespace {mock_namespace}" + ] + + # Verify initial and final logging + mock_logger.info.assert_any_call( + expected_log_calls[0], + x_exosphere_request_id=mock_request_id + ) + mock_logger.info.assert_any_call( + expected_log_calls[1], + x_exosphere_request_id=mock_request_id + ) + + # Verify per-node logging + for node in mock_multiple_register_request.nodes: + mock_logger.info.assert_any_call( + f"Created new node {node.name} in namespace {mock_namespace}", + x_exosphere_request_id=mock_request_id + ) \ No newline at end of file diff --git a/state-manager/tests/unit/middlewares/__init__.py b/state-manager/tests/unit/middlewares/__init__.py new file mode 100644 index 00000000..15b4c15e --- /dev/null +++ b/state-manager/tests/unit/middlewares/__init__.py @@ -0,0 +1 @@ +# Unit tests for middlewares package \ No newline at end of file diff --git a/state-manager/tests/unit/middlewares/test_request_id_middleware.py b/state-manager/tests/unit/middlewares/test_request_id_middleware.py new file mode 100644 index 00000000..446c7076 --- /dev/null +++ b/state-manager/tests/unit/middlewares/test_request_id_middleware.py @@ -0,0 +1,377 @@ +import uuid +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from starlette.requests import Request +from starlette.responses import Response + +from app.middlewares.request_id_middleware import RequestIdMiddleware + + +class TestRequestIdMiddleware: + """Test cases for RequestIdMiddleware""" + + def setup_method(self): + """Set up test fixtures before each test""" + self.middleware = RequestIdMiddleware(app=MagicMock()) + + @pytest.mark.asyncio + async def test_dispatch_with_valid_request_id_header(self): + """Test dispatch with valid UUID in x-exosphere-request-id header""" + # Setup + valid_uuid = str(uuid.uuid4()) + mock_request = MagicMock(spec=Request) + mock_request.headers = {"x-exosphere-request-id": valid_uuid} + mock_request.method = "GET" + mock_request.url.path = "/test" + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 200 + mock_response.headers = {} + + mock_call_next = AsyncMock(return_value=mock_response) + + # Mock time.time for consistent timing + with patch('time.time', side_effect=[1000.0, 1000.5]): # 500ms duration + with patch('app.middlewares.request_id_middleware.logger') as mock_logger: + result = await self.middleware.dispatch(mock_request, mock_call_next) + + # Assertions + assert mock_request.state.x_exosphere_request_id == valid_uuid + assert mock_response.headers["x-exosphere-request-id"] == valid_uuid + assert result == mock_response + + # Check logging calls + assert mock_logger.info.call_count == 2 + + # First log call - request received + first_call_args = mock_logger.info.call_args_list[0] + assert first_call_args[0][0] == "request received" + assert first_call_args[1]["x_exosphere_request_id"] == valid_uuid + assert first_call_args[1]["method"] == "GET" + assert first_call_args[1]["url"] == "/test" + + # Second log call - request processed + second_call_args = mock_logger.info.call_args_list[1] + assert second_call_args[0][0] == "request processed" + assert second_call_args[1]["x_exosphere_request_id"] == valid_uuid + assert second_call_args[1]["response_time"] == 500.0 # 500ms + assert second_call_args[1]["status_code"] == 200 + + @pytest.mark.asyncio + async def test_dispatch_without_request_id_header_generates_new_uuid(self): + """Test dispatch generates new UUID when no x-exosphere-request-id header""" + # Setup + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + mock_request.method = "POST" + mock_request.url.path = "/api/test" + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 201 + mock_response.headers = {} + + mock_call_next = AsyncMock(return_value=mock_response) + + with patch('time.time', side_effect=[2000.0, 2000.1]): # 100ms duration + with patch('app.middlewares.request_id_middleware.logger') as mock_logger: + result = await self.middleware.dispatch(mock_request, mock_call_next) + + # Assertions + generated_uuid = mock_request.state.x_exosphere_request_id + assert generated_uuid is not None + + # Verify it's a valid UUID + uuid.UUID(generated_uuid) # Should not raise exception + + assert mock_response.headers["x-exosphere-request-id"] == generated_uuid + assert result == mock_response + + # Check logging + assert mock_logger.info.call_count == 2 + first_call_args = mock_logger.info.call_args_list[0] + assert first_call_args[1]["x_exosphere_request_id"] == generated_uuid + assert first_call_args[1]["method"] == "POST" + assert first_call_args[1]["url"] == "/api/test" + + @pytest.mark.asyncio + async def test_dispatch_with_invalid_uuid_generates_new_uuid(self): + """Test dispatch generates new UUID when x-exosphere-request-id is invalid""" + # Setup + mock_request = MagicMock(spec=Request) + mock_request.headers = {"x-exosphere-request-id": "invalid-uuid"} + mock_request.method = "PUT" + mock_request.url.path = "/api/update" + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 200 + mock_response.headers = {} + + mock_call_next = AsyncMock(return_value=mock_response) + + with patch('time.time', side_effect=[3000.0, 3001.0]): # 1000ms duration + with patch('app.middlewares.request_id_middleware.logger'): + await self.middleware.dispatch(mock_request, mock_call_next) + + # Assertions + generated_uuid = mock_request.state.x_exosphere_request_id + assert generated_uuid != "invalid-uuid" + + # Verify it's a valid UUID + uuid.UUID(generated_uuid) # Should not raise exception + + assert mock_response.headers["x-exosphere-request-id"] == generated_uuid + + @pytest.mark.asyncio + async def test_dispatch_with_malformed_uuid_generates_new_uuid(self): + """Test dispatch generates new UUID when x-exosphere-request-id is malformed""" + test_cases = [ + "12345", # Too short + "not-a-uuid-at-all", # Not UUID format + "123e4567-e89b-12d3-a456-42661419", # Missing last part + "123e4567-e89b-12d3-a456-426614174000-extra", # Too long + "", # Empty string + " ", # Whitespace only + ] + + for invalid_uuid in test_cases: + mock_request = MagicMock(spec=Request) + mock_request.headers = {"x-exosphere-request-id": invalid_uuid} + mock_request.method = "GET" + mock_request.url.path = "/test" + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 200 + mock_response.headers = {} + + mock_call_next = AsyncMock(return_value=mock_response) + + with patch('time.time', side_effect=[1000.0, 1000.1]): + with patch('app.middlewares.request_id_middleware.logger'): + await self.middleware.dispatch(mock_request, mock_call_next) + + # Should have generated a new valid UUID + generated_uuid = mock_request.state.x_exosphere_request_id + assert generated_uuid != invalid_uuid + uuid.UUID(generated_uuid) # Should not raise exception + + @pytest.mark.asyncio + async def test_dispatch_response_time_calculation(self): + """Test that response time is calculated correctly in milliseconds""" + # Setup + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + mock_request.method = "GET" + mock_request.url.path = "/test" + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 200 + mock_response.headers = {} + + mock_call_next = AsyncMock(return_value=mock_response) + + # Test different time durations + test_cases = [ + (1000.0, 1000.0, 0.0), # 0ms + (1000.0, 1000.1, 100.0), # 100ms + (1000.0, 1001.0, 1000.0), # 1000ms (1 second) + (1000.0, 1002.5, 2500.0), # 2500ms (2.5 seconds) + ] + + for start_time, end_time, expected_ms in test_cases: + with patch('time.time', side_effect=[start_time, end_time]): + with patch('app.middlewares.request_id_middleware.logger') as mock_logger: + await self.middleware.dispatch(mock_request, mock_call_next) + + # Check the response time in the second log call + second_call_args = mock_logger.info.call_args_list[1] + assert abs(second_call_args[1]["response_time"] - expected_ms) < 0.1 + + @pytest.mark.asyncio + async def test_dispatch_preserves_response_properties(self): + """Test that dispatch preserves all response properties""" + # Setup + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + mock_request.method = "GET" + mock_request.url.path = "/test" + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 404 + mock_response.headers = {"Content-Type": "application/json", "Custom-Header": "custom-value"} + + mock_call_next = AsyncMock(return_value=mock_response) + + with patch('time.time', side_effect=[1000.0, 1000.1]): + with patch('app.middlewares.request_id_middleware.logger'): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + # Should preserve all response properties and add request ID header + assert result == mock_response + assert result.status_code == 404 + assert result.headers["Content-Type"] == "application/json" + assert result.headers["Custom-Header"] == "custom-value" + assert "x-exosphere-request-id" in result.headers + + @pytest.mark.asyncio + async def test_dispatch_logs_different_request_methods_and_paths(self): + """Test that dispatch logs different HTTP methods and paths correctly""" + test_cases = [ + ("GET", "/api/users"), + ("POST", "/api/users"), + ("PUT", "/api/users/123"), + ("DELETE", "/api/users/123"), + ("PATCH", "/api/users/123"), + ("HEAD", "/health"), + ("OPTIONS", "/api/cors"), + ] + + for method, path in test_cases: + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + mock_request.method = method + mock_request.url.path = path + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 200 + mock_response.headers = {} + + mock_call_next = AsyncMock(return_value=mock_response) + + with patch('time.time', side_effect=[1000.0, 1000.1]): + with patch('app.middlewares.request_id_middleware.logger') as mock_logger: + await self.middleware.dispatch(mock_request, mock_call_next) + + # Check first log call contains correct method and URL + first_call_args = mock_logger.info.call_args_list[0] + assert first_call_args[1]["method"] == method + assert first_call_args[1]["url"] == path + + @pytest.mark.asyncio + async def test_dispatch_logs_different_response_status_codes(self): + """Test that dispatch logs different response status codes correctly""" + status_codes = [200, 201, 400, 401, 404, 500, 502, 503] + + for status_code in status_codes: + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + mock_request.method = "GET" + mock_request.url.path = "/test" + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = status_code + mock_response.headers = {} + + mock_call_next = AsyncMock(return_value=mock_response) + + with patch('time.time', side_effect=[1000.0, 1000.1]): + with patch('app.middlewares.request_id_middleware.logger') as mock_logger: + await self.middleware.dispatch(mock_request, mock_call_next) + + # Check second log call contains correct status code + second_call_args = mock_logger.info.call_args_list[1] + assert second_call_args[1]["status_code"] == status_code + + @pytest.mark.asyncio + async def test_dispatch_uuid_consistency_throughout_request(self): + """Test that the same UUID is used throughout the request lifecycle""" + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + mock_request.method = "GET" + mock_request.url.path = "/test" + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 200 + mock_response.headers = {} + + mock_call_next = AsyncMock(return_value=mock_response) + + with patch('time.time', side_effect=[1000.0, 1000.1]): + with patch('app.middlewares.request_id_middleware.logger') as mock_logger: + await self.middleware.dispatch(mock_request, mock_call_next) + + # Get the UUID from request state + request_uuid = mock_request.state.x_exosphere_request_id + + # Get the UUID from response header + response_uuid = mock_response.headers["x-exosphere-request-id"] + + # Get UUIDs from both log calls + first_log_uuid = mock_logger.info.call_args_list[0][1]["x_exosphere_request_id"] + second_log_uuid = mock_logger.info.call_args_list[1][1]["x_exosphere_request_id"] + + # All should be the same + assert request_uuid == response_uuid == first_log_uuid == second_log_uuid + + @pytest.mark.asyncio + async def test_dispatch_handles_case_sensitive_header(self): + """Test that header matching is case-insensitive as per HTTP standards""" + # Setup with different case variations + header_variations = [ + "x-exosphere-request-id", + "X-Exosphere-Request-Id", + "X-EXOSPHERE-REQUEST-ID", + "x-Exosphere-Request-Id" + ] + + valid_uuid = str(uuid.uuid4()) + + for header_name in header_variations: + mock_request = MagicMock(spec=Request) + # Mock headers.get to be case-insensitive like real Starlette + def case_insensitive_get(key): + if key.lower() == "x-exosphere-request-id": + return valid_uuid + return None + + mock_request.headers.get = case_insensitive_get + mock_request.method = "GET" + mock_request.url.path = "/test" + mock_request.state = MagicMock() + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 200 + mock_response.headers = {} + + mock_call_next = AsyncMock(return_value=mock_response) + + with patch('time.time', side_effect=[1000.0, 1000.1]): + with patch('app.middlewares.request_id_middleware.logger'): + await self.middleware.dispatch(mock_request, mock_call_next) + + # Should use the provided UUID regardless of header case + assert mock_request.state.x_exosphere_request_id == valid_uuid + + @pytest.mark.asyncio + async def test_dispatch_exception_handling(self): + """Test middleware behavior when call_next raises an exception""" + mock_request = MagicMock(spec=Request) + mock_request.headers = {} + mock_request.method = "GET" + mock_request.url.path = "/test" + mock_request.state = MagicMock() + + # Mock call_next to raise an exception + mock_call_next = AsyncMock(side_effect=Exception("Test exception")) + + with patch('time.time', side_effect=[1000.0, 1000.1]): + with patch('app.middlewares.request_id_middleware.logger') as mock_logger: + with pytest.raises(Exception, match="Test exception"): + await self.middleware.dispatch(mock_request, mock_call_next) + + # Should still log the request received, but not the processed log + assert mock_logger.info.call_count == 1 + first_call_args = mock_logger.info.call_args_list[0] + assert first_call_args[0][0] == "request received" + + # Request state should still be set + assert hasattr(mock_request.state, 'x_exosphere_request_id') + uuid.UUID(mock_request.state.x_exosphere_request_id) # Should be valid UUID \ No newline at end of file diff --git a/state-manager/tests/unit/middlewares/test_unhandled_exceptions_middleware.py b/state-manager/tests/unit/middlewares/test_unhandled_exceptions_middleware.py new file mode 100644 index 00000000..ec0ed777 --- /dev/null +++ b/state-manager/tests/unit/middlewares/test_unhandled_exceptions_middleware.py @@ -0,0 +1,381 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from starlette.requests import Request +from starlette.responses import Response, JSONResponse + +from app.middlewares.unhandled_exceptions_middleware import UnhandledExceptionsMiddleware + + +class TestUnhandledExceptionsMiddleware: + """Test cases for UnhandledExceptionsMiddleware""" + + def setup_method(self): + """Set up test fixtures before each test""" + self.middleware = UnhandledExceptionsMiddleware(app=MagicMock()) + + @pytest.mark.asyncio + async def test_dispatch_success_no_exception(self): + """Test dispatch when no exception occurs""" + # Setup + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/test" + mock_request.method = "GET" + + mock_response = MagicMock(spec=Response) + mock_response.status_code = 200 + + mock_call_next = AsyncMock(return_value=mock_response) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + result = await self.middleware.dispatch(mock_request, mock_call_next) + + # Should return the original response + assert result == mock_response + + # Should not log any errors + mock_logger.error.assert_not_called() + + @pytest.mark.asyncio + async def test_dispatch_handles_generic_exception(self): + """Test dispatch handles generic exceptions""" + # Setup + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/error" + mock_request.method = "POST" + mock_request.state.x_exosphere_request_id = "test-request-id" + + test_exception = Exception("Generic test error") + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + with patch('traceback.format_exc', return_value="Mock traceback"): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + # Should return JSONResponse with 500 status + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + # Check response content + # Note: We can't easily test the actual JSON content without calling result.body, + # but we can verify it's a JSONResponse with the right status code + + # Should log the error + mock_logger.error.assert_called_once_with( + "unhandled global exception", + error="Generic test error", + traceback="Mock traceback", + path="/api/error", + method="POST", + x_exosphere_request_id="test-request-id" + ) + + @pytest.mark.asyncio + async def test_dispatch_handles_runtime_error(self): + """Test dispatch handles RuntimeError exceptions""" + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/runtime-error" + mock_request.method = "PUT" + mock_request.state.x_exosphere_request_id = "runtime-request-id" + + test_exception = RuntimeError("Runtime test error") + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + with patch('traceback.format_exc', return_value="Runtime traceback"): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + mock_logger.error.assert_called_once_with( + "unhandled global exception", + error="Runtime test error", + traceback="Runtime traceback", + path="/api/runtime-error", + method="PUT", + x_exosphere_request_id="runtime-request-id" + ) + + @pytest.mark.asyncio + async def test_dispatch_handles_value_error(self): + """Test dispatch handles ValueError exceptions""" + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/validation" + mock_request.method = "POST" + mock_request.state.x_exosphere_request_id = "validation-request-id" + + test_exception = ValueError("Invalid value provided") + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + with patch('traceback.format_exc', return_value="ValueError traceback"): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + mock_logger.error.assert_called_once_with( + "unhandled global exception", + error="Invalid value provided", + traceback="ValueError traceback", + path="/api/validation", + method="POST", + x_exosphere_request_id="validation-request-id" + ) + + @pytest.mark.asyncio + async def test_dispatch_handles_key_error(self): + """Test dispatch handles KeyError exceptions""" + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/data" + mock_request.method = "GET" + mock_request.state.x_exosphere_request_id = "key-error-request-id" + + test_exception = KeyError("missing_key") + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + with patch('traceback.format_exc', return_value="KeyError traceback"): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + mock_logger.error.assert_called_once_with( + "unhandled global exception", + error="'missing_key'", + traceback="KeyError traceback", + path="/api/data", + method="GET", + x_exosphere_request_id="key-error-request-id" + ) + + @pytest.mark.asyncio + async def test_dispatch_handles_attribute_error(self): + """Test dispatch handles AttributeError exceptions""" + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/object" + mock_request.method = "PATCH" + mock_request.state.x_exosphere_request_id = "attribute-request-id" + + test_exception = AttributeError("'NoneType' object has no attribute 'method'") + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + with patch('traceback.format_exc', return_value="AttributeError traceback"): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + mock_logger.error.assert_called_once_with( + "unhandled global exception", + error="'NoneType' object has no attribute 'method'", + traceback="AttributeError traceback", + path="/api/object", + method="PATCH", + x_exosphere_request_id="attribute-request-id" + ) + + @pytest.mark.asyncio + async def test_dispatch_without_request_id_logs_none(self): + """Test dispatch when request state has no x_exosphere_request_id""" + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/no-id" + mock_request.method = "DELETE" + # Mock state without x_exosphere_request_id attribute + mock_request.state = MagicMock() + del mock_request.state.x_exosphere_request_id # Simulate missing attribute + + test_exception = Exception("No request ID error") + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + with patch('traceback.format_exc', return_value="No ID traceback"): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + mock_logger.error.assert_called_once_with( + "unhandled global exception", + error="No request ID error", + traceback="No ID traceback", + path="/api/no-id", + method="DELETE", + x_exosphere_request_id=None + ) + + @pytest.mark.asyncio + async def test_dispatch_with_empty_request_id_logs_empty_string(self): + """Test dispatch when request has empty string request ID""" + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/empty-id" + mock_request.method = "OPTIONS" + mock_request.state.x_exosphere_request_id = "" + + test_exception = Exception("Empty ID error") + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + with patch('traceback.format_exc', return_value="Empty ID traceback"): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + mock_logger.error.assert_called_once_with( + "unhandled global exception", + error="Empty ID error", + traceback="Empty ID traceback", + path="/api/empty-id", + method="OPTIONS", + x_exosphere_request_id="" + ) + + @pytest.mark.asyncio + async def test_dispatch_logs_different_request_paths_and_methods(self): + """Test dispatch logs different paths and methods correctly during exceptions""" + test_cases = [ + ("GET", "/api/users", "Get users error"), + ("POST", "/api/users/create", "Create user error"), + ("PUT", "/api/users/123", "Update user error"), + ("DELETE", "/api/users/123", "Delete user error"), + ("PATCH", "/api/users/123/status", "Update status error"), + ("HEAD", "/health", "Health check error"), + ("OPTIONS", "/api/cors", "CORS preflight error"), + ] + + for method, path, error_message in test_cases: + mock_request = MagicMock(spec=Request) + mock_request.url.path = path + mock_request.method = method + mock_request.state.x_exosphere_request_id = f"test-id-{method.lower()}" + + test_exception = Exception(error_message) + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + with patch('traceback.format_exc', return_value=f"{method} traceback"): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + mock_logger.error.assert_called_once_with( + "unhandled global exception", + error=error_message, + traceback=f"{method} traceback", + path=path, + method=method, + x_exosphere_request_id=f"test-id-{method.lower()}" + ) + + @pytest.mark.asyncio + async def test_dispatch_response_content_structure(self): + """Test that the error response has the correct JSON structure""" + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/test" + mock_request.method = "GET" + mock_request.state.x_exosphere_request_id = "response-test-id" + + test_exception = Exception("Response structure test") + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger'): + with patch('traceback.format_exc'): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + # Verify it's a JSONResponse with correct structure + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + # The actual content validation would require calling result.body or similar, + # but we can verify the key properties of the JSONResponse + assert hasattr(result, 'status_code') + assert result.status_code == 500 + + @pytest.mark.asyncio + async def test_dispatch_uses_actual_traceback_format_exc(self): + """Test that dispatch uses actual traceback.format_exc() when not mocked""" + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/traceback-test" + mock_request.method = "POST" + mock_request.state.x_exosphere_request_id = "traceback-test-id" + + test_exception = ValueError("Traceback test error") + mock_call_next = AsyncMock(side_effect=test_exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + # Don't mock traceback.format_exc to test actual behavior + result = await self.middleware.dispatch(mock_request, mock_call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + # Verify the logger was called with actual traceback + mock_logger.error.assert_called_once() + call_args = mock_logger.error.call_args[1] + assert call_args["error"] == "Traceback test error" + assert "traceback" in call_args + # The actual traceback should contain information about the ValueError + assert "ValueError: Traceback test error" in call_args["traceback"] + + @pytest.mark.asyncio + async def test_dispatch_exception_in_exception_handling(self): + """Test middleware behavior when logging itself fails""" + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/logging-error" + mock_request.method = "GET" + mock_request.state.x_exosphere_request_id = "logging-error-id" + + test_exception = Exception("Original error") + mock_call_next = AsyncMock(side_effect=test_exception) + + # Mock logger.error to raise an exception + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + mock_logger.error.side_effect = Exception("Logging failed") + + # The middleware should still return a JSONResponse even if logging fails + # This tests the robustness of error handling + with pytest.raises(Exception, match="Logging failed"): + await self.middleware.dispatch(mock_request, mock_call_next) + + @pytest.mark.asyncio + async def test_dispatch_preserves_original_exception_type_in_logs(self): + """Test that different exception types are logged with their original string representation""" + exception_test_cases = [ + (ValueError("Invalid input"), "Invalid input"), + (KeyError("missing_key"), "'missing_key'"), + (AttributeError("'str' object has no attribute 'nonexistent'"), "'str' object has no attribute 'nonexistent'"), + (TypeError("unsupported operand type(s)"), "unsupported operand type(s)"), + (IndexError("list index out of range"), "list index out of range"), + (FileNotFoundError("No such file or directory"), "No such file or directory"), + (ConnectionError("Connection failed"), "Connection failed"), + (TimeoutError("Operation timed out"), "Operation timed out"), + ] + + for exception, expected_error_message in exception_test_cases: + mock_request = MagicMock(spec=Request) + mock_request.url.path = "/api/exception-types" + mock_request.method = "GET" + mock_request.state.x_exosphere_request_id = "exception-types-id" + + mock_call_next = AsyncMock(side_effect=exception) + + with patch('app.middlewares.unhandled_exceptions_middleware.logger') as mock_logger: + with patch('traceback.format_exc', return_value="Mock traceback"): + result = await self.middleware.dispatch(mock_request, mock_call_next) + + assert isinstance(result, JSONResponse) + assert result.status_code == 500 + + # Verify the specific error message is logged correctly + mock_logger.error.assert_called_once_with( + "unhandled global exception", + error=expected_error_message, + traceback="Mock traceback", + path="/api/exception-types", + method="GET", + x_exosphere_request_id="exception-types-id" + ) \ No newline at end of file diff --git a/state-manager/tests/unit/singletons/__init__.py b/state-manager/tests/unit/singletons/__init__.py new file mode 100644 index 00000000..318c5661 --- /dev/null +++ b/state-manager/tests/unit/singletons/__init__.py @@ -0,0 +1 @@ +# Unit tests for singletons package \ No newline at end of file diff --git a/state-manager/tests/unit/singletons/test_singleton_decorator.py b/state-manager/tests/unit/singletons/test_singleton_decorator.py new file mode 100644 index 00000000..b0239df4 --- /dev/null +++ b/state-manager/tests/unit/singletons/test_singleton_decorator.py @@ -0,0 +1,320 @@ +import pytest +from app.singletons.SingletonDecorator import singleton + + +class TestSingletonDecorator: + """Test cases for singleton decorator function""" + + def test_singleton_decorator_creates_single_instance(self): + """Test that singleton decorator ensures only one instance is created""" + + @singleton + class TestClass: + def __init__(self): + self.value = "test" + + # Create multiple instances + instance1 = TestClass() + instance2 = TestClass() + instance3 = TestClass() + + # All should be the same object + assert instance1 is instance2 + assert instance2 is instance3 + assert instance1 is instance3 + + def test_singleton_decorator_preserves_class_functionality(self): + """Test that singleton decorator preserves class methods and attributes""" + + @singleton + class TestClass: + def __init__(self, value): + self.value = value + self.counter = 0 + + def increment(self): + self.counter += 1 + return self.counter + + def get_value(self): + return self.value + + instance1 = TestClass("first") + instance2 = TestClass("second") # This should be ignored due to singleton + + # Should be the same instance + assert instance1 is instance2 + + # Should preserve the original initialization (first call) + assert instance1.get_value() == "first" + assert instance2.get_value() == "first" + + # Method calls should work and share state + assert instance1.increment() == 1 + assert instance2.increment() == 2 # Same counter, incremented + assert instance1.counter == 2 + + def test_singleton_decorator_with_no_args_constructor(self): + """Test singleton decorator with class that has no constructor arguments""" + + @singleton + class SimpleClass: + def __init__(self): + self.created = True + + instance1 = SimpleClass() + instance2 = SimpleClass() + + assert instance1 is instance2 + assert instance1.created is True + assert instance2.created is True + + def test_singleton_decorator_with_multiple_args(self): + """Test singleton decorator with class that accepts multiple arguments""" + + @singleton + class MultiArgClass: + def __init__(self, arg1, arg2, kwarg1=None, kwarg2="default"): + self.arg1 = arg1 + self.arg2 = arg2 + self.kwarg1 = kwarg1 + self.kwarg2 = kwarg2 + + # First instance with specific args + instance1 = MultiArgClass("first", "second", kwarg1="kw1", kwarg2="kw2") + + # Second instance with different args (should be ignored) + instance2 = MultiArgClass("different", "args", kwarg1="ignored", kwarg2="ignored") + + assert instance1 is instance2 + + # Should preserve first initialization + assert instance1.arg1 == "first" + assert instance1.arg2 == "second" + assert instance1.kwarg1 == "kw1" + assert instance1.kwarg2 == "kw2" + + # instance2 should have the same values + assert instance2.arg1 == "first" + assert instance2.arg2 == "second" + assert instance2.kwarg1 == "kw1" + assert instance2.kwarg2 == "kw2" + + def test_singleton_decorator_with_different_classes(self): + """Test that singleton decorator works independently for different classes""" + + @singleton + class ClassA: + def __init__(self): + self.type = "A" + + @singleton + class ClassB: + def __init__(self): + self.type = "B" + + # Each class should have its own singleton instance + a1 = ClassA() + a2 = ClassA() + b1 = ClassB() + b2 = ClassB() + + # Same class instances should be identical + assert a1 is a2 + assert b1 is b2 + + # Different class instances should be different + assert a1 is not b1 + assert a2 is not b2 + + # Each should preserve their own properties + assert a1.type == "A" + assert b1.type == "B" + + def test_singleton_decorator_preserves_class_name(self): + """Test that singleton decorator preserves original class name""" + + @singleton + class NamedClass: + pass + + instance = NamedClass() + + # The returned function should still reference the original class + # Note: The decorator returns a function, not a class, but the instance + # should still be of the original class type + assert instance.__class__.__name__ == "NamedClass" + + def test_singleton_decorator_thread_safety_simulation(self): + """Test singleton decorator behavior under simulated concurrent access""" + + call_count = 0 + + @singleton + class CountedClass: + def __init__(self): + nonlocal call_count + call_count += 1 + self.instance_id = call_count + + # Simulate multiple "concurrent" calls + instances = [] + for _ in range(10): + instances.append(CountedClass()) + + # All instances should be the same + first_instance = instances[0] + for instance in instances[1:]: + assert instance is first_instance + + # Constructor should only be called once + assert call_count == 1 + assert first_instance.instance_id == 1 + + def test_singleton_decorator_with_methods_and_properties(self): + """Test singleton decorator preserves methods and properties""" + + @singleton + class MethodClass: + def __init__(self): + self._internal_value = 42 + + @property + def value(self): + return self._internal_value + + @value.setter + def value(self, new_value): + self._internal_value = new_value + + def calculate(self, multiplier): + return self._internal_value * multiplier + + @staticmethod + def static_method(): + return "static" + + @classmethod + def class_method(cls): + return cls.__name__ + + instance1 = MethodClass() + instance2 = MethodClass() + + assert instance1 is instance2 + + # Test property access + assert instance1.value == 42 + assert instance2.value == 42 + + # Test property setting + instance1.value = 100 + assert instance2.value == 100 # Should be shared + + # Test method calls + assert instance1.calculate(2) == 200 + assert instance2.calculate(3) == 300 + + # Test static and class methods + assert instance1.static_method() == "static" + assert instance2.static_method() == "static" + assert instance1.class_method() == "MethodClass" + + def test_singleton_decorator_with_exception_in_constructor(self): + """Test singleton decorator behavior when constructor raises exception""" + + @singleton + class FailingClass: + def __init__(self, should_fail=True): + if should_fail: + raise ValueError("Constructor failed") + self.success = True + + # First call with failure + with pytest.raises(ValueError, match="Constructor failed"): + FailingClass() + + # Second call with failure (should try to create again since first failed) + with pytest.raises(ValueError, match="Constructor failed"): + FailingClass() + + # The singleton pattern should handle constructor failures gracefully + # After failure, the class should still not be in instances dict + + def test_singleton_decorator_instances_isolation(self): + """Test that singleton decorator maintains separate instances dict per decorated class""" + + # Test that different decorations maintain separate state + @singleton + class FirstSingleton: + def __init__(self): + self.name = "first" + + @singleton + class SecondSingleton: + def __init__(self): + self.name = "second" + + first = FirstSingleton() + second = SecondSingleton() + + # Should be different instances + assert first is not second + assert first.name != second.name + + # Multiple calls should return same instance for each class + first_again = FirstSingleton() + second_again = SecondSingleton() + + assert first is first_again + assert second is second_again + + def test_singleton_decorator_callable_return_value(self): + """Test that singleton decorator returns a callable""" + + @singleton + class TestClass: + pass + + # The decorator should return a callable (the get_instance function) + assert callable(TestClass) + + # Calling it should return an instance + instance = TestClass() + assert instance is not None + assert hasattr(instance, '__class__') + + def test_singleton_decorator_with_complex_initialization(self): + """Test singleton decorator with complex initialization logic""" + + initialization_count = 0 + + @singleton + class ComplexClass: + def __init__(self, config=None): + nonlocal initialization_count + initialization_count += 1 + + self.config = config or {} + self.initialized_at = initialization_count + self.cache = {} + + # Simulate complex initialization + self._setup_internal_state() + + def _setup_internal_state(self): + self.internal_state = "configured" + + # First initialization + instance1 = ComplexClass({"setting": "value1"}) + assert initialization_count == 1 + assert instance1.config == {"setting": "value1"} + assert instance1.initialized_at == 1 + assert instance1.internal_state == "configured" + + # Second call (should return same instance, ignore new config) + instance2 = ComplexClass({"setting": "value2"}) + assert initialization_count == 1 # No additional initialization + assert instance1 is instance2 + assert instance2.config == {"setting": "value1"} # Original config preserved + assert instance2.initialized_at == 1 \ No newline at end of file diff --git a/state-manager/tests/unit/tasks/test_create_next_states.py b/state-manager/tests/unit/tasks/test_create_next_states.py new file mode 100644 index 00000000..f7f28855 --- /dev/null +++ b/state-manager/tests/unit/tasks/test_create_next_states.py @@ -0,0 +1,162 @@ +import pytest +from unittest.mock import AsyncMock, patch + +from pydantic import BaseModel + +from app.tasks import create_next_states as cns +from app.models.node_template_model import NodeTemplate, Unites +from app.models.state_status_enum import StateStatusEnum + + +# --------------------------------------------------------------------------- +# Helper fixtures & stubs +# --------------------------------------------------------------------------- + +class DummyState: + """Very small stand-in for the real *State* ODM model. + + Only the minimal surface required by the functions under test is + implemented (``status``, ``outputs`` and an async ``save`` method). + """ + + def __init__(self, sid, outputs=None): + self.id = sid + self.status = None + self.outputs = outputs or {} + self.error = None + # ``save`` must be awaitable because the real method is awaited. + self.save = AsyncMock() + + +class DummyQuery: + """Mimics the chain returned by ``State.find()`` inside the helpers.""" + + def __init__(self, count_value: int = 0): + self._count_value = count_value + self.set = AsyncMock() + + async def count(self): + return self._count_value + + +# --------------------------------------------------------------------------- +# Tests for *get_dependents* +# --------------------------------------------------------------------------- + +def test_get_dependents_success(): + src = "Hello ${{parent1.outputs.field1}} world ${{current.outputs.answer}}!" + result = cns.get_dependents(src) + + # Head extraction + assert result.head == "Hello " + + # Two placeholders discovered in order + assert list(result.dependents.keys()) == [0, 1] + + d0 = result.dependents[0] + assert (d0.identifier, d0.field, d0.tail) == ("parent1", "field1", " world ") + + d1 = result.dependents[1] + assert (d1.identifier, d1.field, d1.tail) == ("current", "answer", "!") + + +def test_get_dependents_invalid_format(): + # Missing the mandatory ``.outputs.`` part should error out. + with pytest.raises(ValueError): + cns.get_dependents("Broken ${{parent.outputs_missing}} snippet") + + +# --------------------------------------------------------------------------- +# Tests for *validate_dependencies* +# --------------------------------------------------------------------------- + +class _InputModel(BaseModel): + greeting: str + + +@pytest.fixture +def parent_state(): + return DummyState("parent-sid", outputs={"msg": "hi"}) + + +def _make_node_template(dep_string: str) -> NodeTemplate: + return NodeTemplate( + node_name="next_node", + namespace="ns", + identifier="next_id", + inputs={"greeting": dep_string}, + next_nodes=[], + unites=None, + ) + + +def test_validate_dependencies_success(parent_state): + node_tpl = _make_node_template("${{parent.outputs.msg}}") + # Should not raise. + cns.validate_dependencies(node_tpl, _InputModel, "current", {"parent": parent_state}) + + +def test_validate_dependencies_missing_parent(parent_state): + node_tpl = _make_node_template("${{missing_parent.outputs.msg}}") + with pytest.raises(KeyError): + cns.validate_dependencies(node_tpl, _InputModel, "current", {"parent": parent_state}) + + +# --------------------------------------------------------------------------- +# Tests for *check_unites_satisfied* +# --------------------------------------------------------------------------- + +async def _run_check_unites(count_value): + unit = Unites(identifier="parent") + node_tpl = NodeTemplate( + node_name="node", + namespace="ns", + identifier="id", + inputs={}, + next_nodes=[], + unites=[unit], + ) + + # Patch *State.find()* to deliver the dummy query with desired count. + with patch.object(cns, "State") as mock_state: + mock_state.find.return_value = DummyQuery(count_value) + result = await cns.check_unites_satisfied( + "ns", "graph", node_tpl, {"parent": "parent-sid"} # type: ignore + ) + return result + + +@pytest.mark.asyncio +async def test_check_unites_satisfied_true(): + assert await _run_check_unites(0) is True + + +@pytest.mark.asyncio +async def test_check_unites_satisfied_false(): + assert await _run_check_unites(1) is False + + +# --------------------------------------------------------------------------- +# Tests for *mark_success_states* +# --------------------------------------------------------------------------- + +@pytest.mark.asyncio +async def test_mark_success_states_updates_status(): + state_ids = ["sid-1", "sid-2"] + created = {} + + async def _get(sid): + created[sid] = DummyState(sid) + return created[sid] + + with patch.object(cns, "State") as mock_state: + # Provide *get* and *find* replacements. + mock_state.get = AsyncMock(side_effect=_get) + mock_state.find.return_value = DummyQuery() + + # Execute. + await cns.mark_success_states(state_ids) # type: ignore + + for st in created.values(): + assert st.status == StateStatusEnum.SUCCESS + st.save.assert_awaited() \ No newline at end of file diff --git a/state-manager/tests/unit/test_create_next_state.py b/state-manager/tests/unit/test_create_next_state.py deleted file mode 100644 index 65bc17b1..00000000 --- a/state-manager/tests/unit/test_create_next_state.py +++ /dev/null @@ -1,132 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from beanie import PydanticObjectId - -from app.tasks.create_next_state import create_next_state -from app.models.db.state import State -from app.models.db.graph_template_model import GraphTemplate -from app.models.db.registered_node import RegisteredNode -from app.models.graph_template_validation_status import GraphTemplateValidationStatus -from app.models.state_status_enum import StateStatusEnum - - -class TestCreateNextState: - """Test cases for create_next_state function""" - - @pytest.fixture - def mock_state(self): - """Create a mock state object""" - state = MagicMock(spec=State) - state.id = PydanticObjectId() - state.identifier = "test_node" - state.namespace_name = "test_namespace" - state.graph_name = "test_graph" - state.run_id = "test_run_id" - state.status = StateStatusEnum.EXECUTED - state.inputs = {"input1": "value1"} - state.outputs = {"output1": "result1"} - state.error = None - state.parents = {"parent_node": PydanticObjectId()} - state.save = AsyncMock() - return state - - @pytest.fixture - def mock_graph_template(self): - """Create a mock graph template""" - template = MagicMock(spec=GraphTemplate) - template.validation_status = GraphTemplateValidationStatus.VALID - template.get_node_by_identifier = MagicMock() - return template - - @pytest.fixture - def mock_registered_node(self): - """Create a mock registered node""" - node = MagicMock(spec=RegisteredNode) - node.inputs_schema = { - "type": "object", - "properties": { - "field1": {"type": "string"}, - "field2": {"type": "string"} - } - } - return node - - @patch('app.tasks.create_next_state.GraphTemplate') - async def test_create_next_state_none_id(self, mock_graph_template_class): - """Test create_next_state with state having None id""" - # Arrange - state_with_none_id = MagicMock() - state_with_none_id.id = None - - # Act & Assert - with pytest.raises(ValueError, match="State is not valid"): - await create_next_state(state_with_none_id) - - @patch('app.tasks.create_next_state.GraphTemplate') - @patch('app.tasks.create_next_state.asyncio.sleep') - async def test_create_next_state_wait_for_validation( - self, - mock_sleep, - mock_graph_template_class, - mock_state, - mock_graph_template - ): - """Test waiting for graph template to become valid""" - # Arrange - # First call returns invalid template, second call returns valid - invalid_template = MagicMock() - invalid_template.validation_status = GraphTemplateValidationStatus.INVALID - - mock_graph_template_class.find_one = AsyncMock(side_effect=[invalid_template, mock_graph_template]) - - # Mock node template with no next nodes - node_template = MagicMock() - node_template.next_nodes = None - mock_graph_template.get_node_by_identifier.return_value = node_template - - # Act - await create_next_state(mock_state) - - # Assert - assert mock_graph_template_class.find_one.call_count == 2 - mock_sleep.assert_called_once_with(1) - assert mock_state.status == StateStatusEnum.SUCCESS - - @patch('app.tasks.create_next_state.GraphTemplate') - async def test_create_next_state_no_next_nodes( - self, - mock_graph_template_class, - mock_state, - mock_graph_template - ): - """Test when there are no next nodes""" - # Arrange - mock_graph_template_class.find_one = AsyncMock(return_value=mock_graph_template) - - node_template = MagicMock() - node_template.next_nodes = None - mock_graph_template.get_node_by_identifier.return_value = node_template - - # Act - await create_next_state(mock_state) - - # Assert - assert mock_state.status == StateStatusEnum.SUCCESS - - @patch('app.tasks.create_next_state.GraphTemplate') - async def test_create_next_state_general_exception( - self, - mock_graph_template_class, - mock_state - ): - """Test general exception handling""" - # Arrange - mock_graph_template_class.find_one = AsyncMock(side_effect=Exception("General error")) - - # Act - await create_next_state(mock_state) - - # Assert - assert mock_state.status == StateStatusEnum.ERRORED - assert mock_state.error == "General error" - mock_state.save.assert_called_once() \ No newline at end of file diff --git a/state-manager/tests/unit/test_main.py b/state-manager/tests/unit/test_main.py new file mode 100644 index 00000000..2ae8ba48 --- /dev/null +++ b/state-manager/tests/unit/test_main.py @@ -0,0 +1,276 @@ +import os +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from app import main as app_main + + +class TestMainApp: + """Test cases for main FastAPI application setup""" + + def test_app_initialization(self): + """Test that FastAPI app is initialized correctly""" + app = app_main.app + + assert isinstance(app, FastAPI) + assert app.title == "Exosphere State Manager" + assert app.description == "Exosphere State Manager" + assert app.version == "0.1.0" + + # Check contact info + assert app.contact is not None + assert app.contact["name"] == "Nivedit Jain (Founder exosphere.host)" + assert app.contact["email"] == "nivedit@exosphere.host" + + # Check license info + assert app.license_info is not None + assert app.license_info["name"] == "Elastic License 2.0 (ELv2)" + assert "github.com/exospherehost/exosphere-api-server/blob/main/LICENSE" in app.license_info["url"] + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'test-secret'}) + @patch('app.main.init_beanie') + @patch('app.main.AsyncMongoClient') + @patch('app.main.LogsManager') + def test_health_endpoint(self, mock_logs_manager, mock_mongo_client, mock_init_beanie): + """Test the health endpoint""" + # Setup mocks to avoid database connections + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + mock_client = MagicMock() + mock_mongo_client.return_value = mock_client + mock_init_beanie.return_value = AsyncMock() + + with TestClient(app_main.app) as client: + response = client.get("/health") + + assert response.status_code == 200 + assert response.json() == {"message": "OK"} + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'test-secret'}) + @patch('app.main.init_beanie') + @patch('app.main.AsyncMongoClient') + @patch('app.main.LogsManager') + def test_health_endpoint_content_type(self, mock_logs_manager, mock_mongo_client, mock_init_beanie): + """Test the health endpoint returns JSON""" + # Setup mocks to avoid database connections + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + mock_client = MagicMock() + mock_mongo_client.return_value = mock_client + mock_init_beanie.return_value = AsyncMock() + + with TestClient(app_main.app) as client: + response = client.get("/health") + + assert response.headers["content-type"] == "application/json" + + @patch('app.main.LogsManager') + def test_middlewares_added_to_app(self, mock_logs_manager): + """Test that middlewares are added to the application""" + # Since middlewares are added during app creation, we need to check + # if they're present in the middleware stack + app = app_main.app + + # FastAPI stores middleware in app.user_middleware + middleware_classes = [middleware.cls for middleware in app.user_middleware] + + # Import the middleware classes for comparison + from app.middlewares.request_id_middleware import RequestIdMiddleware + from app.middlewares.unhandled_exceptions_middleware import UnhandledExceptionsMiddleware + + assert RequestIdMiddleware in middleware_classes + assert UnhandledExceptionsMiddleware in middleware_classes + + def test_middleware_order(self): + """Test that middlewares are added in correct order""" + app = app_main.app + + # FastAPI stores middleware in reverse order (last added is first executed) + middleware_classes = [middleware.cls for middleware in app.user_middleware] + + from app.middlewares.request_id_middleware import RequestIdMiddleware + from app.middlewares.unhandled_exceptions_middleware import UnhandledExceptionsMiddleware + + # UnhandledExceptionsMiddleware should be added last (executed first) + # RequestIdMiddleware should be added first (executed after UnhandledExceptionsMiddleware) + request_id_index = middleware_classes.index(RequestIdMiddleware) # type: ignore + unhandled_exceptions_index = middleware_classes.index(UnhandledExceptionsMiddleware) # type: ignore + + # Since middleware is stored in reverse order, UnhandledExceptions should have lower index + assert unhandled_exceptions_index < request_id_index + + def test_router_included(self): + """Test that the main router is included in the app""" + app = app_main.app + + # Check that routes from the router are present + # The exact routes depend on what's in routes.py, but we can check if routes exist + assert len(app.routes) > 1 # Should have at least health + routes from router + + +class TestLifespan: + """Test cases for lifespan context manager""" + + @patch.dict(os.environ, { + 'MONGO_URI': 'mongodb://test:27017', + 'MONGO_DATABASE_NAME': 'test_db', + 'STATE_MANAGER_SECRET': 'test_secret' + }) + @patch('app.main.init_beanie') + @patch('app.main.AsyncMongoClient') + @patch('app.main.LogsManager') + async def test_lifespan_startup_success(self, mock_logs_manager, mock_mongo_client, mock_init_beanie): + """Test successful lifespan startup""" + # Setup mocks + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_client = MagicMock() + mock_mongo_client.return_value = mock_client + mock_db = MagicMock() + mock_client.__getitem__.return_value = mock_db + + mock_init_beanie.return_value = AsyncMock() + + # Create a mock FastAPI app for the lifespan + mock_app = MagicMock() + + # Test the lifespan context manager + async with app_main.lifespan(mock_app): + # During startup, these should be called + mock_logs_manager.assert_called() + mock_logger.info.assert_any_call("server starting") + mock_mongo_client.assert_called_with('mongodb://test:27017') + mock_client.__getitem__.assert_called_with('test_db') + mock_init_beanie.assert_called() + mock_logger.info.assert_any_call("beanie dbs initialized") + mock_logger.info.assert_any_call("secret initialized") + + # After context manager exits (shutdown) + mock_logger.info.assert_any_call("server shutting down") + + @patch.dict(os.environ, { + 'MONGO_URI': 'mongodb://test:27017', + 'MONGO_DATABASE_NAME': 'test_db', + 'STATE_MANAGER_SECRET': '' # Empty secret + }) + @patch('app.main.init_beanie') + @patch('app.main.AsyncMongoClient') + @patch('app.main.LogsManager') + async def test_lifespan_empty_secret_raises_error(self, mock_logs_manager, mock_mongo_client, mock_init_beanie): + """Test that empty STATE_MANAGER_SECRET raises ValueError""" + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_client = MagicMock() + mock_mongo_client.return_value = mock_client + mock_db = MagicMock() + mock_client.__getitem__.return_value = mock_db + + mock_init_beanie.return_value = AsyncMock() + + mock_app = MagicMock() + + with pytest.raises(ValueError, match="STATE_MANAGER_SECRET is not set"): + async with app_main.lifespan(mock_app): + pass + + @patch.dict(os.environ, { + 'MONGO_URI': 'mongodb://test:27017', + 'MONGO_DATABASE_NAME': 'test_db', + 'STATE_MANAGER_SECRET': 'test_secret' + }) + @patch('app.main.init_beanie') + @patch('app.main.AsyncMongoClient') + @patch('app.main.LogsManager') + async def test_lifespan_init_beanie_with_correct_models(self, mock_logs_manager, mock_mongo_client, mock_init_beanie): + """Test that init_beanie is called with correct document models""" + mock_logger = MagicMock() + mock_logs_manager.return_value.get_logger.return_value = mock_logger + + mock_client = MagicMock() + mock_mongo_client.return_value = mock_client + mock_db = MagicMock() + mock_client.__getitem__.return_value = mock_db + + mock_init_beanie.return_value = AsyncMock() + + mock_app = MagicMock() + + async with app_main.lifespan(mock_app): + pass + + # Check that init_beanie was called with the database and correct models + mock_init_beanie.assert_called_once() + call_args = mock_init_beanie.call_args + + # First argument should be the database + assert call_args[0][0] == mock_db + + # Second argument should be document_models with the expected models + document_models = call_args[1]['document_models'] + + # Import the expected models + from app.models.db.state import State + from app.models.db.namespace import Namespace + from app.models.db.graph_template_model import GraphTemplate + from app.models.db.registered_node import RegisteredNode + + expected_models = [State, Namespace, GraphTemplate, RegisteredNode] + assert document_models == expected_models + + +class TestEnvironmentIntegration: + """Test cases for environment variable integration""" + + def test_load_dotenv_called(self): + """Test that load_dotenv is called during module import""" + # This test ensures that .env files are loaded + # We can't easily test this without reimporting the module, + # but we can verify the import doesn't fail + assert hasattr(app_main, 'load_dotenv') + + @patch.dict(os.environ, { + 'MONGO_URI': 'mongodb://custom:27017', + 'MONGO_DATABASE_NAME': 'custom_db', + 'STATE_MANAGER_SECRET': 'custom_secret' + }) + def test_environment_variables_usage(self): + """Test that environment variables are properly accessed""" + # Test that the module can access environment variables + assert os.getenv("MONGO_URI") == 'mongodb://custom:27017' + assert os.getenv("MONGO_DATABASE_NAME") == 'custom_db' + assert os.getenv("STATE_MANAGER_SECRET") == 'custom_secret' + + +class TestAppConfiguration: + """Test cases for application configuration""" + + def test_app_has_lifespan(self): + """Test that app is configured with lifespan""" + app = app_main.app + assert app.router.lifespan_context is not None + + def test_app_routes_configuration(self): + """Test that app routes are properly configured""" + app = app_main.app + + # Should have at least the health route + health_route_found = False + for route in app.routes: + if hasattr(route, 'path') and route.path == '/health': # type: ignore + health_route_found = True + break + + assert health_route_found, "Health route not found in app routes" + + def test_app_has_router_included(self): + """Test that main router is included""" + app = app_main.app + + # The app should have routes beyond just the health endpoint + # This indicates that the main router has been included + route_count = len([route for route in app.routes if hasattr(route, 'path')]) + assert route_count > 1, "Main router appears not to be included" \ No newline at end of file diff --git a/state-manager/tests/unit/utils/__init__.py b/state-manager/tests/unit/utils/__init__.py new file mode 100644 index 00000000..4f1a40a2 --- /dev/null +++ b/state-manager/tests/unit/utils/__init__.py @@ -0,0 +1 @@ +# Unit tests for utils package \ No newline at end of file diff --git a/state-manager/tests/unit/utils/test_check_secret.py b/state-manager/tests/unit/utils/test_check_secret.py new file mode 100644 index 00000000..cc946887 --- /dev/null +++ b/state-manager/tests/unit/utils/test_check_secret.py @@ -0,0 +1,216 @@ +import os +import pytest +from unittest.mock import patch +from fastapi import HTTPException +from fastapi.security.api_key import APIKeyHeader +from starlette.status import HTTP_401_UNAUTHORIZED + +from app.utils.check_secret import api_key_header, API_KEY_NAME + + +class TestCheckApiKey: + """Test cases for check_api_key function""" + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'test-secret-key'}) + @pytest.mark.asyncio + async def test_check_api_key_success_with_valid_key(self): + """Test check_api_key succeeds with valid API key""" + # Import here to get the updated environment variable + from app.utils.check_secret import check_api_key + + # Reload the module to pick up the new environment variable + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + + result = await check_api_key('test-secret-key') + assert result == 'test-secret-key' + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'test-secret-key'}) + @pytest.mark.asyncio + async def test_check_api_key_fails_with_invalid_key(self): + """Test check_api_key fails with invalid API key""" + # Import here to get the updated environment variable + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + with pytest.raises(HTTPException) as exc_info: + await check_api_key('wrong-key') + + assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'test-secret-key'}) + @pytest.mark.asyncio + async def test_check_api_key_fails_with_none_key(self): + """Test check_api_key fails with None API key""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + with pytest.raises(HTTPException) as exc_info: + await check_api_key(None) # type: ignore + + assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'test-secret-key'}) + @pytest.mark.asyncio + async def test_check_api_key_fails_with_empty_string_key(self): + """Test check_api_key fails with empty string API key""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + with pytest.raises(HTTPException) as exc_info: + await check_api_key('') + + assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'case-sensitive-key'}) + @pytest.mark.asyncio + async def test_check_api_key_is_case_sensitive(self): + """Test check_api_key is case sensitive""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + with pytest.raises(HTTPException) as exc_info: + await check_api_key('CASE-SENSITIVE-KEY') + + assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'whitespace-key'}) + @pytest.mark.asyncio + async def test_check_api_key_whitespace_sensitive(self): + """Test check_api_key is sensitive to whitespace""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + with pytest.raises(HTTPException) as exc_info: + await check_api_key(' whitespace-key ') + + assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED + assert exc_info.value.detail == "Invalid API key" + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'special-chars-!@#$%^&*()'}) + @pytest.mark.asyncio + async def test_check_api_key_with_special_characters(self): + """Test check_api_key works with special characters""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + result = await check_api_key('special-chars-!@#$%^&*()') + assert result == 'special-chars-!@#$%^&*()' + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'unicode-key-你好'}) + @pytest.mark.asyncio + async def test_check_api_key_with_unicode_characters(self): + """Test check_api_key works with unicode characters""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + result = await check_api_key('unicode-key-你好') + assert result == 'unicode-key-你好' + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': ''}) + @pytest.mark.asyncio + async def test_check_api_key_with_empty_env_variable(self): + """Test check_api_key when STATE_MANAGER_SECRET is empty string""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + # Empty string should match empty string + result = await check_api_key('') + assert result == '' + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'very-long-key-with-many-characters-1234567890-abcdefghijklmnopqrstuvwxyz-ABCDEFGHIJKLMNOPQRSTUVWXYZ'}) + @pytest.mark.asyncio + async def test_check_api_key_with_very_long_key(self): + """Test check_api_key works with very long keys""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + long_key = 'very-long-key-with-many-characters-1234567890-abcdefghijklmnopqrstuvwxyz-ABCDEFGHIJKLMNOPQRSTUVWXYZ' + result = await check_api_key(long_key) + assert result == long_key + + +class TestModuleConstants: + """Test cases for module constants and configuration""" + + def test_api_key_name_constant(self): + """Test API_KEY_NAME constant is correct""" + assert API_KEY_NAME == "x-api-key" + + def test_api_key_header_configuration(self): + """Test api_key_header is configured correctly""" + assert isinstance(api_key_header, APIKeyHeader) + assert api_key_header.model.name == "x-api-key" + assert api_key_header.auto_error is False + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'test-constant-key'}) + def test_api_key_loads_from_environment(self): + """Test API_KEY loads from environment variable""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + + # Access the reloaded module's API_KEY + assert app.utils.check_secret.API_KEY == 'test-constant-key' + +class TestIntegrationWithFastAPI: + """Integration tests with FastAPI dependency system""" + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'integration-test-key'}) + @pytest.mark.asyncio + async def test_dependency_integration_success(self): + """Test successful integration as FastAPI dependency""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + # Simulate FastAPI calling the dependency with the correct header value + result = await check_api_key('integration-test-key') + assert result == 'integration-test-key' + + @patch.dict(os.environ, {'STATE_MANAGER_SECRET': 'integration-test-key'}) + @pytest.mark.asyncio + async def test_dependency_integration_failure(self): + """Test failed integration as FastAPI dependency""" + import importlib + import app.utils.check_secret + importlib.reload(app.utils.check_secret) + from app.utils.check_secret import check_api_key + + # Simulate FastAPI calling the dependency with wrong header value + with pytest.raises(HTTPException) as exc_info: + await check_api_key('wrong-integration-key') + + assert exc_info.value.status_code == HTTP_401_UNAUTHORIZED + assert "Invalid API key" in exc_info.value.detail + + def test_api_key_header_accepts_none_when_auto_error_false(self): + """Test api_key_header configuration allows None when auto_error is False""" + # This tests the configuration, not the actual FastAPI behavior + # but ensures our APIKeyHeader is set up to not auto-error + assert api_key_header.auto_error is False + # This means FastAPI won't automatically raise 403 when header is missing \ No newline at end of file diff --git a/state-manager/tests/unit/utils/test_encrypter.py b/state-manager/tests/unit/utils/test_encrypter.py new file mode 100644 index 00000000..4d62ce1d --- /dev/null +++ b/state-manager/tests/unit/utils/test_encrypter.py @@ -0,0 +1,244 @@ +import os +import base64 +import pytest +from unittest.mock import patch, MagicMock +from cryptography.hazmat.primitives.ciphers.aead import AESGCM + +from app.utils.encrypter import Encrypter, get_encrypter + + +class TestEncrypter: + """Test cases for Encrypter class""" + + def setup_method(self): + """Reset the global encrypter instance before each test""" + import app.utils.encrypter + app.utils.encrypter._encrypter_instance = None + + def teardown_method(self): + """Clean up after each test""" + import app.utils.encrypter + app.utils.encrypter._encrypter_instance = None + + def test_generate_key_returns_valid_base64_key(self): + """Test that generate_key returns a valid base64 encoded key""" + key = Encrypter.generate_key() + + # Should be base64 encoded string + assert isinstance(key, str) + # Should be able to decode without exception + decoded_key = base64.urlsafe_b64decode(key) + # Should be 32 bytes (256 bits) + assert len(decoded_key) == 32 + + def test_generate_key_creates_different_keys(self): + """Test that generate_key creates different keys each time""" + key1 = Encrypter.generate_key() + key2 = Encrypter.generate_key() + + assert key1 != key2 + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'a' * 32).decode()}) + def test_encrypter_init_with_valid_key(self): + """Test Encrypter initialization with valid key""" + encrypter = Encrypter() + + assert encrypter._key == b'a' * 32 + assert isinstance(encrypter._aesgcm, AESGCM) + + @patch.dict(os.environ, {}, clear=True) + def test_encrypter_init_without_key_raises_error(self): + """Test Encrypter initialization without SECRETS_ENCRYPTION_KEY raises ValueError""" + with pytest.raises(ValueError, match="SECRETS_ENCRYPTION_KEY is not set"): + Encrypter() + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': 'invalid-base64!@#'}) + def test_encrypter_init_with_invalid_base64_raises_error(self): + """Test Encrypter initialization with invalid base64 key""" + with pytest.raises(ValueError, match="Key must be URL-safe base64"): + Encrypter() + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'too_short').decode()}) + def test_encrypter_init_with_wrong_key_length_raises_error(self): + """Test Encrypter initialization with wrong key length""" + with pytest.raises(ValueError, match="Key must be 32 raw bytes"): + Encrypter() + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_encrypt_returns_base64_string(self): + """Test that encrypt returns a base64 encoded string""" + encrypter = Encrypter() + secret = "my secret message" + + encrypted = encrypter.encrypt(secret) + + assert isinstance(encrypted, str) + # Should be able to decode without exception + base64.urlsafe_b64decode(encrypted) + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_encrypt_different_secrets_produce_different_results(self): + """Test that different secrets produce different encrypted results""" + encrypter = Encrypter() + + encrypted1 = encrypter.encrypt("secret1") + encrypted2 = encrypter.encrypt("secret2") + + assert encrypted1 != encrypted2 + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_encrypt_same_secret_produces_different_results(self): + """Test that same secret produces different encrypted results due to nonce""" + encrypter = Encrypter() + secret = "same secret" + + encrypted1 = encrypter.encrypt(secret) + encrypted2 = encrypter.encrypt(secret) + + # Should be different due to different nonces + assert encrypted1 != encrypted2 + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_decrypt_returns_original_secret(self): + """Test that decrypt returns the original secret""" + encrypter = Encrypter() + original_secret = "my secret message" + + encrypted = encrypter.encrypt(original_secret) + decrypted = encrypter.decrypt(encrypted) + + assert decrypted == original_secret + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_encrypt_decrypt_roundtrip_with_special_characters(self): + """Test encrypt/decrypt with special characters""" + encrypter = Encrypter() + original_secret = "Special chars: !@#$%^&*()_+-={}[]|\\:;\"'<>?,./" + + encrypted = encrypter.encrypt(original_secret) + decrypted = encrypter.decrypt(encrypted) + + assert decrypted == original_secret + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_encrypt_decrypt_roundtrip_with_unicode(self): + """Test encrypt/decrypt with unicode characters""" + encrypter = Encrypter() + original_secret = "Unicode: 你好世界 🌍 ñáéíóú" + + encrypted = encrypter.encrypt(original_secret) + decrypted = encrypter.decrypt(encrypted) + + assert decrypted == original_secret + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_encrypt_decrypt_empty_string(self): + """Test encrypt/decrypt with empty string""" + encrypter = Encrypter() + original_secret = "" + + encrypted = encrypter.encrypt(original_secret) + decrypted = encrypter.decrypt(encrypted) + + assert decrypted == original_secret + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_decrypt_with_invalid_base64_raises_error(self): + """Test decrypt with invalid base64 data raises exception""" + encrypter = Encrypter() + + with pytest.raises(Exception): # base64 decode error + encrypter.decrypt("invalid-base64!@#") + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_decrypt_with_wrong_key_raises_error(self): + """Test decrypt with data encrypted with different key raises exception""" + # Encrypt with one key + encrypter1 = Encrypter() + encrypted = encrypter1.encrypt("secret") + + # Try to decrypt with different key + with patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'y' * 32).decode()}): + encrypter2 = Encrypter() + with pytest.raises(Exception): # AESGCM decrypt error + encrypter2.decrypt(encrypted) + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_decrypt_with_corrupted_data_raises_error(self): + """Test decrypt with corrupted encrypted data raises exception""" + encrypter = Encrypter() + + # Create invalid encrypted data (too short) + invalid_encrypted = base64.urlsafe_b64encode(b'too_short').decode() + + with pytest.raises(Exception): # AESGCM decrypt error + encrypter.decrypt(invalid_encrypted) + + +class TestGetEncrypter: + """Test cases for get_encrypter function""" + + def setup_method(self): + """Reset the global encrypter instance before each test""" + import app.utils.encrypter + app.utils.encrypter._encrypter_instance = None + + def teardown_method(self): + """Clean up after each test""" + import app.utils.encrypter + app.utils.encrypter._encrypter_instance = None + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_get_encrypter_returns_encrypter_instance(self): + """Test get_encrypter returns an Encrypter instance""" + encrypter = get_encrypter() + + assert isinstance(encrypter, Encrypter) + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_get_encrypter_returns_same_instance_singleton(self): + """Test get_encrypter returns the same instance (singleton pattern)""" + encrypter1 = get_encrypter() + encrypter2 = get_encrypter() + + assert encrypter1 is encrypter2 + + @patch.dict(os.environ, {}, clear=True) + def test_get_encrypter_without_key_raises_error(self): + """Test get_encrypter without SECRETS_ENCRYPTION_KEY raises ValueError""" + with pytest.raises(ValueError, match="SECRETS_ENCRYPTION_KEY is not set"): + get_encrypter() + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': 'invalid-key'}) + def test_get_encrypter_with_invalid_key_raises_error(self): + """Test get_encrypter with invalid key raises ValueError""" + with pytest.raises(ValueError, match="Key must be URL-safe base64"): + get_encrypter() + + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_get_encrypter_functional_test(self): + """Test that get_encrypter returns a functional encrypter""" + encrypter = get_encrypter() + original_secret = "functional test secret" + + encrypted = encrypter.encrypt(original_secret) + decrypted = encrypter.decrypt(encrypted) + + assert decrypted == original_secret + + @patch('app.utils.encrypter.Encrypter') + @patch.dict(os.environ, {'SECRETS_ENCRYPTION_KEY': base64.urlsafe_b64encode(b'x' * 32).decode()}) + def test_get_encrypter_creates_instance_only_once(self, mock_encrypter_class): + """Test that get_encrypter creates Encrypter instance only once""" + mock_instance = MagicMock() + mock_encrypter_class.return_value = mock_instance + + # Call get_encrypter multiple times + result1 = get_encrypter() + result2 = get_encrypter() + result3 = get_encrypter() + + # Encrypter constructor should be called only once + assert mock_encrypter_class.call_count == 1 + # All calls should return the same instance + assert result1 is result2 is result3 is mock_instance \ No newline at end of file