diff --git a/state-manager/.coverage b/state-manager/.coverage deleted file mode 100644 index e0975eb1..00000000 Binary files a/state-manager/.coverage and /dev/null differ diff --git a/state-manager/app/controller/manual_retry_state.py b/state-manager/app/controller/manual_retry_state.py new file mode 100644 index 00000000..17c926e5 --- /dev/null +++ b/state-manager/app/controller/manual_retry_state.py @@ -0,0 +1,49 @@ +from pymongo.errors import DuplicateKeyError +from app.models.manual_retry import ManualRetryRequestModel, ManualRetryResponseModel +from beanie import PydanticObjectId +from app.singletons.logs_manager import LogsManager +from app.models.state_status_enum import StateStatusEnum +from fastapi import HTTPException, status +from app.models.db.state import State + + +logger = LogsManager().get_logger() + +async def manual_retry_state(namespace_name: str, state_id: PydanticObjectId, body: ManualRetryRequestModel, x_exosphere_request_id: str): + try: + logger.info(f"Manual retry state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + state = await State.find_one(State.id == state_id, State.namespace_name == namespace_name) + if not state: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") + + try: + retry_state = State( + node_name=state.node_name, + namespace_name=state.namespace_name, + identifier=state.identifier, + graph_name=state.graph_name, + run_id=state.run_id, + status=StateStatusEnum.CREATED, + inputs=state.inputs, + outputs={}, + error=None, + parents=state.parents, + does_unites=state.does_unites, + fanout_id=body.fanout_id # this will ensure that multiple unwanted retries are not formed because of index in database + ) + retry_state = await retry_state.insert() + logger.info(f"Retry state {retry_state.id} created for state {state_id}", x_exosphere_request_id=x_exosphere_request_id) + + state.status = StateStatusEnum.RETRY_CREATED + await state.save() + + return ManualRetryResponseModel(id=str(retry_state.id), status=retry_state.status) + except DuplicateKeyError: + logger.info(f"Duplicate retry state detected for state {state_id}. A retry state with the same unique key already exists.", x_exosphere_request_id=x_exosphere_request_id) + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="Duplicate retry state detected") + + + except Exception as _: + logger.error(f"Error manual retry state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise diff --git a/state-manager/app/models/manual_retry.py b/state-manager/app/models/manual_retry.py new file mode 100644 index 00000000..0aec686b --- /dev/null +++ b/state-manager/app/models/manual_retry.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, Field +from .state_status_enum import StateStatusEnum + + +class ManualRetryRequestModel(BaseModel): + fanout_id: str = Field(..., description="Fanout ID of the state") + + +class ManualRetryResponseModel(BaseModel): + id: str = Field(..., description="ID of the state") + status: StateStatusEnum = Field(..., description="Status of the state") diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 53f73ff5..e552081f 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -50,6 +50,10 @@ from .models.signal_models import ReEnqueueAfterRequestModel from .controller.re_queue_after_signal import re_queue_after_signal +# manual_retry +from .models.manual_retry import ManualRetryRequestModel, ManualRetryResponseModel +from .controller.manual_retry_state import manual_retry_state + logger = LogsManager().get_logger() @@ -176,6 +180,24 @@ async def re_enqueue_after_state_route(namespace_name: str, state_id: str, body: return await re_queue_after_signal(namespace_name, PydanticObjectId(state_id), body, x_exosphere_request_id) +@router.post( + "/state/{state_id}/manual-retry", + response_model=ManualRetryResponseModel, + status_code=status.HTTP_200_OK, + response_description="State manual retry successfully", + tags=["state"] +) +async def manual_retry_state_route(namespace_name: str, state_id: str, body: ManualRetryRequestModel, request: Request, api_key: str = Depends(check_api_key)): + x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) + + if api_key: + logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + else: + logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return await manual_retry_state(namespace_name, PydanticObjectId(state_id), body, x_exosphere_request_id) + @router.put( "/graph/{graph_name}", diff --git a/state-manager/tests/README.md b/state-manager/tests/README.md index aafe1dff..c2a89da1 100644 --- a/state-manager/tests/README.md +++ b/state-manager/tests/README.md @@ -14,6 +14,7 @@ tests/ │ ├── test_errored_state.py │ ├── test_get_graph_template.py │ ├── test_get_secrets.py +│ ├── test_manual_retry_state.py │ ├── test_register_nodes.py │ └── test_upsert_graph_template.py └── README.md @@ -80,7 +81,21 @@ The unit tests cover all controller functions in the state-manager: - ✅ Complex schema handling - ✅ Database error handling -### 8. `upsert_graph_template.py` +### 8. `manual_retry_state.py` +- ✅ Successful manual retry state creation +- ✅ State not found scenarios +- ✅ Duplicate retry state detection (DuplicateKeyError) +- ✅ Different fanout_id handling +- ✅ Complex inputs and multiple parents preservation +- ✅ Database errors during state lookup +- ✅ Database errors during state save +- ✅ Database errors during retry state insert +- ✅ Empty inputs and parents handling +- ✅ Namespace mismatch scenarios +- ✅ Field preservation and reset logic +- ✅ Logging verification + +### 9. `upsert_graph_template.py` - ✅ Existing template updates - ✅ New template creation - ✅ Empty nodes handling diff --git a/state-manager/tests/unit/controller/test_manual_retry_state.py b/state-manager/tests/unit/controller/test_manual_retry_state.py new file mode 100644 index 00000000..072372e0 --- /dev/null +++ b/state-manager/tests/unit/controller/test_manual_retry_state.py @@ -0,0 +1,518 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from fastapi import HTTPException, status +from beanie import PydanticObjectId +from pymongo.errors import DuplicateKeyError + +from app.controller.manual_retry_state import manual_retry_state +from app.models.manual_retry import ManualRetryRequestModel, ManualRetryResponseModel +from app.models.state_status_enum import StateStatusEnum + + +class TestManualRetryState: + """Test cases for manual_retry_state function""" + + @pytest.fixture + def mock_request_id(self): + return "test-request-id" + + @pytest.fixture + def mock_namespace(self): + return "test_namespace" + + @pytest.fixture + def mock_state_id(self): + return PydanticObjectId() + + @pytest.fixture + def mock_manual_retry_request(self): + return ManualRetryRequestModel( + fanout_id="test-fanout-id-123" + ) + + @pytest.fixture + def mock_original_state(self): + state = MagicMock() + state.id = PydanticObjectId() + state.node_name = "test_node" + state.namespace_name = "test_namespace" + state.identifier = "test_identifier" + state.graph_name = "test_graph" + state.run_id = "test_run_id" + state.status = StateStatusEnum.EXECUTED + state.inputs = {"key": "value"} + state.outputs = {"result": "success"} + state.error = "Original error" + state.parents = {"parent1": PydanticObjectId()} + state.does_unites = False + state.save = AsyncMock() + return state + + @pytest.fixture + def mock_retry_state(self): + retry_state = MagicMock() + retry_state.id = PydanticObjectId() + retry_state.status = StateStatusEnum.CREATED + retry_state.insert = AsyncMock(return_value=retry_state) + return retry_state + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_success( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_original_state, + mock_retry_state, + mock_request_id + ): + """Test successful manual retry state creation""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=mock_original_state) + mock_state_class.return_value = mock_retry_state + + # Act + result = await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + # Assert + assert isinstance(result, ManualRetryResponseModel) + assert result.id == str(mock_retry_state.id) + assert result.status == StateStatusEnum.CREATED + + # Verify State.find_one was called with correct parameters + mock_state_class.find_one.assert_called_once() + call_args = mock_state_class.find_one.call_args[0] + # Check that both conditions were passed + assert len(call_args) == 2 + + # Verify original state was updated to RETRY_CREATED + assert mock_original_state.status == StateStatusEnum.RETRY_CREATED + mock_original_state.save.assert_called_once() + + # Verify retry state was created with correct attributes + mock_state_class.assert_called_once() + retry_state_args = mock_state_class.call_args[1] + assert retry_state_args['node_name'] == mock_original_state.node_name + assert retry_state_args['namespace_name'] == mock_original_state.namespace_name + assert retry_state_args['identifier'] == mock_original_state.identifier + assert retry_state_args['graph_name'] == mock_original_state.graph_name + assert retry_state_args['run_id'] == mock_original_state.run_id + assert retry_state_args['status'] == StateStatusEnum.CREATED + assert retry_state_args['inputs'] == mock_original_state.inputs + assert retry_state_args['outputs'] == {} + assert retry_state_args['error'] is None + assert retry_state_args['parents'] == mock_original_state.parents + assert retry_state_args['does_unites'] == mock_original_state.does_unites + assert retry_state_args['fanout_id'] == mock_manual_retry_request.fanout_id + + # Verify retry state was inserted + mock_retry_state.insert.assert_called_once() + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_not_found( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ): + """Test when original state is not found""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=None) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert exc_info.value.detail == "State not found" + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_duplicate_key_error( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_original_state, + mock_retry_state, + mock_request_id + ): + """Test when duplicate retry state is detected""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=mock_original_state) + mock_retry_state.insert = AsyncMock(side_effect=DuplicateKeyError("Duplicate key")) + mock_state_class.return_value = mock_retry_state + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_409_CONFLICT + assert exc_info.value.detail == "Duplicate retry state detected" + + # Verify original state was not updated since duplicate was detected + mock_original_state.save.assert_not_called() + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_with_different_fanout_id( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_original_state, + mock_retry_state, + mock_request_id + ): + """Test manual retry with different fanout_id""" + # Arrange + different_fanout_request = ManualRetryRequestModel( + fanout_id="different-fanout-id-456" + ) + mock_state_class.find_one = AsyncMock(return_value=mock_original_state) + mock_state_class.return_value = mock_retry_state + + # Act + result = await manual_retry_state( + mock_namespace, + mock_state_id, + different_fanout_request, + mock_request_id + ) + + # Assert + assert isinstance(result, ManualRetryResponseModel) + assert result.id == str(mock_retry_state.id) + assert result.status == StateStatusEnum.CREATED + + # Verify retry state was created with the different fanout_id + retry_state_args = mock_state_class.call_args[1] + assert retry_state_args['fanout_id'] == "different-fanout-id-456" + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_with_complex_inputs_and_parents( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_retry_state, + mock_request_id + ): + """Test manual retry with complex inputs and multiple parents""" + # Arrange + complex_state = MagicMock() + complex_state.id = PydanticObjectId() + complex_state.node_name = "complex_node" + complex_state.namespace_name = "test_namespace" + complex_state.identifier = "complex_identifier" + complex_state.graph_name = "complex_graph" + complex_state.run_id = "complex_run_id" + complex_state.status = StateStatusEnum.ERRORED + complex_state.inputs = { + "nested_data": {"key1": "value1", "key2": [1, 2, 3]}, + "simple_value": "test", + "number": 42 + } + complex_state.outputs = {"previous_result": "some_output"} + complex_state.error = "Complex error message" + complex_state.parents = { + "parent1": PydanticObjectId(), + "parent2": PydanticObjectId(), + "parent3": PydanticObjectId() + } + complex_state.does_unites = True + complex_state.save = AsyncMock() + + mock_state_class.find_one = AsyncMock(return_value=complex_state) + mock_state_class.return_value = mock_retry_state + + # Act + result = await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + # Assert + assert isinstance(result, ManualRetryResponseModel) + + # Verify retry state preserves complex data structures + retry_state_args = mock_state_class.call_args[1] + assert retry_state_args['inputs'] == complex_state.inputs + assert retry_state_args['parents'] == complex_state.parents + assert retry_state_args['does_unites'] == complex_state.does_unites + assert retry_state_args['outputs'] == {} # Should be reset + assert retry_state_args['error'] is None # Should be reset + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_database_error_on_find( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ): + """Test handling of database error during state lookup""" + # Arrange + mock_state_class.find_one = AsyncMock(side_effect=Exception("Database connection error")) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + assert str(exc_info.value) == "Database connection error" + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_database_error_on_save( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_original_state, + mock_retry_state, + mock_request_id + ): + """Test handling of database error during original state save""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=mock_original_state) + mock_state_class.return_value = mock_retry_state + mock_original_state.save = AsyncMock(side_effect=Exception("Save operation failed")) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + assert str(exc_info.value) == "Save operation failed" + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_database_error_on_insert( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_original_state, + mock_retry_state, + mock_request_id + ): + """Test handling of database error during retry state insert (non-duplicate)""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=mock_original_state) + mock_retry_state.insert = AsyncMock(side_effect=Exception("Insert operation failed")) + mock_state_class.return_value = mock_retry_state + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + assert str(exc_info.value) == "Insert operation failed" + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_empty_inputs_and_parents( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_retry_state, + mock_request_id + ): + """Test manual retry with empty inputs and parents""" + # Arrange + empty_state = MagicMock() + empty_state.id = PydanticObjectId() + empty_state.node_name = "empty_node" + empty_state.namespace_name = "test_namespace" + empty_state.identifier = "empty_identifier" + empty_state.graph_name = "empty_graph" + empty_state.run_id = "empty_run_id" + empty_state.status = StateStatusEnum.EXECUTED + empty_state.inputs = {} + empty_state.outputs = {} + empty_state.error = None + empty_state.parents = {} + empty_state.does_unites = False + empty_state.save = AsyncMock() + + mock_state_class.find_one = AsyncMock(return_value=empty_state) + mock_state_class.return_value = mock_retry_state + + # Act + result = await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + # Assert + assert isinstance(result, ManualRetryResponseModel) + + # Verify retry state handles empty collections correctly + retry_state_args = mock_state_class.call_args[1] + assert retry_state_args['inputs'] == {} + assert retry_state_args['parents'] == {} + assert retry_state_args['outputs'] == {} + assert retry_state_args['error'] is None + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_namespace_mismatch( + self, + mock_state_class, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ): + """Test manual retry with namespace that doesn't match any state""" + # Arrange + different_namespace = "different_namespace" + mock_state_class.find_one = AsyncMock(return_value=None) + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await manual_retry_state( + different_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND + assert exc_info.value.detail == "State not found" + + # Verify find_one was called with the different namespace + mock_state_class.find_one.assert_called_once() + + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_preserves_all_original_fields( + self, + mock_state_class, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_retry_state, + mock_request_id + ): + """Test that all relevant fields from original state are preserved in retry state""" + # Arrange + original_state = MagicMock() + original_state.id = PydanticObjectId() + original_state.node_name = "preserve_test_node" + original_state.namespace_name = "preserve_test_namespace" + original_state.identifier = "preserve_test_identifier" + original_state.graph_name = "preserve_test_graph" + original_state.run_id = "preserve_test_run_id" + original_state.status = StateStatusEnum.EXECUTED + original_state.inputs = {"preserve": "input_data"} + original_state.outputs = {"should_be": "reset"} + original_state.error = "should_be_reset" + original_state.parents = {"preserve_parent": PydanticObjectId()} + original_state.does_unites = True + original_state.save = AsyncMock() + + mock_state_class.find_one = AsyncMock(return_value=original_state) + mock_state_class.return_value = mock_retry_state + + # Act + await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + # Assert - verify all fields are correctly set + retry_state_args = mock_state_class.call_args[1] + + # Fields that should be preserved + assert retry_state_args['node_name'] == original_state.node_name + assert retry_state_args['namespace_name'] == original_state.namespace_name + assert retry_state_args['identifier'] == original_state.identifier + assert retry_state_args['graph_name'] == original_state.graph_name + assert retry_state_args['run_id'] == original_state.run_id + assert retry_state_args['inputs'] == original_state.inputs + assert retry_state_args['parents'] == original_state.parents + assert retry_state_args['does_unites'] == original_state.does_unites + assert retry_state_args['fanout_id'] == mock_manual_retry_request.fanout_id + + # Fields that should be reset/set to specific values + assert retry_state_args['status'] == StateStatusEnum.CREATED + assert retry_state_args['outputs'] == {} + assert retry_state_args['error'] is None + + @patch('app.controller.manual_retry_state.logger') + @patch('app.controller.manual_retry_state.State') + async def test_manual_retry_state_logging_calls( + self, + mock_state_class, + mock_logger, + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_original_state, + mock_retry_state, + mock_request_id + ): + """Test that appropriate logging calls are made""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=mock_original_state) + mock_state_class.return_value = mock_retry_state + + # Act + await manual_retry_state( + mock_namespace, + mock_state_id, + mock_manual_retry_request, + mock_request_id + ) + + # Assert - verify logging calls were made + assert mock_logger.info.call_count >= 2 # At least initial log and success log + + # Check that the initial log contains expected information + first_call_args = mock_logger.info.call_args_list[0] + assert str(mock_state_id) in first_call_args[0][0] + assert mock_namespace in first_call_args[0][0] + assert first_call_args[1]['x_exosphere_request_id'] == mock_request_id + + # Check that the success log contains retry state id + second_call_args = mock_logger.info.call_args_list[1] + assert str(mock_retry_state.id) in second_call_args[0][0] + assert str(mock_state_id) in second_call_args[0][0] + assert second_call_args[1]['x_exosphere_request_id'] == mock_request_id diff --git a/state-manager/tests/unit/models/test_manual_retry.py b/state-manager/tests/unit/models/test_manual_retry.py new file mode 100644 index 00000000..5869702c --- /dev/null +++ b/state-manager/tests/unit/models/test_manual_retry.py @@ -0,0 +1,241 @@ +import pytest +from pydantic import ValidationError + +from app.models.manual_retry import ManualRetryRequestModel, ManualRetryResponseModel +from app.models.state_status_enum import StateStatusEnum + + +class TestManualRetryRequestModel: + """Test cases for ManualRetryRequestModel""" + + def test_manual_retry_request_model_valid_data(self): + """Test ManualRetryRequestModel with valid fanout_id""" + # Arrange & Act + fanout_id = "test-fanout-id-123" + model = ManualRetryRequestModel(fanout_id=fanout_id) + + # Assert + assert model.fanout_id == fanout_id + + def test_manual_retry_request_model_empty_fanout_id(self): + """Test ManualRetryRequestModel with empty fanout_id""" + # Arrange & Act + fanout_id = "" + model = ManualRetryRequestModel(fanout_id=fanout_id) + + # Assert + assert model.fanout_id == fanout_id + + def test_manual_retry_request_model_uuid_fanout_id(self): + """Test ManualRetryRequestModel with UUID fanout_id""" + # Arrange & Act + fanout_id = "550e8400-e29b-41d4-a716-446655440000" + model = ManualRetryRequestModel(fanout_id=fanout_id) + + # Assert + assert model.fanout_id == fanout_id + + def test_manual_retry_request_model_long_fanout_id(self): + """Test ManualRetryRequestModel with long fanout_id""" + # Arrange & Act + fanout_id = "a" * 1000 # Very long string + model = ManualRetryRequestModel(fanout_id=fanout_id) + + # Assert + assert model.fanout_id == fanout_id + + def test_manual_retry_request_model_special_characters_fanout_id(self): + """Test ManualRetryRequestModel with special characters in fanout_id""" + # Arrange & Act + fanout_id = "test-fanout@#$%^&*()_+-={}[]|\\:;\"'<>?,./" + model = ManualRetryRequestModel(fanout_id=fanout_id) + + # Assert + assert model.fanout_id == fanout_id + + def test_manual_retry_request_model_missing_fanout_id(self): + """Test ManualRetryRequestModel with missing fanout_id field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ManualRetryRequestModel() # type: ignore + + assert "fanout_id" in str(exc_info.value) + assert "Field required" in str(exc_info.value) + + def test_manual_retry_request_model_none_fanout_id(self): + """Test ManualRetryRequestModel with None fanout_id""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ManualRetryRequestModel(fanout_id=None) # type: ignore + + assert "fanout_id" in str(exc_info.value) + + def test_manual_retry_request_model_numeric_fanout_id(self): + """Test ManualRetryRequestModel with numeric fanout_id (should fail validation)""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ManualRetryRequestModel(fanout_id=12345) # type: ignore + + assert "string_type" in str(exc_info.value) + + def test_manual_retry_request_model_dict_representation(self): + """Test ManualRetryRequestModel dict representation""" + # Arrange & Act + fanout_id = "test-fanout-id" + model = ManualRetryRequestModel(fanout_id=fanout_id) + + # Assert + expected_dict = {"fanout_id": fanout_id} + assert model.model_dump() == expected_dict + + def test_manual_retry_request_model_json_serialization(self): + """Test ManualRetryRequestModel JSON serialization""" + # Arrange & Act + fanout_id = "test-fanout-id" + model = ManualRetryRequestModel(fanout_id=fanout_id) + + # Assert + json_str = model.model_dump_json() + assert f'"fanout_id":"{fanout_id}"' in json_str + + +class TestManualRetryResponseModel: + """Test cases for ManualRetryResponseModel""" + + def test_manual_retry_response_model_valid_data(self): + """Test ManualRetryResponseModel with valid data""" + # Arrange & Act + state_id = "507f1f77bcf86cd799439011" + status = StateStatusEnum.CREATED + model = ManualRetryResponseModel(id=state_id, status=status) + + # Assert + assert model.id == state_id + assert model.status == status + + def test_manual_retry_response_model_all_status_types(self): + """Test ManualRetryResponseModel with all possible status values""" + # Arrange & Act & Assert + state_id = "507f1f77bcf86cd799439011" + + for status in StateStatusEnum: + model = ManualRetryResponseModel(id=state_id, status=status) + assert model.id == state_id + assert model.status == status + + def test_manual_retry_response_model_created_status(self): + """Test ManualRetryResponseModel with CREATED status""" + # Arrange & Act + state_id = "507f1f77bcf86cd799439011" + status = StateStatusEnum.CREATED + model = ManualRetryResponseModel(id=state_id, status=status) + + # Assert + assert model.id == state_id + assert model.status == StateStatusEnum.CREATED + + def test_manual_retry_response_model_retry_created_status(self): + """Test ManualRetryResponseModel with RETRY_CREATED status""" + # Arrange & Act + state_id = "507f1f77bcf86cd799439011" + status = StateStatusEnum.RETRY_CREATED + model = ManualRetryResponseModel(id=state_id, status=status) + + # Assert + assert model.id == state_id + assert model.status == StateStatusEnum.RETRY_CREATED + + def test_manual_retry_response_model_missing_id(self): + """Test ManualRetryResponseModel with missing id field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ManualRetryResponseModel(status=StateStatusEnum.CREATED) # type: ignore + + assert "id" in str(exc_info.value) + assert "Field required" in str(exc_info.value) + + def test_manual_retry_response_model_missing_status(self): + """Test ManualRetryResponseModel with missing status field""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ManualRetryResponseModel(id="507f1f77bcf86cd799439011") # type: ignore + + assert "status" in str(exc_info.value) + assert "Field required" in str(exc_info.value) + + def test_manual_retry_response_model_none_id(self): + """Test ManualRetryResponseModel with None id""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ManualRetryResponseModel(id=None, status=StateStatusEnum.CREATED) # type: ignore + + assert "id" in str(exc_info.value) + + def test_manual_retry_response_model_none_status(self): + """Test ManualRetryResponseModel with None status""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ManualRetryResponseModel(id="507f1f77bcf86cd799439011", status=None) # type: ignore + + assert "status" in str(exc_info.value) + + def test_manual_retry_response_model_invalid_status(self): + """Test ManualRetryResponseModel with invalid status""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ManualRetryResponseModel(id="507f1f77bcf86cd799439011", status="INVALID_STATUS") # type: ignore + + assert "status" in str(exc_info.value) + + def test_manual_retry_response_model_numeric_id(self): + """Test ManualRetryResponseModel with numeric id (should fail validation)""" + # Arrange & Act & Assert + with pytest.raises(ValidationError) as exc_info: + ManualRetryResponseModel(id=12345, status=StateStatusEnum.CREATED) # type: ignore + + assert "string_type" in str(exc_info.value) + + def test_manual_retry_response_model_dict_representation(self): + """Test ManualRetryResponseModel dict representation""" + # Arrange & Act + state_id = "507f1f77bcf86cd799439011" + status = StateStatusEnum.CREATED + model = ManualRetryResponseModel(id=state_id, status=status) + + # Assert + expected_dict = {"id": state_id, "status": status} + assert model.model_dump() == expected_dict + + def test_manual_retry_response_model_json_serialization(self): + """Test ManualRetryResponseModel JSON serialization""" + # Arrange & Act + state_id = "507f1f77bcf86cd799439011" + status = StateStatusEnum.CREATED + model = ManualRetryResponseModel(id=state_id, status=status) + + # Assert + json_str = model.model_dump_json() + assert f'"id":"{state_id}"' in json_str + assert f'"status":"{status.value}"' in json_str + + def test_manual_retry_response_model_empty_id(self): + """Test ManualRetryResponseModel with empty string id""" + # Arrange & Act + state_id = "" + status = StateStatusEnum.CREATED + model = ManualRetryResponseModel(id=state_id, status=status) + + # Assert + assert model.id == state_id + assert model.status == status + + def test_manual_retry_response_model_long_id(self): + """Test ManualRetryResponseModel with very long id""" + # Arrange & Act + state_id = "a" * 1000 # Very long string + status = StateStatusEnum.CREATED + model = ManualRetryResponseModel(id=state_id, status=status) + + # Assert + assert model.id == state_id + assert model.status == status \ No newline at end of file diff --git a/state-manager/tests/unit/test_routes.py b/state-manager/tests/unit/test_routes.py index af50f066..97568887 100644 --- a/state-manager/tests/unit/test_routes.py +++ b/state-manager/tests/unit/test_routes.py @@ -8,6 +8,7 @@ from app.models.secrets_response import SecretsResponseModel from app.models.list_models import ListRegisteredNodesResponse, ListGraphTemplatesResponse from app.models.run_models import RunsResponse, RunListItem, RunStatusEnum +from app.models.manual_retry import ManualRetryRequestModel, ManualRetryResponseModel import pytest @@ -32,6 +33,7 @@ def test_router_has_correct_routes(self): assert any('/v0/namespace/{namespace_name}/state/{state_id}/errored' in path for path in paths) assert any('/v0/namespace/{namespace_name}/state/{state_id}/prune' in path for path in paths) assert any('/v0/namespace/{namespace_name}/state/{state_id}/re-enqueue-after' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/state/{state_id}/manual-retry' in path for path in paths) # Graph template routes (there are two /graph/{graph_name} routes - GET and PUT) assert any('/v0/namespace/{namespace_name}/graph/{graph_name}' in path for path in paths) @@ -89,7 +91,7 @@ def test_trigger_graph_request_model_validation(self): "store": {"s1": "v1"}, "inputs": {"input1": "value1"} } - model = TriggerGraphRequestModel(**valid_data) + model = TriggerGraphRequestModel(**valid_data) # type: ignore assert model.store == {"s1": "v1"} assert model.inputs == {"input1": "value1"} @@ -273,6 +275,26 @@ def test_list_graph_templates_response_validation(self): assert model.namespace == "test" assert model.count == 0 + def test_manual_retry_request_model_validation(self): + """Test ManualRetryRequestModel validation""" + # Test with valid data + valid_data = {"fanout_id": "test-fanout-id-123"} + model = ManualRetryRequestModel(**valid_data) + assert model.fanout_id == "test-fanout-id-123" + + def test_manual_retry_response_model_validation(self): + """Test ManualRetryResponseModel validation""" + from app.models.state_status_enum import StateStatusEnum + + # Test with valid data + valid_data = { + "id": "507f1f77bcf86cd799439011", + "status": StateStatusEnum.CREATED + } + model = ManualRetryResponseModel(**valid_data) + assert model.id == "507f1f77bcf86cd799439011" + assert model.status == StateStatusEnum.CREATED + @@ -295,7 +317,8 @@ def test_route_handlers_exist(self): list_graph_templates_route, get_runs_route, get_graph_structure_route, - get_node_run_details_route + get_node_run_details_route, + manual_retry_state_route ) @@ -313,6 +336,7 @@ def test_route_handlers_exist(self): assert callable(get_runs_route) assert callable(get_graph_structure_route) assert callable(get_node_run_details_route) + assert callable(manual_retry_state_route) @@ -1033,4 +1057,64 @@ async def test_get_node_run_details_route_with_invalid_api_key(self, mock_get_no assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid API key" - mock_get_node_run_details.assert_not_called() \ No newline at end of file + mock_get_node_run_details.assert_not_called() + + @patch('app.routes.manual_retry_state') + async def test_manual_retry_state_route_with_valid_api_key(self, mock_manual_retry_state, mock_request): + """Test manual_retry_state_route with valid API key""" + from app.routes import manual_retry_state_route + + # Arrange + mock_manual_retry_state.return_value = MagicMock() + body = ManualRetryRequestModel(fanout_id="test-fanout-id") + + # Act + result = await manual_retry_state_route("test_namespace", "507f1f77bcf86cd799439011", body, mock_request, "valid_key") + + # Assert + mock_manual_retry_state.assert_called_once() + call_args = mock_manual_retry_state.call_args + assert call_args[0][0] == "test_namespace" # namespace_name + assert str(call_args[0][1]) == "507f1f77bcf86cd799439011" # state_id as PydanticObjectId + assert call_args[0][2] == body # body + assert call_args[0][3] == "test-request-id" # x_exosphere_request_id + assert result == mock_manual_retry_state.return_value + + @patch('app.routes.manual_retry_state') + async def test_manual_retry_state_route_with_invalid_api_key(self, mock_manual_retry_state, mock_request): + """Test manual_retry_state_route with invalid API key""" + from app.routes import manual_retry_state_route + from fastapi import HTTPException + + # Arrange + body = ManualRetryRequestModel(fanout_id="test-fanout-id") + + # Act & Assert + with pytest.raises(HTTPException) as exc_info: + await manual_retry_state_route("test_namespace", "507f1f77bcf86cd799439011", body, mock_request, None) # type: ignore + + assert exc_info.value.status_code == 401 + assert exc_info.value.detail == "Invalid API key" + mock_manual_retry_state.assert_not_called() + + @patch('app.routes.manual_retry_state') + async def test_manual_retry_state_route_without_request_id(self, mock_manual_retry_state, mock_request_no_id): + """Test manual_retry_state_route without x_exosphere_request_id""" + from app.routes import manual_retry_state_route + + # Arrange + mock_manual_retry_state.return_value = MagicMock() + body = ManualRetryRequestModel(fanout_id="test-fanout-id") + + # Act + result = await manual_retry_state_route("test_namespace", "507f1f77bcf86cd799439011", body, mock_request_no_id, "valid_key") + + # Assert + mock_manual_retry_state.assert_called_once() + call_args = mock_manual_retry_state.call_args + assert call_args[0][0] == "test_namespace" # namespace_name + assert str(call_args[0][1]) == "507f1f77bcf86cd799439011" # state_id as PydanticObjectId + assert call_args[0][2] == body # body + # Should generate a UUID when no request ID is present + assert len(call_args[0][3]) > 0 # x_exosphere_request_id should be generated + assert result == mock_manual_retry_state.return_value \ No newline at end of file