diff --git a/state-manager/app/controller/executed_state.py b/state-manager/app/controller/executed_state.py index 7a38bd0d..f255d746 100644 --- a/state-manager/app/controller/executed_state.py +++ b/state-manager/app/controller/executed_state.py @@ -1,4 +1,5 @@ from beanie import PydanticObjectId +from beanie.operators import In from app.models.executed_models import ExecutedRequestModel, ExecutedResponseModel from fastapi import HTTPException, status, BackgroundTasks @@ -36,11 +37,11 @@ async def executed_state(namespace_name: str, state_id: PydanticObjectId, body: state.parents = {**state.parents, state.identifier: state.id} await state.save() - background_tasks.add_task(create_next_state, state) + new_states = [] for output in body.outputs[1:]: - new_state = State( + new_states.append(State( node_name=state.node_name, namespace_name=state.namespace_name, identifier=state.identifier, @@ -54,9 +55,20 @@ async def executed_state(namespace_name: str, state_id: PydanticObjectId, body: **state.parents, state.identifier: state.id } - ) - await new_state.save() - background_tasks.add_task(create_next_state, new_state) + )) + + if len(new_states) > 0: + inserted_ids = (await State.insert_many(new_states)).inserted_ids + + inserted_states = await State.find( + In(State.id, inserted_ids) + ).to_list() + + if len(inserted_states) != len(new_states): + raise RuntimeError(f"Failed to insert all new states. Expected {len(new_states)} states, but only {len(inserted_states)} were inserted") + + for inserted_state in inserted_states: + background_tasks.add_task(create_next_state, inserted_state) return ExecutedResponseModel(status=StateStatusEnum.EXECUTED) diff --git a/state-manager/app/singletons/logs_manager.py b/state-manager/app/singletons/logs_manager.py index 6474cb3b..cc8573cf 100644 --- a/state-manager/app/singletons/logs_manager.py +++ b/state-manager/app/singletons/logs_manager.py @@ -1,5 +1,7 @@ import structlog import logging +import os +import sys from .SingletonDecorator import singleton @@ -28,9 +30,37 @@ def __init__(self): handler.setFormatter(formatter) logger = logging.getLogger() logger.addHandler(handler) - logger.setLevel(logging.INFO) + + # Check if running in development mode + # Development mode is determined by the --mode argument passed to run.py + is_development = self._is_development_mode() + + if is_development: + # In development mode, set level to WARNING to disable INFO logs + logger.setLevel(logging.WARNING) + else: + # In production mode, keep INFO level + logger.setLevel(logging.INFO) self.logger = structlog.get_logger() + def _is_development_mode(self) -> bool: + """ + Check if the application is running in development mode. + Development mode is determined by checking if '--mode' 'development' + is in the command line arguments. + """ + # Check command line arguments for development mode + if '--mode' in sys.argv: + try: + mode_index = sys.argv.index('--mode') + if mode_index + 1 < len(sys.argv) and sys.argv[mode_index + 1] == 'development': + return True + except (ValueError, IndexError): + pass + + # Fallback: check environment variable + return os.getenv('MODE', '').lower() == 'development' + def get_logger(self): return self.logger diff --git a/state-manager/app/tasks/create_next_state.py b/state-manager/app/tasks/create_next_state.py index e7d67238..bf03679e 100644 --- a/state-manager/app/tasks/create_next_state.py +++ b/state-manager/app/tasks/create_next_state.py @@ -14,6 +14,7 @@ logger = LogsManager().get_logger() async def create_next_state(state: State): + logger.info(f"Creating next state for {state.identifier}") graph_template = None if state is None or state.id is None: diff --git a/state-manager/run.py b/state-manager/run.py index edcd515d..c3b0ccb7 100644 --- a/state-manager/run.py +++ b/state-manager/run.py @@ -12,10 +12,10 @@ def serve(): mode = args.mode + workers = args.workers if mode == "development": - uvicorn.run("app.main:app", reload=True, host="0.0.0.0", port=8000) + uvicorn.run("app.main:app", workers=workers, reload=True, host="0.0.0.0", port=8000) elif mode == "production": - workers = args.workers print(f"Running with {workers} workers") uvicorn.run("app.main:app", workers=workers, host="0.0.0.0", port=8000) else: diff --git a/state-manager/tests/unit/controller/test_executed_state.py b/state-manager/tests/unit/controller/test_executed_state.py index be7fbb82..14936d78 100644 --- a/state-manager/tests/unit/controller/test_executed_state.py +++ b/state-manager/tests/unit/controller/test_executed_state.py @@ -118,6 +118,9 @@ async def test_executed_state_success_multiple_outputs( # Additional calls in the loop also return query objects with set method mock_state_class.find_one = AsyncMock(return_value=mock_state) mock_state.save = AsyncMock() + new_ids = [PydanticObjectId(), PydanticObjectId()] + mock_state_class.insert_many = AsyncMock(return_value=MagicMock(inserted_ids=new_ids)) + mock_state_class.find = MagicMock(return_value=AsyncMock(to_list=AsyncMock(return_value=[mock_state, mock_state]))) # Mock State.save() for new states mock_new_state = MagicMock() @@ -264,3 +267,33 @@ async def test_executed_state_database_error( assert str(exc_info.value) == "Database error" + @patch('app.controller.executed_state.State') + @patch('app.controller.executed_state.create_next_state') + async def test_executed_state_general_exception_handling( + self, + mock_create_next_state, + mock_state_class, + mock_namespace, + mock_state_id, + mock_executed_request, + mock_state, + mock_background_tasks, + mock_request_id + ): + """Test general exception handling in executed_state function""" + # Arrange + mock_state_class.find_one = AsyncMock(return_value=mock_state) + mock_state.save = AsyncMock(side_effect=Exception("Save error")) + + # Act & Assert + with pytest.raises(Exception) as exc_info: + await executed_state( + mock_namespace, + mock_state_id, + mock_executed_request, + mock_request_id, + mock_background_tasks + ) + + assert str(exc_info.value) == "Save error" + diff --git a/state-manager/tests/unit/test_create_next_state.py b/state-manager/tests/unit/test_create_next_state.py new file mode 100644 index 00000000..65bc17b1 --- /dev/null +++ b/state-manager/tests/unit/test_create_next_state.py @@ -0,0 +1,132 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from beanie import PydanticObjectId + +from app.tasks.create_next_state import create_next_state +from app.models.db.state import State +from app.models.db.graph_template_model import GraphTemplate +from app.models.db.registered_node import RegisteredNode +from app.models.graph_template_validation_status import GraphTemplateValidationStatus +from app.models.state_status_enum import StateStatusEnum + + +class TestCreateNextState: + """Test cases for create_next_state function""" + + @pytest.fixture + def mock_state(self): + """Create a mock state object""" + state = MagicMock(spec=State) + state.id = PydanticObjectId() + state.identifier = "test_node" + state.namespace_name = "test_namespace" + state.graph_name = "test_graph" + state.run_id = "test_run_id" + state.status = StateStatusEnum.EXECUTED + state.inputs = {"input1": "value1"} + state.outputs = {"output1": "result1"} + state.error = None + state.parents = {"parent_node": PydanticObjectId()} + state.save = AsyncMock() + return state + + @pytest.fixture + def mock_graph_template(self): + """Create a mock graph template""" + template = MagicMock(spec=GraphTemplate) + template.validation_status = GraphTemplateValidationStatus.VALID + template.get_node_by_identifier = MagicMock() + return template + + @pytest.fixture + def mock_registered_node(self): + """Create a mock registered node""" + node = MagicMock(spec=RegisteredNode) + node.inputs_schema = { + "type": "object", + "properties": { + "field1": {"type": "string"}, + "field2": {"type": "string"} + } + } + return node + + @patch('app.tasks.create_next_state.GraphTemplate') + async def test_create_next_state_none_id(self, mock_graph_template_class): + """Test create_next_state with state having None id""" + # Arrange + state_with_none_id = MagicMock() + state_with_none_id.id = None + + # Act & Assert + with pytest.raises(ValueError, match="State is not valid"): + await create_next_state(state_with_none_id) + + @patch('app.tasks.create_next_state.GraphTemplate') + @patch('app.tasks.create_next_state.asyncio.sleep') + async def test_create_next_state_wait_for_validation( + self, + mock_sleep, + mock_graph_template_class, + mock_state, + mock_graph_template + ): + """Test waiting for graph template to become valid""" + # Arrange + # First call returns invalid template, second call returns valid + invalid_template = MagicMock() + invalid_template.validation_status = GraphTemplateValidationStatus.INVALID + + mock_graph_template_class.find_one = AsyncMock(side_effect=[invalid_template, mock_graph_template]) + + # Mock node template with no next nodes + node_template = MagicMock() + node_template.next_nodes = None + mock_graph_template.get_node_by_identifier.return_value = node_template + + # Act + await create_next_state(mock_state) + + # Assert + assert mock_graph_template_class.find_one.call_count == 2 + mock_sleep.assert_called_once_with(1) + assert mock_state.status == StateStatusEnum.SUCCESS + + @patch('app.tasks.create_next_state.GraphTemplate') + async def test_create_next_state_no_next_nodes( + self, + mock_graph_template_class, + mock_state, + mock_graph_template + ): + """Test when there are no next nodes""" + # Arrange + mock_graph_template_class.find_one = AsyncMock(return_value=mock_graph_template) + + node_template = MagicMock() + node_template.next_nodes = None + mock_graph_template.get_node_by_identifier.return_value = node_template + + # Act + await create_next_state(mock_state) + + # Assert + assert mock_state.status == StateStatusEnum.SUCCESS + + @patch('app.tasks.create_next_state.GraphTemplate') + async def test_create_next_state_general_exception( + self, + mock_graph_template_class, + mock_state + ): + """Test general exception handling""" + # Arrange + mock_graph_template_class.find_one = AsyncMock(side_effect=Exception("General error")) + + # Act + await create_next_state(mock_state) + + # Assert + assert mock_state.status == StateStatusEnum.ERRORED + assert mock_state.error == "General error" + mock_state.save.assert_called_once() \ No newline at end of file diff --git a/state-manager/tests/unit/test_logs_manager.py b/state-manager/tests/unit/test_logs_manager.py new file mode 100644 index 00000000..ccc49b98 --- /dev/null +++ b/state-manager/tests/unit/test_logs_manager.py @@ -0,0 +1,63 @@ +import logging +from unittest.mock import patch +from app.singletons.logs_manager import LogsManager + + +class TestLogsManager: + """Test cases for LogsManager singleton""" + + def setup_method(self): + """Reset the singleton and logging before each test""" + # Clear the singleton instance + if hasattr(LogsManager, '_instance'): + delattr(LogsManager, '_instance') + + # Reset logging level to INFO + logging.getLogger().setLevel(logging.INFO) + + def teardown_method(self): + """Clean up after each test""" + # Clear the singleton instance + if hasattr(LogsManager, '_instance'): + delattr(LogsManager, '_instance') + + @patch('app.singletons.logs_manager.sys.argv', ['python', 'run.py', '--mode', 'production']) + def test_logs_manager_production_mode_command_line(self): + """Test LogsManager sets INFO level in production mode via command line""" + # Check that the logging level is set to INFO in production mode + root_logger = logging.getLogger() + assert root_logger.level == logging.INFO + + @patch('app.singletons.logs_manager.sys.argv', ['python', 'run.py', '--mode']) + def test_logs_manager_invalid_command_line_format(self): + """Test LogsManager handles invalid command line format gracefully""" + # Should default to INFO level when command line format is invalid + root_logger = logging.getLogger() + assert root_logger.level == logging.INFO + + @patch('app.singletons.logs_manager.sys.argv', ['python', 'run.py', '--mode', 'invalid']) + def test_logs_manager_invalid_mode_command_line(self): + """Test LogsManager handles invalid mode in command line""" + # Should default to INFO level when mode is invalid + root_logger = logging.getLogger() + assert root_logger.level == logging.INFO + + def test_logs_manager_singleton_pattern(self): + """Test LogsManager follows singleton pattern""" + logs_manager1 = LogsManager() + logs_manager2 = LogsManager() + + # Both instances should be the same object + assert logs_manager1 is logs_manager2 + + def test_get_logger_returns_structlog_logger(self): + """Test get_logger returns a structlog logger""" + logs_manager = LogsManager() + logger = logs_manager.get_logger() + + # Should return a structlog logger + assert logger is not None + # Check that it's a structlog logger by checking for structlog-specific attributes + assert hasattr(logger, 'info') + assert hasattr(logger, 'error') + assert hasattr(logger, 'warning') diff --git a/state-manager/tests/unit/test_logs_manager_simple.py b/state-manager/tests/unit/test_logs_manager_simple.py new file mode 100644 index 00000000..9c22eea4 --- /dev/null +++ b/state-manager/tests/unit/test_logs_manager_simple.py @@ -0,0 +1,101 @@ +import logging +from unittest.mock import patch +from app.singletons.logs_manager import LogsManager + + +class TestLogsManagerSimple: + """Simplified test cases for LogsManager singleton""" + + def setup_method(self): + """Reset logging before each test""" + # Reset logging level to INFO + logging.getLogger().setLevel(logging.INFO) + + def test_logs_manager_singleton_pattern(self): + """Test LogsManager follows singleton pattern""" + logs_manager1 = LogsManager() + logs_manager2 = LogsManager() + + # Both instances should be the same object + assert logs_manager1 is logs_manager2 + + def test_get_logger_returns_structlog_logger(self): + """Test get_logger returns a structlog logger""" + logs_manager = LogsManager() + logger = logs_manager.get_logger() + + # Should return a structlog logger + assert logger is not None + # Check that it's a structlog logger by checking for structlog-specific attributes + assert hasattr(logger, 'info') + assert hasattr(logger, 'error') + assert hasattr(logger, 'warning') + + def test_is_development_mode_command_line_development(self): + """Test _is_development_mode with development command line argument""" + with patch('sys.argv', ['python', 'run.py', '--mode', 'development']): + logs_manager = LogsManager() + # Access the private method through the instance + result = logs_manager._is_development_mode() + assert result is True + + def test_is_development_mode_command_line_production(self): + """Test _is_development_mode with production command line argument""" + with patch('sys.argv', ['python', 'run.py', '--mode', 'production']): + logs_manager = LogsManager() + result = logs_manager._is_development_mode() + assert result is False + + def test_is_development_mode_env_var_development(self): + """Test _is_development_mode with development environment variable""" + with patch('sys.argv', ['python', 'run.py']): + with patch('os.getenv', return_value='development'): + logs_manager = LogsManager() + result = logs_manager._is_development_mode() + assert result is True + + def test_is_development_mode_env_var_production(self): + """Test _is_development_mode with production environment variable""" + with patch('sys.argv', ['python', 'run.py']): + with patch('os.getenv', return_value='production'): + logs_manager = LogsManager() + result = logs_manager._is_development_mode() + assert result is False + + def test_is_development_mode_env_var_case_insensitive(self): + """Test _is_development_mode with case insensitive environment variable""" + with patch('sys.argv', ['python', 'run.py']): + with patch('os.getenv', return_value='DEVELOPMENT'): + logs_manager = LogsManager() + result = logs_manager._is_development_mode() + assert result is True + + def test_is_development_mode_env_var_empty(self): + """Test _is_development_mode with empty environment variable""" + with patch('sys.argv', ['python', 'run.py']): + with patch('os.getenv', return_value=''): + logs_manager = LogsManager() + result = logs_manager._is_development_mode() + assert result is False + + def test_is_development_mode_invalid_command_line_format(self): + """Test _is_development_mode with invalid command line format""" + with patch('sys.argv', ['python', 'run.py', '--mode']): + logs_manager = LogsManager() + result = logs_manager._is_development_mode() + assert result is False + + def test_is_development_mode_invalid_mode(self): + """Test _is_development_mode with invalid mode""" + with patch('sys.argv', ['python', 'run.py', '--mode', 'invalid']): + logs_manager = LogsManager() + result = logs_manager._is_development_mode() + assert result is False + + def test_is_development_mode_no_mode_arg(self): + """Test _is_development_mode with no mode argument""" + with patch('sys.argv', ['python', 'run.py']): + with patch('os.getenv', return_value=''): + logs_manager = LogsManager() + result = logs_manager._is_development_mode() + assert result is False \ No newline at end of file