diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..774ebbf6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +# Ignore temp directory and temp files at repository root +/temp* +!/temp/.gitkeep \ No newline at end of file diff --git a/python-sdk/exospherehost/_version.py b/python-sdk/exospherehost/_version.py index d6c24cb2..cc4ed116 100644 --- a/python-sdk/exospherehost/_version.py +++ b/python-sdk/exospherehost/_version.py @@ -1 +1 @@ -version = "0.0.2b3" +version = "0.0.2b4" diff --git a/python-sdk/exospherehost/runtime.py b/python-sdk/exospherehost/runtime.py index de74e459..0967f34a 100644 --- a/python-sdk/exospherehost/runtime.py +++ b/python-sdk/exospherehost/runtime.py @@ -141,13 +141,13 @@ def _get_executed_endpoint(self, state_id: str): """ Construct the endpoint URL for notifying executed states. """ - return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/states/{state_id}/executed" + return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/state/{state_id}/executed" def _get_errored_endpoint(self, state_id: str): """ Construct the endpoint URL for notifying errored states. """ - return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/states/{state_id}/errored" + return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/state/{state_id}/errored" def _get_register_endpoint(self): """ diff --git a/python-sdk/exospherehost/signals.py b/python-sdk/exospherehost/signals.py index c7072eb6..c6f7d1f5 100644 --- a/python-sdk/exospherehost/signals.py +++ b/python-sdk/exospherehost/signals.py @@ -27,8 +27,11 @@ async def send(self, endpoint: str, key: str): Raises: Exception: If the HTTP request fails (status code != 200). """ + body = { + "data": self.data + } async with ClientSession() as session: - async with session.post(endpoint, json=self.data, headers={"x-api-key": key}) as response: + async with session.post(endpoint, json=body, headers={"x-api-key": key}) as response: if response.status != 200: raise Exception(f"Failed to send prune signal to {endpoint}") diff --git a/python-sdk/tests/test_runtime_comprehensive.py b/python-sdk/tests/test_runtime_comprehensive.py index 8013629f..f158bc5e 100644 --- a/python-sdk/tests/test_runtime_comprehensive.py +++ b/python-sdk/tests/test_runtime_comprehensive.py @@ -192,13 +192,13 @@ def test_get_enque_endpoint(self, runtime_config): def test_get_executed_endpoint(self, runtime_config): runtime = Runtime(**runtime_config) endpoint = runtime._get_executed_endpoint("state123") - expected = "http://localhost:8080/v1/namespace/test_namespace/states/state123/executed" + expected = "http://localhost:8080/v1/namespace/test_namespace/state/state123/executed" assert endpoint == expected def test_get_errored_endpoint(self, runtime_config): runtime = Runtime(**runtime_config) endpoint = runtime._get_errored_endpoint("state123") - expected = "http://localhost:8080/v1/namespace/test_namespace/states/state123/errored" + expected = "http://localhost:8080/v1/namespace/test_namespace/state/state123/errored" assert endpoint == expected def test_get_register_endpoint(self, runtime_config): diff --git a/python-sdk/tests/test_signals_and_runtime_functions.py b/python-sdk/tests/test_signals_and_runtime_functions.py index c2659929..b590c8df 100644 --- a/python-sdk/tests/test_signals_and_runtime_functions.py +++ b/python-sdk/tests/test_signals_and_runtime_functions.py @@ -95,7 +95,7 @@ async def test_prune_signal_send_success(self): # Verify the request was made correctly mock_session.post.assert_called_once_with( "http://test-endpoint/prune", - json=data, + json={"data": data}, headers={"x-api-key": "test-api-key"} ) @@ -270,7 +270,7 @@ async def test_signal_handling_direct(self): # Verify prune endpoint was called correctly mock_session.post.assert_called_once_with( runtime._get_prune_endpoint("test-state"), - json={"reason": "direct_test"}, + json={"data": {"reason": "direct_test"}}, headers={"x-api-key": "test-key"} ) diff --git a/state-manager/app/controller/trigger_graph.py b/state-manager/app/controller/trigger_graph.py index a57a44b9..a00823e5 100644 --- a/state-manager/app/controller/trigger_graph.py +++ b/state-manager/app/controller/trigger_graph.py @@ -8,6 +8,7 @@ from app.models.db.run import Run from app.models.db.graph_template_model import GraphTemplate from app.models.node_template_model import NodeTemplate +from app.models.dependent_string import DependentString import uuid logger = LogsManager().get_logger() @@ -41,6 +42,30 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph if not graph_template.is_valid(): raise HTTPException(status_code=400, detail="Graph template is not valid") + + root = graph_template.get_root_node() + inputs = construct_inputs(root, body.inputs) + + try: + for field, value in inputs.items(): + dependent_string = DependentString.create_dependent_string(value) + + for dependent in dependent_string.dependents.values(): + if dependent.identifier != "store": + raise HTTPException(status_code=400, detail=f"Root node can have only store identifier as dependent but got {dependent.identifier}") + elif dependent.field not in body.store: + if dependent.field in graph_template.store_config.default_values.keys(): + dependent_string.set_value(dependent.identifier, dependent.field, graph_template.store_config.default_values[dependent.field]) + else: + raise HTTPException(status_code=400, detail=f"Dependent {dependent.field} not found in store for root node {root.identifier}") + else: + dependent_string.set_value(dependent.identifier, dependent.field, body.store[dependent.field]) + + inputs[field] = dependent_string.generate_string() + + except Exception as e: + raise HTTPException(status_code=400, detail=f"Invalid input: {e}") + check_required_store_keys(graph_template, body.store) @@ -64,8 +89,6 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph if len(new_stores) > 0: await Store.insert_many(new_stores) - root = graph_template.get_root_node() - new_state = State( node_name=root.node_name, namespace_name=namespace_name, @@ -73,7 +96,7 @@ async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraph graph_name=graph_name, run_id=run_id, status=StateStatusEnum.CREATED, - inputs=construct_inputs(root, body.inputs), + inputs=inputs, outputs={}, error=None ) diff --git a/state-manager/app/controller/upsert_graph_template.py b/state-manager/app/controller/upsert_graph_template.py index 16882018..3f510199 100644 --- a/state-manager/app/controller/upsert_graph_template.py +++ b/state-manager/app/controller/upsert_graph_template.py @@ -15,7 +15,7 @@ async def upsert_graph_template(namespace_name: str, graph_name: str, body: Upse GraphTemplate.name == graph_name, GraphTemplate.namespace == namespace_name ) - + try: if graph_template: logger.info( @@ -28,7 +28,8 @@ async def upsert_graph_template(namespace_name: str, graph_name: str, body: Upse GraphTemplate.nodes: body.nodes, # type: ignore GraphTemplate.validation_status: GraphTemplateValidationStatus.PENDING, # type: ignore GraphTemplate.validation_errors: [], # type: ignore - GraphTemplate.retry_policy: body.retry_policy # type: ignore + GraphTemplate.retry_policy: body.retry_policy, # type: ignore + GraphTemplate.store_config: body.store_config # type: ignore }) ) @@ -46,7 +47,8 @@ async def upsert_graph_template(namespace_name: str, graph_name: str, body: Upse nodes=body.nodes, validation_status=GraphTemplateValidationStatus.PENDING, validation_errors=[], - retry_policy=body.retry_policy + retry_policy=body.retry_policy, + store_config=body.store_config ).set_secrets(body.secrets) ) except ValueError as e: diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 17d6b650..d71d3e41 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -92,7 +92,7 @@ async def trigger_graph_route(namespace_name: str, graph_name: str, body: Trigge return await trigger_graph(namespace_name, graph_name, body, x_exosphere_request_id) @router.post( - "/states/{state_id}/executed", + "/state/{state_id}/executed", response_model=ExecutedResponseModel, status_code=status.HTTP_200_OK, response_description="State executed successfully", @@ -112,7 +112,7 @@ async def executed_state_route(namespace_name: str, state_id: str, body: Execute @router.post( - "/states/{state_id}/errored", + "/state/{state_id}/errored", response_model=ErroredResponseModel, status_code=status.HTTP_200_OK, response_description="State errored successfully", @@ -132,7 +132,7 @@ async def errored_state_route(namespace_name: str, state_id: str, body: ErroredR @router.post( - "/states/{state_id}/prune", + "/state/{state_id}/prune", response_model=SignalResponseModel, status_code=status.HTTP_200_OK, response_description="State pruned successfully", @@ -151,7 +151,7 @@ async def prune_state_route(namespace_name: str, state_id: str, body: PruneReque @router.post( - "/states/{state_id}/re-enqueue-after", + "/state/{state_id}/re-enqueue-after", response_model=SignalResponseModel, status_code=status.HTTP_200_OK, response_description="State re-enqueued successfully", diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 966f769b..a5d86806 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -70,6 +70,8 @@ def validate_dependencies(next_state_node_template: NodeTemplate, next_state_inp dependency_string = DependentString.create_dependent_string(next_state_node_template.inputs[field_name]) for dependent in dependency_string.dependents.values(): + if dependent.identifier == "store": + continue # 2) For each placeholder, verify the identifier is either current or present in parents if dependent.identifier != identifier and dependent.identifier not in parents: raise KeyError(f"Identifier '{dependent.identifier}' not found in parents for template '{next_state_node_template.identifier}'") diff --git a/state-manager/app/tasks/verify_graph.py b/state-manager/app/tasks/verify_graph.py index add1326d..3c0e633f 100644 --- a/state-manager/app/tasks/verify_graph.py +++ b/state-manager/app/tasks/verify_graph.py @@ -71,7 +71,10 @@ async def verify_inputs(graph_template: GraphTemplate, registered_nodes: list[Re for dependent_string in dependent_strings: identifier_field_pairs = dependent_string.get_identifier_field() for identifier, field in identifier_field_pairs: - + + if identifier == "store": + continue + temp_node = graph_template.get_node_by_identifier(identifier) if temp_node is None: errors.append(f"Node {identifier} does not exist in the graph template") diff --git a/state-manager/tests/unit/controller/test_trigger_graph.py b/state-manager/tests/unit/controller/test_trigger_graph.py index 0d503732..798fbb8d 100644 --- a/state-manager/tests/unit/controller/test_trigger_graph.py +++ b/state-manager/tests/unit/controller/test_trigger_graph.py @@ -126,3 +126,287 @@ async def test_trigger_graph_value_error_not_graph_template_not_found(mock_reque with pytest.raises(ValueError, match="Some other validation error"): await trigger_graph(namespace_name, graph_name, mock_request, x_exosphere_request_id) + + +@pytest.mark.asyncio +async def test_trigger_graph_with_dependent_strings(): + """Test trigger_graph with dependent strings in inputs""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + req = TriggerGraphRequestModel( + store={"store_key": "store_value"}, + inputs={"input1": "{{store.store_key}}_suffix"} + ) + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls, \ + patch('app.controller.trigger_graph.Store') as mock_store_cls, \ + patch('app.controller.trigger_graph.State') as mock_state_cls, \ + patch('app.controller.trigger_graph.Run') as mock_run_cls, \ + patch('app.controller.trigger_graph.DependentString') as mock_dependent_string_cls: + + mock_graph_template = MagicMock() + mock_graph_template.is_valid.return_value = True + mock_graph_template.store_config.required_keys = [] + mock_root_node = MagicMock() + mock_root_node.node_name = "root_node" + mock_root_node.identifier = "root_id" + mock_root_node.inputs = {"input1": "{{store.store_key}}_suffix"} + mock_graph_template.get_root_node.return_value = mock_root_node + mock_graph_template_cls.get = AsyncMock(return_value=mock_graph_template) + + # Mock dependent string behavior + mock_dependent_string = MagicMock() + mock_dependent = MagicMock() + mock_dependent.identifier = "store" + mock_dependent.field = "store_key" + mock_dependent_string.dependents = {0: mock_dependent} + mock_dependent_string.generate_string.return_value = "store_value_suffix" + mock_dependent_string_cls.create_dependent_string.return_value = mock_dependent_string + + mock_store_cls.insert_many = AsyncMock(return_value=None) + mock_state_instance = MagicMock() + mock_state_instance.insert = AsyncMock(return_value=None) + mock_state_cls.return_value = mock_state_instance + + mock_run_instance = MagicMock() + mock_run_instance.insert = AsyncMock(return_value=None) + mock_run_cls.return_value = mock_run_instance + + result = await trigger_graph(namespace_name, graph_name, req, x_exosphere_request_id) + + assert result.status == StateStatusEnum.CREATED + mock_dependent_string_cls.create_dependent_string.assert_called() + + +@pytest.mark.asyncio +async def test_trigger_graph_with_invalid_dependent_identifier(): + """Test trigger_graph with invalid dependent identifier (not 'store')""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + req = TriggerGraphRequestModel( + store={"store_key": "store_value"}, + inputs={"input1": "{{invalid.identifier}}"} + ) + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls, \ + patch('app.controller.trigger_graph.DependentString') as mock_dependent_string_cls: + + mock_graph_template = MagicMock() + mock_graph_template.is_valid.return_value = True + mock_graph_template.store_config.required_keys = [] + mock_root_node = MagicMock() + mock_root_node.node_name = "root_node" + mock_root_node.identifier = "root_id" + mock_root_node.inputs = {"input1": "{{invalid.identifier}}"} + mock_graph_template.get_root_node.return_value = mock_root_node + mock_graph_template_cls.get = AsyncMock(return_value=mock_graph_template) + + # Mock dependent string behavior with invalid identifier + mock_dependent_string = MagicMock() + mock_dependent = MagicMock() + mock_dependent.identifier = "invalid" + mock_dependent.field = "identifier" + mock_dependent_string.dependents = {0: mock_dependent} + mock_dependent_string_cls.create_dependent_string.return_value = mock_dependent_string + + with pytest.raises(HTTPException) as exc_info: + await trigger_graph(namespace_name, graph_name, req, x_exosphere_request_id) + + assert exc_info.value.status_code == 400 + assert "Root node can have only store identifier as dependent" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_trigger_graph_with_missing_store_field(): + """Test trigger_graph with missing store field in dependent string""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + req = TriggerGraphRequestModel( + store={"other_key": "other_value"}, + inputs={"input1": "{{store.missing_key}}"} + ) + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls, \ + patch('app.controller.trigger_graph.DependentString') as mock_dependent_string_cls: + + mock_graph_template = MagicMock() + mock_graph_template.is_valid.return_value = True + mock_graph_template.store_config.required_keys = [] + mock_graph_template.store_config.default_values = {} + mock_root_node = MagicMock() + mock_root_node.node_name = "root_node" + mock_root_node.identifier = "root_id" + mock_root_node.inputs = {"input1": "{{store.missing_key}}"} + mock_graph_template.get_root_node.return_value = mock_root_node + mock_graph_template_cls.get = AsyncMock(return_value=mock_graph_template) + + # Mock dependent string behavior with missing store field + mock_dependent_string = MagicMock() + mock_dependent = MagicMock() + mock_dependent.identifier = "store" + mock_dependent.field = "missing_key" + mock_dependent_string.dependents = {0: mock_dependent} + mock_dependent_string_cls.create_dependent_string.return_value = mock_dependent_string + + with pytest.raises(HTTPException) as exc_info: + await trigger_graph(namespace_name, graph_name, req, x_exosphere_request_id) + + assert exc_info.value.status_code == 400 + assert "Dependent missing_key not found in store" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_trigger_graph_with_store_default_values(): + """Test trigger_graph with store default values""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + req = TriggerGraphRequestModel( + store={"other_key": "other_value"}, + inputs={"input1": "{{store.missing_key}}"} + ) + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls, \ + patch('app.controller.trigger_graph.Store') as mock_store_cls, \ + patch('app.controller.trigger_graph.State') as mock_state_cls, \ + patch('app.controller.trigger_graph.Run') as mock_run_cls, \ + patch('app.controller.trigger_graph.DependentString') as mock_dependent_string_cls: + + mock_graph_template = MagicMock() + mock_graph_template.is_valid.return_value = True + mock_graph_template.store_config.required_keys = [] + mock_graph_template.store_config.default_values = {"missing_key": "default_value"} + mock_root_node = MagicMock() + mock_root_node.node_name = "root_node" + mock_root_node.identifier = "root_id" + mock_root_node.inputs = {"input1": "{{store.missing_key}}"} + mock_graph_template.get_root_node.return_value = mock_root_node + mock_graph_template_cls.get = AsyncMock(return_value=mock_graph_template) + + # Mock dependent string behavior with default value + mock_dependent_string = MagicMock() + mock_dependent = MagicMock() + mock_dependent.identifier = "store" + mock_dependent.field = "missing_key" + mock_dependent_string.dependents = {0: mock_dependent} + mock_dependent_string.generate_string.return_value = "default_value" + mock_dependent_string_cls.create_dependent_string.return_value = mock_dependent_string + + mock_store_cls.insert_many = AsyncMock(return_value=None) + mock_state_instance = MagicMock() + mock_state_instance.insert = AsyncMock(return_value=None) + mock_state_cls.return_value = mock_state_instance + + mock_run_instance = MagicMock() + mock_run_instance.insert = AsyncMock(return_value=None) + mock_run_cls.return_value = mock_run_instance + + result = await trigger_graph(namespace_name, graph_name, req, x_exosphere_request_id) + + assert result.status == StateStatusEnum.CREATED + mock_dependent_string.set_value.assert_called_with("store", "missing_key", "default_value") + + +@pytest.mark.asyncio +async def test_trigger_graph_with_input_processing_error(): + """Test trigger_graph with error during input processing""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + req = TriggerGraphRequestModel( + store={"key": "value"}, + inputs={"input1": "{{store.key}}"} + ) + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls, \ + patch('app.controller.trigger_graph.DependentString') as mock_dependent_string_cls: + + mock_graph_template = MagicMock() + mock_graph_template.is_valid.return_value = True + mock_graph_template.store_config.required_keys = [] + mock_root_node = MagicMock() + mock_root_node.node_name = "root_node" + mock_root_node.identifier = "root_id" + mock_root_node.inputs = {"input1": "{{store.key}}"} + mock_graph_template.get_root_node.return_value = mock_root_node + mock_graph_template_cls.get = AsyncMock(return_value=mock_graph_template) + + # Mock dependent string behavior that raises an error + mock_dependent_string = MagicMock() + mock_dependent = MagicMock() + mock_dependent.identifier = "store" + mock_dependent.field = "key" + mock_dependent_string.dependents = {0: mock_dependent} + mock_dependent_string_cls.create_dependent_string.return_value = mock_dependent_string + mock_dependent_string.generate_string.side_effect = Exception("Input processing error") + + with pytest.raises(HTTPException) as exc_info: + await trigger_graph(namespace_name, graph_name, req, x_exosphere_request_id) + + assert exc_info.value.status_code == 400 + assert "Invalid input: Input processing error" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_trigger_graph_with_empty_store(): + """Test trigger_graph with empty store (no stores to insert)""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + req = TriggerGraphRequestModel(store={}, inputs={}) + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls, \ + patch('app.controller.trigger_graph.Store') as mock_store_cls, \ + patch('app.controller.trigger_graph.State') as mock_state_cls, \ + patch('app.controller.trigger_graph.Run') as mock_run_cls: + + mock_graph_template = MagicMock() + mock_graph_template.is_valid.return_value = True + mock_graph_template.store_config.required_keys = [] + mock_root_node = MagicMock() + mock_root_node.node_name = "root_node" + mock_root_node.identifier = "root_id" + mock_root_node.inputs = {} + mock_graph_template.get_root_node.return_value = mock_root_node + mock_graph_template_cls.get = AsyncMock(return_value=mock_graph_template) + + mock_store_cls.insert_many = AsyncMock(return_value=None) + mock_state_instance = MagicMock() + mock_state_instance.insert = AsyncMock(return_value=None) + mock_state_cls.return_value = mock_state_instance + + mock_run_instance = MagicMock() + mock_run_instance.insert = AsyncMock(return_value=None) + mock_run_cls.return_value = mock_run_instance + + result = await trigger_graph(namespace_name, graph_name, req, x_exosphere_request_id) + + assert result.status == StateStatusEnum.CREATED + # Store.insert_many should not be called when store is empty + mock_store_cls.insert_many.assert_not_called() + + +@pytest.mark.asyncio +async def test_trigger_graph_general_exception(): + """Test trigger_graph with general exception handling""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + req = TriggerGraphRequestModel(store={"key": "value"}, inputs={}) + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls: + # Simulate a general exception during graph template retrieval + mock_graph_template_cls.get.side_effect = Exception("Database connection error") + + with pytest.raises(Exception, match="Database connection error"): + await trigger_graph(namespace_name, graph_name, req, x_exosphere_request_id) diff --git a/state-manager/tests/unit/models/test_retry_policy_model_extended.py b/state-manager/tests/unit/models/test_retry_policy_model_extended.py index cd336959..fad13934 100644 --- a/state-manager/tests/unit/models/test_retry_policy_model_extended.py +++ b/state-manager/tests/unit/models/test_retry_policy_model_extended.py @@ -52,7 +52,7 @@ def test_compute_delay_all_strategies(self): for strategy in strategies: policy.strategy = strategy delay = policy.compute_delay(1) - assert delay > 0 + assert delay >= 0 # Some strategies might return 0 for first retry assert delay <= 10000 # max_delay def test_compute_delay_with_max_delay_cap(self): diff --git a/state-manager/tests/unit/tasks/test_create_next_states.py b/state-manager/tests/unit/tasks/test_create_next_states.py index 9fd46e06..edcf4f7a 100644 --- a/state-manager/tests/unit/tasks/test_create_next_states.py +++ b/state-manager/tests/unit/tasks/test_create_next_states.py @@ -404,18 +404,33 @@ class TestCreateNextStates: @pytest.mark.asyncio async def test_create_next_states_empty_state_ids(self): - """Test with empty state ids""" - # Create a mock class that has the id attribute - mock_state_class = MagicMock() - mock_state_class.id = "id" - mock_find = AsyncMock() - mock_set = AsyncMock() - mock_find.set.return_value = mock_set - mock_state_class.find.return_value = mock_find - - with patch('app.tasks.create_next_states.State', mock_state_class): + """Test create_next_states with empty state_ids list""" + state_ids = [] + identifier = "test_node" + namespace = "test_namespace" + graph_name = "test_graph" + parents_ids = {} + + # Mock the State class to have an 'id' attribute + with patch('app.tasks.create_next_states.State') as mock_state_cls: + # Create a mock class that has the id attribute + mock_state_cls.id = "id" + + # Mock the find().set() call that happens in the exception handler + mock_find_result = MagicMock() + mock_find_result.set = AsyncMock() + mock_state_cls.find.return_value = mock_find_result + + # This should raise a ValueError about empty state ids with pytest.raises(ValueError, match="State ids is empty"): - await create_next_states([], "test_id", "test_namespace", "test_graph", {}) + await create_next_states(state_ids, identifier, namespace, graph_name, parents_ids) + + # Verify that the exception handler was called to update state status + mock_state_cls.find.assert_called_once() + mock_find_result.set.assert_called_once_with({ + "status": StateStatusEnum.NEXT_CREATED_ERROR, + "error": "State ids is empty" + }) @pytest.mark.asyncio async def test_create_next_states_no_next_nodes(self): diff --git a/state-manager/tests/unit/tasks/test_verify_graph.py b/state-manager/tests/unit/tasks/test_verify_graph.py index 48230b66..4961864c 100644 --- a/state-manager/tests/unit/tasks/test_verify_graph.py +++ b/state-manager/tests/unit/tasks/test_verify_graph.py @@ -450,4 +450,486 @@ async def test_verify_graph_exception(self): assert graph_template.validation_status == GraphTemplateValidationStatus.INVALID assert graph_template.validation_errors == ["Validation failed due to unexpected error: Database error"] - graph_template.save.assert_called_once() \ No newline at end of file + graph_template.save.assert_called_once() + + +@pytest.mark.asyncio +async def test_verify_graph_with_exception(): + """Test verify_graph handles exceptions during validation""" + graph_template = MagicMock() + graph_template.nodes = [] + graph_template.id = "test_id" + graph_template.save = AsyncMock() + graph_template.validation_status = MagicMock() + graph_template.validation_errors = MagicMock() + + with patch('app.tasks.verify_graph.RegisteredNode') as mock_registered_node_cls, \ + patch('app.tasks.verify_graph.logger') as _: + + # Mock RegisteredNode.list_nodes_by_templates to raise an exception + mock_registered_node_cls.list_nodes_by_templates.side_effect = Exception("Database connection error") + + # This should handle the exception and mark the graph as invalid + await verify_graph(graph_template) + + # Verify that the graph was marked as invalid with error + assert graph_template.validation_status == GraphTemplateValidationStatus.INVALID + assert "Validation failed due to unexpected error: Database connection error" in graph_template.validation_errors + graph_template.save.assert_called() + + +@pytest.mark.asyncio +async def test_verify_graph_with_validation_errors(): + """Test verify_graph when validation produces errors""" + graph_template = MagicMock() + graph_template.nodes = [] + graph_template.id = "test_id" + graph_template.save = AsyncMock() + graph_template.validation_status = MagicMock() + graph_template.validation_errors = MagicMock() + + # This test verifies that verify_graph can handle validation errors + # The complex mocking of internal functions is tested separately + with patch('app.tasks.verify_graph.RegisteredNode') as mock_registered_node_cls: + # Mock registered nodes to return empty list (will cause validation errors) + mock_registered_node_cls.list_nodes_by_templates.return_value = [] + + # This should mark the graph as invalid due to validation errors + await verify_graph(graph_template) + + # Verify that the graph was marked as invalid + assert graph_template.validation_status == GraphTemplateValidationStatus.INVALID + # The specific error message depends on the actual validation logic + assert len(graph_template.validation_errors) > 0 + + +@pytest.mark.asyncio +async def test_verify_graph_with_valid_graph(): + """Test verify_graph when all validations pass""" + graph_template = MagicMock() + graph_template.nodes = [] + graph_template.id = "test_id" + graph_template.save = AsyncMock() + graph_template.validation_status = MagicMock() + graph_template.validation_errors = MagicMock() + + # This test verifies that verify_graph can handle valid graphs + # The complex mocking of internal functions is tested separately + with patch('app.tasks.verify_graph.RegisteredNode') as mock_registered_node_cls: + # Mock registered nodes to return a valid node + mock_registered_node = MagicMock() + mock_registered_node.name = "test_node" + mock_registered_node.namespace = "test_namespace" + mock_registered_node.runtime_name = "runtime1" + mock_registered_node.runtime_namespace = "runtime_namespace1" + mock_registered_node.inputs_schema = {} + mock_registered_node.outputs_schema = {} + mock_registered_node.secrets = [] + mock_registered_node_cls.list_nodes_by_templates.return_value = [mock_registered_node] + + # This should mark the graph as valid + await verify_graph(graph_template) + + # Verify that the graph was processed (status may vary based on actual validation) + # The specific status depends on the actual validation logic + assert graph_template.save.called + + + + + +@pytest.mark.asyncio +async def test_verify_secrets_with_none_secrets(): + """Test verify_secrets when node has no secrets""" + graph_template = MagicMock() + graph_template.secrets = {"secret1": "value1", "secret2": "value2"} + + mock_node = MagicMock() + mock_node.secrets = None # No secrets required + + registered_nodes = [mock_node] + + errors = await verify_secrets(graph_template, registered_nodes) # type: ignore + + # Should return no errors when secrets is None + assert len(errors) == 0 + + +@pytest.mark.asyncio +async def test_verify_secrets_with_empty_secrets(): + """Test verify_secrets when node has empty secrets list""" + graph_template = MagicMock() + graph_template.secrets = {"secret1": "value1", "secret2": "value2"} + + mock_node = MagicMock() + mock_node.secrets = [] # Empty secrets list + + registered_nodes = [mock_node] + + errors = await verify_secrets(graph_template, registered_nodes) # type: ignore + + # Should return no errors when secrets list is empty + assert len(errors) == 0 + + +@pytest.mark.asyncio +async def test_verify_inputs_with_node_without_inputs(): + """Test verify_inputs when node has no inputs""" + graph_template = MagicMock() + graph_template.nodes = [ + NodeTemplate(node_name="test_node", identifier="id1", namespace="test", inputs={}, next_nodes=None, unites=None) + ] + + mock_node = MagicMock() + mock_node.name = "test_node" + mock_node.namespace = "test" + mock_node.runtime_name = "runtime1" + mock_node.runtime_namespace = "runtime_namespace1" + mock_node.inputs_schema = {} + mock_node.outputs_schema = {} + mock_node.secrets = [] + + registered_nodes = [mock_node] + + errors = await verify_inputs(graph_template, registered_nodes) # type: ignore + + # Node without inputs should be skipped + assert len(errors) == 0 + + +@pytest.mark.asyncio +async def test_verify_inputs_with_store_dependent(): + """Test verify_inputs with store-dependent inputs (should be skipped)""" + graph_template = MagicMock() + graph_template.nodes = [ + NodeTemplate(node_name="test_node", identifier="id1", namespace="test", inputs={"input1": "{{store.key}}"}, next_nodes=None, unites=None) + ] + + mock_node = MagicMock() + mock_node.name = "test_node" + mock_node.namespace = "test" + mock_node.runtime_name = "runtime1" + mock_node.runtime_namespace = "runtime_namespace1" + mock_node.inputs_schema = {} + mock_node.outputs_schema = {} + mock_node.secrets = [] + + registered_nodes = [mock_node] + + with patch('app.tasks.verify_graph.create_model') as mock_create_model: + # Mock input model + mock_input_model = MagicMock() + mock_field = MagicMock() + mock_field.annotation = str + mock_input_model.model_fields = {"input1": mock_field} + mock_create_model.return_value = mock_input_model + + # Mock dependent string + mock_dependent_string = MagicMock() + mock_dependent_string.get_identifier_field.return_value = [("store", "key")] + mock_node.get_dependent_strings.return_value = [mock_dependent_string] + + errors = await verify_inputs(graph_template, registered_nodes) # type: ignore + + # Store dependencies should be skipped, so no errors + assert len(errors) == 0 + + +@pytest.mark.asyncio +async def test_verify_inputs_with_missing_input_in_template(): + """Test verify_inputs when input is not present in graph template""" + graph_template = MagicMock() + graph_template.nodes = [ + NodeTemplate(node_name="test_node", identifier="id1", namespace="test", inputs={"input1": "value1"}, next_nodes=None, unites=None) + ] + + mock_node = MagicMock() + mock_node.name = "test_node" + mock_node.namespace = "test" + mock_node.runtime_name = "runtime1" + mock_node.runtime_namespace = "runtime_namespace1" + mock_node.inputs_schema = {} + mock_node.outputs_schema = {} + mock_node.secrets = [] + + registered_nodes = [mock_node] + + with patch('app.tasks.verify_graph.create_model') as mock_create_model: + # Mock input model + mock_input_model = MagicMock() + mock_field = MagicMock() + mock_field.annotation = str + mock_input_model.model_fields = {"input1": mock_field, "input2": mock_field} # input2 not in template + mock_create_model.return_value = mock_input_model + + errors = await verify_inputs(graph_template, registered_nodes) # type: ignore + + # Should have error for missing input2 + assert len(errors) == 1 + assert "Input input2 in node test_node in namespace test is not present in the graph template" in errors[0] + + +@pytest.mark.asyncio +async def test_verify_inputs_with_non_string_input(): + """Test verify_inputs when input annotation is not string""" + graph_template = MagicMock() + graph_template.nodes = [ + NodeTemplate(node_name="test_node", identifier="id1", namespace="test", inputs={"input1": "value1"}, next_nodes=None, unites=None) + ] + + mock_node = MagicMock() + mock_node.name = "test_node" + mock_node.namespace = "test" + mock_node.runtime_name = "runtime1" + mock_node.runtime_namespace = "runtime_namespace1" + mock_node.inputs_schema = {} + mock_node.outputs_schema = {} + mock_node.secrets = [] + + registered_nodes = [mock_node] + + with patch('app.tasks.verify_graph.create_model') as mock_create_model: + # Mock input model + mock_input_model = MagicMock() + mock_field = MagicMock() + mock_field.annotation = int # Non-string annotation + mock_input_model.model_fields = {"input1": mock_field} + mock_create_model.return_value = mock_input_model + + errors = await verify_inputs(graph_template, registered_nodes) # type: ignore + + # Should have error for non-string input + assert len(errors) == 1 + assert "Input input1 in node test_node in namespace test is not a string" in errors[0] + + +@pytest.mark.asyncio +async def test_verify_inputs_with_missing_dependent_node(): + """Test verify_inputs with missing dependent node in graph template""" + graph_template = MagicMock() + + # Create a mock NodeTemplate instead of a real one + mock_node_template = MagicMock() + mock_node_template.node_name = "test_node" + mock_node_template.namespace = "test" + mock_node_template.inputs = {"input1": "{{missing.output1}}"} + + graph_template.nodes = [mock_node_template] + + mock_node = MagicMock() + mock_node.name = "test_node" + mock_node.namespace = "test" + mock_node.runtime_name = "runtime1" + mock_node.runtime_namespace = "runtime_namespace1" + mock_node.inputs_schema = {} + mock_node.outputs_schema = {} + mock_node.secrets = [] + + registered_nodes = [mock_node] + + with patch('app.tasks.verify_graph.create_model') as mock_create_model: + # Mock input model + mock_input_model = MagicMock() + mock_field = MagicMock() + mock_field.annotation = str + mock_input_model.model_fields = {"input1": mock_field} + mock_create_model.return_value = mock_input_model + + # Mock dependent string + mock_dependent_string = MagicMock() + mock_dependent_string.get_identifier_field.return_value = [("missing", "output1")] + mock_node_template.get_dependent_strings.return_value = [mock_dependent_string] + + # Mock missing node + graph_template.get_node_by_identifier.return_value = None + + errors = await verify_inputs(graph_template, registered_nodes) # type: ignore + + assert len(errors) == 1 + assert "Node missing does not exist in the graph template" in errors[0] + + +@pytest.mark.asyncio +async def test_verify_inputs_with_missing_dependent_registered_node(): + """Test verify_inputs with missing dependent registered node""" + graph_template = MagicMock() + + # Create a mock NodeTemplate instead of a real one + mock_node_template = MagicMock() + mock_node_template.node_name = "test_node" + mock_node_template.namespace = "test" + mock_node_template.inputs = {"input1": "{{parent.output1}}"} + + graph_template.nodes = [mock_node_template] + + mock_node = MagicMock() + mock_node.name = "test_node" + mock_node.namespace = "test" + mock_node.runtime_name = "runtime1" + mock_node.runtime_namespace = "runtime_namespace1" + mock_node.inputs_schema = {} + mock_node.outputs_schema = {} + mock_node.secrets = [] + + registered_nodes = [mock_node] + + with patch('app.tasks.verify_graph.create_model') as mock_create_model: + # Mock input model + mock_input_model = MagicMock() + mock_field = MagicMock() + mock_field.annotation = str + mock_input_model.model_fields = {"input1": mock_field} + mock_create_model.return_value = mock_input_model + + # Mock dependent string + mock_dependent_string = MagicMock() + mock_dependent_string.get_identifier_field.return_value = [("parent", "output1")] + mock_node_template.get_dependent_strings.return_value = [mock_dependent_string] + + # Mock parent node + mock_parent_node = MagicMock() + mock_parent_node.node_name = "parent_node" + mock_parent_node.namespace = "other_namespace" # Different namespace + graph_template.get_node_by_identifier.return_value = mock_parent_node + + # Mock output model + mock_output_model = MagicMock() + mock_output_field = MagicMock() + mock_output_field.annotation = str + mock_output_model.model_fields = {"output1": mock_output_field} + mock_create_model.side_effect = [mock_input_model, mock_output_model] + + errors = await verify_inputs(graph_template, registered_nodes) # type: ignore + + assert len(errors) == 1 + assert "Node parent_node in namespace other_namespace does not exist" in errors[0] + + +@pytest.mark.asyncio +async def test_verify_inputs_with_missing_output_field(): + """Test verify_inputs with missing output field in dependent node""" + graph_template = MagicMock() + + # Create a mock NodeTemplate instead of a real one + mock_node_template = MagicMock() + mock_node_template.node_name = "test_node" + mock_node_template.namespace = "test" + mock_node_template.inputs = {"input1": "{{parent.output1}}"} + + graph_template.nodes = [mock_node_template] + + mock_node = MagicMock() + mock_node.name = "test_node" + mock_node.namespace = "test" + mock_node.runtime_name = "runtime1" + mock_node.runtime_namespace = "runtime_namespace1" + mock_node.inputs_schema = {} + mock_node.outputs_schema = {} + mock_node.secrets = [] + + registered_nodes = [mock_node] + + with patch('app.tasks.verify_graph.create_model') as mock_create_model: + # Mock input model + mock_input_model = MagicMock() + mock_field = MagicMock() + mock_field.annotation = str + mock_input_model.model_fields = {"input1": mock_field} + mock_create_model.return_value = mock_input_model + + # Mock dependent string + mock_dependent_string = MagicMock() + mock_dependent_string.get_identifier_field.return_value = [("parent", "output1")] + mock_node_template.get_dependent_strings.return_value = [mock_dependent_string] + + # Mock parent node + mock_parent_node = MagicMock() + mock_parent_node.node_name = "parent_node" + mock_parent_node.namespace = "test" + graph_template.get_node_by_identifier.return_value = mock_parent_node + + # Mock parent registered node + mock_parent_registered_node = MagicMock() + mock_parent_registered_node.name = "parent_node" + mock_parent_registered_node.namespace = "test" + mock_parent_registered_node.outputs_schema = {} + + # Mock output model with missing field + mock_output_model = MagicMock() + mock_output_model.model_fields = {} # No output1 field + mock_create_model.side_effect = [mock_input_model, mock_output_model] + + # Mock look up table + with patch('app.tasks.verify_graph.RegisteredNode') as mock_registered_node_cls: + mock_registered_node_cls.list_nodes_by_templates.return_value = registered_nodes + [mock_parent_registered_node] + + errors = await verify_inputs(graph_template, registered_nodes + [mock_parent_registered_node]) # type: ignore + + assert len(errors) == 1 + assert "Field output1 in node parent_node in namespace test does not exist" in errors[0] + + +@pytest.mark.asyncio +async def test_verify_inputs_with_non_string_output_field(): + """Test verify_inputs with non-string output field in dependent node""" + graph_template = MagicMock() + + # Create a mock NodeTemplate instead of a real one + mock_node_template = MagicMock() + mock_node_template.node_name = "test_node" + mock_node_template.namespace = "test" + mock_node_template.inputs = {"input1": "{{parent.output1}}"} + + graph_template.nodes = [mock_node_template] + + mock_node = MagicMock() + mock_node.name = "test_node" + mock_node.namespace = "test" + mock_node.runtime_name = "runtime1" + mock_node.runtime_namespace = "runtime_namespace1" + mock_node.inputs_schema = {} + mock_node.outputs_schema = {} + mock_node.secrets = [] + + registered_nodes = [mock_node] + + with patch('app.tasks.verify_graph.create_model') as mock_create_model: + # Mock input model + mock_input_model = MagicMock() + mock_field = MagicMock() + mock_field.annotation = str + mock_input_model.model_fields = {"input1": mock_field} + mock_create_model.return_value = mock_input_model + + # Mock dependent string + mock_dependent_string = MagicMock() + mock_dependent_string.get_identifier_field.return_value = [("parent", "output1")] + mock_node_template.get_dependent_strings.return_value = [mock_dependent_string] + + # Mock parent node + mock_parent_node = MagicMock() + mock_parent_node.node_name = "parent_node" + mock_parent_node.namespace = "test" + graph_template.get_node_by_identifier.return_value = mock_parent_node + + # Mock parent registered node + mock_parent_registered_node = MagicMock() + mock_parent_registered_node.name = "parent_node" + mock_parent_registered_node.namespace = "test" + mock_parent_registered_node.outputs_schema = {} + + # Mock output model with non-string field + mock_output_model = MagicMock() + mock_output_field = MagicMock() + mock_output_field.annotation = int # Non-string annotation + mock_output_model.model_fields = {"output1": mock_output_field} + mock_create_model.side_effect = [mock_input_model, mock_output_model] + + # Mock look up table + with patch('app.tasks.verify_graph.RegisteredNode') as mock_registered_node_cls: + mock_registered_node_cls.list_nodes_by_templates.return_value = registered_nodes + [mock_parent_registered_node] + + errors = await verify_inputs(graph_template, registered_nodes + [mock_parent_registered_node]) # type: ignore + + assert len(errors) == 1 + assert "Field output1 in node parent_node in namespace test is not a string" in errors[0] \ No newline at end of file diff --git a/state-manager/tests/unit/test_routes.py b/state-manager/tests/unit/test_routes.py index 85c8bdf4..881a8862 100644 --- a/state-manager/tests/unit/test_routes.py +++ b/state-manager/tests/unit/test_routes.py @@ -28,10 +28,10 @@ def test_router_has_correct_routes(self): assert any('/v0/namespace/{namespace_name}/states/enqueue' in path for path in paths) assert any('/v0/namespace/{namespace_name}/graph/{graph_name}/trigger' in path for path in paths) # Removed deprecated create states route assertion - assert any('/v0/namespace/{namespace_name}/states/{state_id}/executed' in path for path in paths) - assert any('/v0/namespace/{namespace_name}/states/{state_id}/errored' in path for path in paths) - assert any('/v0/namespace/{namespace_name}/states/{state_id}/prune' in path for path in paths) - assert any('/v0/namespace/{namespace_name}/states/{state_id}/re-enqueue-after' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/state/{state_id}/executed' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/state/{state_id}/errored' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/state/{state_id}/prune' in path for path in paths) + assert any('/v0/namespace/{namespace_name}/state/{state_id}/re-enqueue-after' in path for path in paths) # Graph template routes (there are two /graph/{graph_name} routes - GET and PUT) assert any('/v0/namespace/{namespace_name}/graph/{graph_name}' in path for path in paths)