Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions state-manager/app/controller/executed_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from beanie import PydanticObjectId
from beanie.operators import In
from app.models.executed_models import ExecutedRequestModel, ExecutedResponseModel

from fastapi import HTTPException, status, BackgroundTasks
Expand Down Expand Up @@ -36,11 +37,11 @@ async def executed_state(namespace_name: str, state_id: PydanticObjectId, body:
state.parents = {**state.parents, state.identifier: state.id}

await state.save()

background_tasks.add_task(create_next_state, state)

new_states = []
for output in body.outputs[1:]:
new_state = State(
new_states.append(State(
node_name=state.node_name,
namespace_name=state.namespace_name,
identifier=state.identifier,
Expand All @@ -54,9 +55,20 @@ async def executed_state(namespace_name: str, state_id: PydanticObjectId, body:
**state.parents,
state.identifier: state.id
}
)
await new_state.save()
background_tasks.add_task(create_next_state, new_state)
))

if len(new_states) > 0:
inserted_ids = (await State.insert_many(new_states)).inserted_ids

inserted_states = await State.find(
In(State.id, inserted_ids)
).to_list()

if len(inserted_states) != len(new_states):
raise RuntimeError(f"Failed to insert all new states. Expected {len(new_states)} states, but only {len(inserted_states)} were inserted")

for inserted_state in inserted_states:
background_tasks.add_task(create_next_state, inserted_state)

return ExecutedResponseModel(status=StateStatusEnum.EXECUTED)

Expand Down
32 changes: 31 additions & 1 deletion state-manager/app/singletons/logs_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import structlog
import logging
import os
import sys
from .SingletonDecorator import singleton


Expand Down Expand Up @@ -28,9 +30,37 @@ def __init__(self):
handler.setFormatter(formatter)
logger = logging.getLogger()
logger.addHandler(handler)
logger.setLevel(logging.INFO)

# Check if running in development mode
# Development mode is determined by the --mode argument passed to run.py
is_development = self._is_development_mode()

if is_development:
# In development mode, set level to WARNING to disable INFO logs
logger.setLevel(logging.WARNING)
else:
# In production mode, keep INFO level
logger.setLevel(logging.INFO)

self.logger = structlog.get_logger()

def _is_development_mode(self) -> bool:
"""
Check if the application is running in development mode.
Development mode is determined by checking if '--mode' 'development'
is in the command line arguments.
"""
# Check command line arguments for development mode
if '--mode' in sys.argv:
try:
mode_index = sys.argv.index('--mode')
if mode_index + 1 < len(sys.argv) and sys.argv[mode_index + 1] == 'development':
return True
except (ValueError, IndexError):
pass

# Fallback: check environment variable
return os.getenv('MODE', '').lower() == 'development'

def get_logger(self):
return self.logger
1 change: 1 addition & 0 deletions state-manager/app/tasks/create_next_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
logger = LogsManager().get_logger()

async def create_next_state(state: State):
logger.info(f"Creating next state for {state.identifier}")
graph_template = None

if state is None or state.id is None:
Expand Down
4 changes: 2 additions & 2 deletions state-manager/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

def serve():
mode = args.mode
workers = args.workers
if mode == "development":
uvicorn.run("app.main:app", reload=True, host="0.0.0.0", port=8000)
uvicorn.run("app.main:app", workers=workers, reload=True, host="0.0.0.0", port=8000)
elif mode == "production":
workers = args.workers
print(f"Running with {workers} workers")
uvicorn.run("app.main:app", workers=workers, host="0.0.0.0", port=8000)
else:
Expand Down
33 changes: 33 additions & 0 deletions state-manager/tests/unit/controller/test_executed_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ async def test_executed_state_success_multiple_outputs(
# Additional calls in the loop also return query objects with set method
mock_state_class.find_one = AsyncMock(return_value=mock_state)
mock_state.save = AsyncMock()
new_ids = [PydanticObjectId(), PydanticObjectId()]
mock_state_class.insert_many = AsyncMock(return_value=MagicMock(inserted_ids=new_ids))
mock_state_class.find = MagicMock(return_value=AsyncMock(to_list=AsyncMock(return_value=[mock_state, mock_state])))

# Mock State.save() for new states
mock_new_state = MagicMock()
Expand Down Expand Up @@ -264,3 +267,33 @@ async def test_executed_state_database_error(

assert str(exc_info.value) == "Database error"

@patch('app.controller.executed_state.State')
@patch('app.controller.executed_state.create_next_state')
async def test_executed_state_general_exception_handling(
self,
mock_create_next_state,
mock_state_class,
mock_namespace,
mock_state_id,
mock_executed_request,
mock_state,
mock_background_tasks,
mock_request_id
):
"""Test general exception handling in executed_state function"""
# Arrange
mock_state_class.find_one = AsyncMock(return_value=mock_state)
mock_state.save = AsyncMock(side_effect=Exception("Save error"))

# Act & Assert
with pytest.raises(Exception) as exc_info:
await executed_state(
mock_namespace,
mock_state_id,
mock_executed_request,
mock_request_id,
mock_background_tasks
)

assert str(exc_info.value) == "Save error"

132 changes: 132 additions & 0 deletions state-manager/tests/unit/test_create_next_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from beanie import PydanticObjectId

from app.tasks.create_next_state import create_next_state
from app.models.db.state import State
from app.models.db.graph_template_model import GraphTemplate
from app.models.db.registered_node import RegisteredNode
from app.models.graph_template_validation_status import GraphTemplateValidationStatus
from app.models.state_status_enum import StateStatusEnum


class TestCreateNextState:
"""Test cases for create_next_state function"""

@pytest.fixture
def mock_state(self):
"""Create a mock state object"""
state = MagicMock(spec=State)
state.id = PydanticObjectId()
state.identifier = "test_node"
state.namespace_name = "test_namespace"
state.graph_name = "test_graph"
state.run_id = "test_run_id"
state.status = StateStatusEnum.EXECUTED
state.inputs = {"input1": "value1"}
state.outputs = {"output1": "result1"}
state.error = None
state.parents = {"parent_node": PydanticObjectId()}
state.save = AsyncMock()
return state

@pytest.fixture
def mock_graph_template(self):
"""Create a mock graph template"""
template = MagicMock(spec=GraphTemplate)
template.validation_status = GraphTemplateValidationStatus.VALID
template.get_node_by_identifier = MagicMock()
return template

@pytest.fixture
def mock_registered_node(self):
"""Create a mock registered node"""
node = MagicMock(spec=RegisteredNode)
node.inputs_schema = {
"type": "object",
"properties": {
"field1": {"type": "string"},
"field2": {"type": "string"}
}
}
return node

@patch('app.tasks.create_next_state.GraphTemplate')
async def test_create_next_state_none_id(self, mock_graph_template_class):
"""Test create_next_state with state having None id"""
# Arrange
state_with_none_id = MagicMock()
state_with_none_id.id = None

# Act & Assert
with pytest.raises(ValueError, match="State is not valid"):
await create_next_state(state_with_none_id)

@patch('app.tasks.create_next_state.GraphTemplate')
@patch('app.tasks.create_next_state.asyncio.sleep')
async def test_create_next_state_wait_for_validation(
self,
mock_sleep,
mock_graph_template_class,
mock_state,
mock_graph_template
):
"""Test waiting for graph template to become valid"""
# Arrange
# First call returns invalid template, second call returns valid
invalid_template = MagicMock()
invalid_template.validation_status = GraphTemplateValidationStatus.INVALID

mock_graph_template_class.find_one = AsyncMock(side_effect=[invalid_template, mock_graph_template])

# Mock node template with no next nodes
node_template = MagicMock()
node_template.next_nodes = None
mock_graph_template.get_node_by_identifier.return_value = node_template

# Act
await create_next_state(mock_state)

# Assert
assert mock_graph_template_class.find_one.call_count == 2
mock_sleep.assert_called_once_with(1)
assert mock_state.status == StateStatusEnum.SUCCESS

@patch('app.tasks.create_next_state.GraphTemplate')
async def test_create_next_state_no_next_nodes(
self,
mock_graph_template_class,
mock_state,
mock_graph_template
):
"""Test when there are no next nodes"""
# Arrange
mock_graph_template_class.find_one = AsyncMock(return_value=mock_graph_template)

node_template = MagicMock()
node_template.next_nodes = None
mock_graph_template.get_node_by_identifier.return_value = node_template

# Act
await create_next_state(mock_state)

# Assert
assert mock_state.status == StateStatusEnum.SUCCESS

@patch('app.tasks.create_next_state.GraphTemplate')
async def test_create_next_state_general_exception(
self,
mock_graph_template_class,
mock_state
):
"""Test general exception handling"""
# Arrange
mock_graph_template_class.find_one = AsyncMock(side_effect=Exception("General error"))

# Act
await create_next_state(mock_state)

# Assert
assert mock_state.status == StateStatusEnum.ERRORED
assert mock_state.error == "General error"
mock_state.save.assert_called_once()
63 changes: 63 additions & 0 deletions state-manager/tests/unit/test_logs_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
from unittest.mock import patch
from app.singletons.logs_manager import LogsManager


class TestLogsManager:
"""Test cases for LogsManager singleton"""

def setup_method(self):
"""Reset the singleton and logging before each test"""
# Clear the singleton instance
if hasattr(LogsManager, '_instance'):
delattr(LogsManager, '_instance')

# Reset logging level to INFO
logging.getLogger().setLevel(logging.INFO)

def teardown_method(self):
"""Clean up after each test"""
# Clear the singleton instance
if hasattr(LogsManager, '_instance'):
delattr(LogsManager, '_instance')

@patch('app.singletons.logs_manager.sys.argv', ['python', 'run.py', '--mode', 'production'])
def test_logs_manager_production_mode_command_line(self):
"""Test LogsManager sets INFO level in production mode via command line"""
# Check that the logging level is set to INFO in production mode
root_logger = logging.getLogger()
assert root_logger.level == logging.INFO

@patch('app.singletons.logs_manager.sys.argv', ['python', 'run.py', '--mode'])
def test_logs_manager_invalid_command_line_format(self):
"""Test LogsManager handles invalid command line format gracefully"""
# Should default to INFO level when command line format is invalid
root_logger = logging.getLogger()
assert root_logger.level == logging.INFO

@patch('app.singletons.logs_manager.sys.argv', ['python', 'run.py', '--mode', 'invalid'])
def test_logs_manager_invalid_mode_command_line(self):
"""Test LogsManager handles invalid mode in command line"""
# Should default to INFO level when mode is invalid
root_logger = logging.getLogger()
assert root_logger.level == logging.INFO

def test_logs_manager_singleton_pattern(self):
"""Test LogsManager follows singleton pattern"""
logs_manager1 = LogsManager()
logs_manager2 = LogsManager()

# Both instances should be the same object
assert logs_manager1 is logs_manager2

def test_get_logger_returns_structlog_logger(self):
"""Test get_logger returns a structlog logger"""
logs_manager = LogsManager()
logger = logs_manager.get_logger()

# Should return a structlog logger
assert logger is not None
# Check that it's a structlog logger by checking for structlog-specific attributes
assert hasattr(logger, 'info')
assert hasattr(logger, 'error')
assert hasattr(logger, 'warning')
Loading
Loading