diff --git a/docs/docs/exosphere/create-graph.md b/docs/docs/exosphere/create-graph.md index 178debad..502c4db7 100644 --- a/docs/docs/exosphere/create-graph.md +++ b/docs/docs/exosphere/create-graph.md @@ -176,7 +176,7 @@ async def create_graph_template(): } try: - # Create or update the graph template + # Create or update the graph template (with optional store, beta) result = await state_manager.upsert_graph( graph_name="my-workflow", graph_nodes=graph_nodes, @@ -186,6 +186,9 @@ async def create_graph_template(): "strategy": "EXPONENTIAL", "backoff_factor": 2000, "exponent": 2 + }, + store_config={ # beta + "ttl": 7200 # seconds to keep key/values } ) print("Graph template created successfully!") diff --git a/docs/docs/exosphere/trigger-graph.md b/docs/docs/exosphere/trigger-graph.md index cc217ec0..4ff2cb83 100644 --- a/docs/docs/exosphere/trigger-graph.md +++ b/docs/docs/exosphere/trigger-graph.md @@ -16,20 +16,14 @@ The recommended way to trigger graphs is using the Exosphere Python SDK, which p state_manager_uri=EXOSPHERE_STATE_MANAGER_URI, key=EXOSPHERE_API_KEY ) - - # Create trigger state - trigger_state = TriggerState( - identifier="data_loader", # Must match a node identifier in your graph - inputs={ - "source": "/path/to/data.csv", - "format": "csv", - "batch_size": "1000" - } - ) - + try: - # Trigger the graph - result = await state_manager.trigger("my-graph", state=trigger_state) + # Trigger the graph with optional store (beta) + result = await state_manager.trigger( + "my-graph", + inputs={"user_id": "123"}, + store={"cursor": "0"} # persisted across nodes (beta) + ) print(f"Graph triggered successfully!") print(f"Run ID: {result['run_id']}") return result diff --git a/python-sdk/README.md b/python-sdk/README.md index 7df6c3c2..40b6e23b 100644 --- a/python-sdk/README.md +++ b/python-sdk/README.md @@ -77,6 +77,7 @@ export EXOSPHERE_API_KEY="your-api-key" - **Async Support**: Native async/await support for high-performance operations - **Error Handling**: Built-in retry mechanisms and error recovery - **Scalability**: Designed for high-volume batch processing and workflows +- **Graph Store (beta)**: Strings-only key-value store with per-run scope for sharing data across nodes (not durable across separate runs or clusters) ## Architecture @@ -241,95 +242,29 @@ trigger_state = TriggerState( } ) -# Trigger a single state -result = await state_manager.trigger("my-graph", state=trigger_state) - -# Or trigger multiple states -trigger_states = [ - TriggerState(identifier="trigger1", inputs={"key1": "value1"}), - TriggerState(identifier="trigger2", inputs={"key2": "value2"}) -] - -result = await state_manager.trigger("my-graph", states=trigger_states) -``` - -**Parameters:** -- `graph_name` (str): The name of the graph to trigger -- `state` (TriggerState, optional): A single trigger state -- `states` (list[TriggerState], optional): A list of trigger states - -**Returns:** -- `dict`: The JSON response from the state manager API - -**Raises:** -- `ValueError`: If neither `state` nor `states` is provided, if both are provided, or if `states` is an empty list -- `Exception`: If the API request fails with a non-200 status code - -### TriggerState Class - -The `TriggerState` class represents a trigger state for graph execution. It contains an identifier and a set of input parameters that will be passed to the graph when it is triggered. - -#### Creating Trigger States - -```python -from exospherehost import TriggerState - -# Basic trigger state -trigger_state = TriggerState( - identifier="data-processing", - inputs={ - "file_path": "/path/to/data.csv", - "batch_size": "1000", - "priority": "high" - } -) - -# Trigger state with complex data (serialized as JSON) -import json - -complex_data = { - "filters": ["active", "verified"], - "date_range": {"start": "2024-01-01", "end": "2024-01-31"}, - "options": {"include_metadata": True, "format": "json"} -} - -trigger_state = TriggerState( - identifier="complex-processing", +# Trigger the graph (beta store support) +result = await state_manager.trigger( + "my-graph", inputs={ - "config": json.dumps(complex_data), - "user_id": "12345" + "user_id": "12345", + "session_token": "abc123def456" + }, + store={ + "cursor": "0" # persisted across nodes (beta) } ) ``` -**Attributes:** -- `identifier` (str): A unique identifier for this trigger state. Used to distinguish between different trigger states and may be used by the graph to determine how to process the trigger -- `inputs` (dict[str, str]): A dictionary of input parameters that will be passed to the graph. The keys are parameter names and values are parameter values, both as strings - -## Integration with ExosphereHost Platform - -The Python SDK integrates seamlessly with the ExosphereHost platform, providing: - -- **Performance**: Optimized execution with intelligent resource allocation and parallel processing -- **Reliability**: Built-in fault tolerance, automatic recovery, and failover capabilities -- **Scalability**: Automatic scaling based on workload demands -- **Monitoring**: Integrated logging and monitoring capabilities - -## Documentation - -For more detailed information, visit our [documentation](https://docs.exosphere.host). - -## Contributing +**Parameters:** -We welcome contributions! Please see our [contributing guidelines](https://github.com/exospherehost/exospherehost/blob/main/CONTRIBUTING.md) for details. +- `graph_name` (str): Name of the graph to execute +- `inputs` (dict[str, str] | None): Key/value inputs for the first node (strings only) +- `store` (dict[str, str] | None): Graph-level key/value store (beta) persisted across nodes -## Support +**Returns:** -For support and questions: -- **Email**: [nivedit@exosphere.host](mailto:nivedit@exosphere.host) -- **Documentation**: [https://docs.exosphere.host](https://docs.exosphere.host) -- **GitHub Issues**: [https://github.com/exospherehost/exospherehost/issues](https://github.com/exospherehost/exospherehost/issues) +- `dict`: JSON payload from the state manager -## License +**Raises:** -This Python SDK is licensed under the MIT License. The main ExosphereHost project is licensed under the Elastic License 2.0. \ No newline at end of file +- `Exception`: If the HTTP request fails \ No newline at end of file diff --git a/python-sdk/exospherehost/_version.py b/python-sdk/exospherehost/_version.py index 9a836ee0..d6c24cb2 100644 --- a/python-sdk/exospherehost/_version.py +++ b/python-sdk/exospherehost/_version.py @@ -1 +1 @@ -version = "0.0.2b2" +version = "0.0.2b3" diff --git a/python-sdk/exospherehost/statemanager.py b/python-sdk/exospherehost/statemanager.py index c6de7eb1..7940831c 100644 --- a/python-sdk/exospherehost/statemanager.py +++ b/python-sdk/exospherehost/statemanager.py @@ -67,60 +67,49 @@ def _get_upsert_graph_endpoint(self, graph_name: str): def _get_get_graph_endpoint(self, graph_name: str): return f"{self._state_manager_uri}/{self._state_manager_version}/namespace/{self._namespace}/graph/{graph_name}" - async def trigger(self, graph_name: str, state: TriggerState | None = None, states: list[TriggerState] | None = None): + async def trigger(self, graph_name: str, inputs: dict[str, str] | None = None, store: dict[str, str] | None = None): """ - Trigger a graph execution with one or more trigger states. + Trigger execution of a graph. - This method sends trigger states to the specified graph endpoint to initiate - graph execution. It accepts either a single trigger state or a list of trigger - states, but not both simultaneously. + Beta: This method now supports an optional **store** parameter that lets you + pass a key-value map that is persisted for the lifetime of the graph run. All + keys **and** values must be strings in the current beta release – the schema + may change in future versions. Args: - graph_name (str): The name of the graph to trigger execution for. - state (TriggerState | None, optional): A single trigger state to send. - Must be provided if `states` is None. - states (list[TriggerState] | None, optional): A list of trigger states to send. - Must be provided if `state` is None. Cannot be an empty list. + graph_name (str): Name of the graph you want to run. + inputs (dict[str, str] | None): Optional inputs for the first node in the + graph. Strings only. + store (dict[str, str] | None): Optional key-value store that will be merged + into the graph-level store before execution (beta). Returns: - dict: The JSON response from the state manager API containing the - result of the trigger operation. + dict: JSON payload returned by the state-manager API. Raises: - ValueError: If neither `state` nor `states` is provided, if both are provided, - or if `states` is an empty list. - Exception: If the API request fails with a non-200 status code. The exception - message includes the HTTP status code and response text for debugging. + Exception: If the request fails. Example: ```python - # Trigger with a single state - state = TriggerState(identifier="my-trigger", inputs={"key": "value"}) - result = await state_manager.trigger("my-graph", state=state) - - # Trigger with multiple states - states = [ - TriggerState(identifier="trigger1", inputs={"key1": "value1"}), - TriggerState(identifier="trigger2", inputs={"key2": "value2"}) - ] - result = await state_manager.trigger("my-graph", states=states) + # Trigger with inputs only + await state_manager.trigger("my-graph", inputs={"user_id": "123"}) + + # Trigger with inputs **and** a beta store + await state_manager.trigger( + "my-graph", + inputs={"user_id": "123"}, + store={"cursor": "0"} # beta + ) ``` """ - if state is None and states is None: - raise ValueError("Either state or states must be provided") - if state is not None and states is not None: - raise ValueError("Only one of state or states must be provided") - if states is not None and len(states) == 0: - raise ValueError("States must be a non-empty list") + if inputs is None: + inputs = {} + if store is None: + store = {} - states_list = [] - if state is not None: - states_list.append(state) - if states is not None: - states_list.extend(states) - body = { - "states": [state.model_dump() for state in states_list] + "inputs": inputs, + "store": store } headers = { "x-api-key": self._key @@ -167,35 +156,32 @@ async def get_graph(self, graph_name: str): raise Exception(f"Failed to get graph: {response.status} {await response.text()}") return await response.json() - async def upsert_graph(self, graph_name: str, graph_nodes: list[dict[str, Any]], secrets: dict[str, str], validation_timeout: int = 60, polling_interval: int = 1): + async def upsert_graph(self, graph_name: str, graph_nodes: list[dict[str, Any]], secrets: dict[str, str], retry_policy: dict[str, Any] | None = None, store_config: dict[str, Any] | None = None, validation_timeout: int = 60, polling_interval: int = 1): """ - Create or update a graph in the state manager with validation. + Create or update a graph definition. + + Beta: `store_config` is a new field that allows you to configure a + namespaced key-value store that lives for the duration of a graph run. The + feature is in beta and the shape of `store_config` may change. - This method sends a graph definition to the state manager API for creation - or update. After submission, it polls the API to wait for graph validation - to complete, ensuring the graph is properly configured before returning. + After submitting the graph, this helper polls the state-manager until the + graph has been validated (or the timeout is hit). Args: - graph_name (str): The name of the graph to create or update. - graph_nodes (list[dict[str, Any]]): A list of node definitions that make up - the graph. Each node should contain the necessary configuration for - the graph execution engine. - secrets (dict[str, str]): A dictionary of secret values that will be - available to the graph during execution. Keys are secret names and - values are the secret values. - validation_timeout (int, optional): Maximum time in seconds to wait for - graph validation to complete. Defaults to 60. - polling_interval (int, optional): Time in seconds between validation - status checks. Defaults to 1. - + graph_name (str): Graph identifier. + graph_nodes (list[dict[str, Any]]): Graph node list. + secrets (dict[str, str]): Secrets available to all nodes. + retry_policy (dict[str, Any] | None): Optional per-node retry policy. + store_config (dict[str, Any] | None): Beta configuration for the + graph-level store (schema is subject to change). + validation_timeout (int): Seconds to wait for validation (default 60). + polling_interval (int): Polling interval in seconds (default 1). + Returns: - dict: The JSON response from the state manager API containing the - validated graph information. - + dict: Validated graph object returned by the API. + Raises: - Exception: If the API request fails with a non-201 status code, if graph - validation times out, or if validation fails. The exception message - includes relevant error details for debugging. + Exception: If validation fails or times out. """ endpoint = self._get_upsert_graph_endpoint(graph_name) headers = { @@ -205,6 +191,12 @@ async def upsert_graph(self, graph_name: str, graph_nodes: list[dict[str, Any]], "secrets": secrets, "nodes": graph_nodes } + + if retry_policy is not None: + body["retry_policy"] = retry_policy + if store_config is not None: + body["store_config"] = store_config + async with aiohttp.ClientSession() as session: async with session.put(endpoint, json=body, headers=headers) as response: # type: ignore if response.status not in [200, 201]: diff --git a/python-sdk/tests/test_coverage_additions.py b/python-sdk/tests/test_coverage_additions.py new file mode 100644 index 00000000..04d97cd1 --- /dev/null +++ b/python-sdk/tests/test_coverage_additions.py @@ -0,0 +1,173 @@ +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from pydantic import BaseModel + +from exospherehost.statemanager import StateManager +from exospherehost.runtime import Runtime +from exospherehost.node.BaseNode import BaseNode +from exospherehost.signals import PruneSignal, ReQueueAfterSignal + + +def _make_mock_session_with_status(status: int): + mock_session = MagicMock() + mock_resp = MagicMock() + mock_resp.status = status + mock_resp.json = AsyncMock(return_value={}) + + mock_ctx = MagicMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp) + mock_ctx.__aexit__ = AsyncMock(return_value=None) + + mock_session.post.return_value = mock_ctx + mock_session.get.return_value = mock_ctx + mock_session.put.return_value = mock_ctx + + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=None) + return mock_session, mock_resp + + +@pytest.mark.asyncio +async def test_statemanager_trigger_defaults(monkeypatch): + monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") + monkeypatch.setenv("EXOSPHERE_API_KEY", "k") + + sm = StateManager(namespace="ns") + + mock_session, _ = _make_mock_session_with_status(200) + + with patch('exospherehost.statemanager.aiohttp.ClientSession', return_value=mock_session): + await sm.trigger("g") + + # Verify it sent empty inputs/store when omitted + mock_session.post.assert_called_once() + _, kwargs = mock_session.post.call_args + assert kwargs["json"] == {"inputs": {}, "store": {}} + + +class _DummyNode(BaseNode): + class Inputs(BaseModel): + x: str = "" + class Outputs(BaseModel): + y: str + class Secrets(BaseModel): + pass + async def execute(self): # type: ignore + return self.Outputs(y="ok") + + +@pytest.mark.asyncio +async def test_runtime_enqueue_puts_states_and_sleeps(monkeypatch): + monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") + monkeypatch.setenv("EXOSPHERE_API_KEY", "k") + + rt = Runtime(namespace="ns", name="rt", nodes=[_DummyNode], batch_size=2, workers=1) + + with patch.object(rt, "_enqueue_call", new=AsyncMock(side_effect=[{"states": [{"state_id": "s1", "node_name": _DummyNode.__name__, "inputs": {}}]}, asyncio.CancelledError()])): + task = asyncio.create_task(rt._enqueue()) + await asyncio.sleep(0.01) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + assert rt._state_queue.qsize() >= 1 + + +def test_runtime_validate_nodes_not_subclass(monkeypatch): + monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") + monkeypatch.setenv("EXOSPHERE_API_KEY", "k") + + class NotNode: + pass + + with pytest.raises(ValueError) as e: + Runtime(namespace="ns", name="rt", nodes=[NotNode]) # type: ignore + msg = str(e.value) + # Expect multiple validation messages + assert "does not inherit" in msg + assert "does not have an Inputs class" in msg + assert "does not have an Outputs class" in msg + assert "does not have an Secrets class" in msg + + +class _PruneNode(BaseNode): + class Inputs(BaseModel): + a: str + class Outputs(BaseModel): + b: str + class Secrets(BaseModel): + pass + async def execute(self): # type: ignore + raise PruneSignal({"reason": "test"}) + + +class _RequeueNode(BaseNode): + class Inputs(BaseModel): + a: str + class Outputs(BaseModel): + b: str + class Secrets(BaseModel): + pass + async def execute(self): # type: ignore + from datetime import timedelta + raise ReQueueAfterSignal(timedelta(seconds=1)) + + +@pytest.mark.asyncio +async def test_worker_handles_prune_signal(monkeypatch): + monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") + monkeypatch.setenv("EXOSPHERE_API_KEY", "k") + + rt = Runtime(namespace="ns", name="rt", nodes=[_PruneNode], workers=1) + + with patch('exospherehost.signals.PruneSignal.send', new=AsyncMock(return_value=None)) as send_mock: + await rt._state_queue.put({"state_id": "s1", "node_name": _PruneNode.__name__, "inputs": {"a": "1"}}) + worker = asyncio.create_task(rt._worker(1)) + await asyncio.sleep(0.02) + worker.cancel() + try: + await worker + except asyncio.CancelledError: + pass + send_mock.assert_awaited() + + +@pytest.mark.asyncio +async def test_worker_handles_requeue_signal(monkeypatch): + monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") + monkeypatch.setenv("EXOSPHERE_API_KEY", "k") + + rt = Runtime(namespace="ns", name="rt", nodes=[_RequeueNode], workers=1) + + with patch('exospherehost.signals.ReQueueAfterSignal.send', new=AsyncMock(return_value=None)) as send_mock: + await rt._state_queue.put({"state_id": "s2", "node_name": _RequeueNode.__name__, "inputs": {"a": "1"}}) + worker = asyncio.create_task(rt._worker(2)) + await asyncio.sleep(0.02) + worker.cancel() + try: + await worker + except asyncio.CancelledError: + pass + send_mock.assert_awaited() + + +@pytest.mark.asyncio +async def test_runtime_start_creates_tasks(monkeypatch): + monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") + monkeypatch.setenv("EXOSPHERE_API_KEY", "k") + + rt = Runtime(namespace="ns", name="rt", nodes=[_DummyNode], workers=1) + + with patch.object(rt, "_register", new=AsyncMock(return_value=None)): + with patch.object(rt, "_enqueue", new=AsyncMock(side_effect=asyncio.CancelledError())): + with patch.object(rt, "_worker", new=AsyncMock(side_effect=asyncio.CancelledError())): + t = asyncio.create_task(rt._start()) + await asyncio.sleep(0.01) + t.cancel() + try: + await t + except asyncio.CancelledError: + pass \ No newline at end of file diff --git a/python-sdk/tests/test_integration.py b/python-sdk/tests/test_integration.py index 157f038f..33e87500 100644 --- a/python-sdk/tests/test_integration.py +++ b/python-sdk/tests/test_integration.py @@ -219,7 +219,7 @@ async def test_state_manager_graph_lifecycle(self, mock_env_vars): inputs={"user_id": "123", "action": "login"} ) - trigger_result = await sm.trigger("test_graph", state=trigger_state) + trigger_result = await sm.trigger("test_graph", inputs=trigger_state.inputs) assert trigger_result == {"status": "triggered"} @@ -451,7 +451,7 @@ async def test_state_manager_error_propagation(self, mock_env_vars): trigger_state = TriggerState(identifier="test", inputs={"key": "value"}) with pytest.raises(Exception, match="Failed to trigger state: 404 Graph not found"): - await sm.trigger("nonexistent_graph", state=trigger_state) + await sm.trigger("nonexistent_graph", inputs=trigger_state.inputs) class TestConcurrencyIntegration: diff --git a/python-sdk/tests/test_runtime_validation.py b/python-sdk/tests/test_runtime_validation.py index d2f04a15..7b0f61a5 100644 --- a/python-sdk/tests/test_runtime_validation.py +++ b/python-sdk/tests/test_runtime_validation.py @@ -50,6 +50,7 @@ def test_runtime_missing_config_raises(monkeypatch): Runtime(namespace="ns", name="rt", nodes=[GoodNode]) +@pytest.mark.filterwarnings("ignore:.*coroutine.*was never awaited.*:RuntimeWarning") def test_runtime_with_env_ok(monkeypatch): monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") monkeypatch.setenv("EXOSPHERE_API_KEY", "k") @@ -57,6 +58,7 @@ def test_runtime_with_env_ok(monkeypatch): assert rt is not None +@pytest.mark.filterwarnings("ignore:.*coroutine.*was never awaited.*:RuntimeWarning") def test_runtime_invalid_params_raises(monkeypatch): monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") monkeypatch.setenv("EXOSPHERE_API_KEY", "k") @@ -79,6 +81,8 @@ def test_node_validation_errors(monkeypatch): assert "Inputs field" in msg and "Outputs field" in msg and "Secrets field" in msg +@pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") +@pytest.mark.filterwarnings("ignore:.*coroutine.*was never awaited.*:RuntimeWarning") def test_duplicate_node_names_raise(monkeypatch): monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") monkeypatch.setenv("EXOSPHERE_API_KEY", "k") @@ -107,8 +111,16 @@ async def execute(self): # Use the same name for both classes GoodNode2.__name__ = "GoodNode1" - # Suppress the RuntimeWarning about unawaited coroutines + # Suppress warnings about unawaited coroutines and pytest unraisable exceptions (test-only) with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message=".*coroutine.*was never awaited.*", category=RuntimeWarning) + warnings.filterwarnings( + "ignore", + message=".*coroutine.*was never awaited.*", + category=RuntimeWarning + ) + warnings.filterwarnings( + "ignore", + category=pytest.PytestUnraisableExceptionWarning + ) with pytest.raises(ValueError): Runtime(namespace="ns", name="rt", nodes=[GoodNode1, GoodNode2]) \ No newline at end of file diff --git a/python-sdk/tests/test_signals_and_runtime_functions.py b/python-sdk/tests/test_signals_and_runtime_functions.py index e6a2222d..c2659929 100644 --- a/python-sdk/tests/test_signals_and_runtime_functions.py +++ b/python-sdk/tests/test_signals_and_runtime_functions.py @@ -105,10 +105,29 @@ async def test_prune_signal_send_failure(self): data = {"reason": "test_prune"} signal = PruneSignal(data) - mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() - mock_post_response.status = 500 - - with patch('exospherehost.signals.ClientSession', return_value=mock_session): + class _FakeResponse: + def __init__(self, status: int): + self.status = status + + class _FakePostCtx: + def __init__(self, status: int): + self._status = status + async def __aenter__(self): + return _FakeResponse(self._status) + async def __aexit__(self, exc_type, exc, tb): + return None + + class _FakeSession: + def __init__(self, status: int): + self._status = status + def post(self, *args, **kwargs): + return _FakePostCtx(self._status) + async def __aenter__(self): + return self + async def __aexit__(self, exc_type, exc, tb): + return None + + with patch('exospherehost.signals.ClientSession', return_value=_FakeSession(500)): with pytest.raises(Exception, match="Failed to send prune signal"): await signal.send("http://test-endpoint/prune", "test-api-key") @@ -177,10 +196,29 @@ async def test_requeue_signal_send_failure(self): delta = timedelta(seconds=30) signal = ReQueueAfterSignal(delta) - mock_session, mock_post_response, _, _ = create_mock_aiohttp_session() - mock_post_response.status = 400 - - with patch('exospherehost.signals.ClientSession', return_value=mock_session): + class _FakeResponse: + def __init__(self, status: int): + self.status = status + + class _FakePostCtx: + def __init__(self, status: int): + self._status = status + async def __aenter__(self): + return _FakeResponse(self._status) + async def __aexit__(self, exc_type, exc, tb): + return None + + class _FakeSession: + def __init__(self, status: int): + self._status = status + def post(self, *args, **kwargs): + return _FakePostCtx(self._status) + async def __aenter__(self): + return self + async def __aexit__(self, exc_type, exc, tb): + return None + + with patch('exospherehost.signals.ClientSession', return_value=_FakeSession(400)): with pytest.raises(Exception, match="Failed to send requeue after signal"): await signal.send("http://test-endpoint/requeue", "test-api-key") @@ -209,6 +247,8 @@ def test_runtime_endpoint_construction(self): assert requeue_endpoint == expected_requeue @pytest.mark.asyncio + @pytest.mark.filterwarnings("ignore:.*coroutine.*was never awaited.*:RuntimeWarning") + @pytest.mark.filterwarnings("ignore::pytest.PytestUnraisableExceptionWarning") async def test_signal_handling_direct(self): """Test signal handling by directly calling signal.send() with runtime endpoints.""" runtime = Runtime( diff --git a/python-sdk/tests/test_state_manager.py b/python-sdk/tests/test_state_manager.py deleted file mode 100644 index ba8f5ccf..00000000 --- a/python-sdk/tests/test_state_manager.py +++ /dev/null @@ -1,28 +0,0 @@ -import pytest -import asyncio -from exospherehost.statemanager import StateManager, TriggerState - - -def test_trigger_requires_either_state_or_states(monkeypatch): - monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") - monkeypatch.setenv("EXOSPHERE_API_KEY", "k") - sm = StateManager(namespace="ns") - with pytest.raises(ValueError): - asyncio.run(sm.trigger("g")) - - -def test_trigger_rejects_both_state_and_states(monkeypatch): - monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") - monkeypatch.setenv("EXOSPHERE_API_KEY", "k") - sm = StateManager(namespace="ns") - state = TriggerState(identifier="id", inputs={}) - with pytest.raises(ValueError): - asyncio.run(sm.trigger("g", state=state, states=[state])) - - -def test_trigger_rejects_empty_states_list(monkeypatch): - monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") - monkeypatch.setenv("EXOSPHERE_API_KEY", "k") - sm = StateManager(namespace="ns") - with pytest.raises(ValueError): - asyncio.run(sm.trigger("g", states=[])) \ No newline at end of file diff --git a/python-sdk/tests/test_statemanager_comprehensive.py b/python-sdk/tests/test_statemanager_comprehensive.py index 9a763d90..c8b1c7a2 100644 --- a/python-sdk/tests/test_statemanager_comprehensive.py +++ b/python-sdk/tests/test_statemanager_comprehensive.py @@ -1,5 +1,4 @@ import pytest -import asyncio from unittest.mock import AsyncMock, patch, MagicMock from exospherehost.statemanager import StateManager, TriggerState @@ -127,7 +126,7 @@ async def test_trigger_single_state_success(self, state_manager_config): sm = StateManager(**state_manager_config) state = TriggerState(identifier="test", inputs={"key": "value"}) - result = await sm.trigger("test_graph", state=state) + result = await sm.trigger("test_graph", inputs=state.inputs) assert result == {"status": "success"} @@ -147,7 +146,8 @@ async def test_trigger_multiple_states_success(self, state_manager_config): TriggerState(identifier="test2", inputs={"key2": "value2"}) ] - result = await sm.trigger("test_graph", states=states) + merged_inputs = {**states[0].inputs, **states[1].inputs} + result = await sm.trigger("test_graph", inputs=merged_inputs) assert result == {"status": "success"} @@ -165,26 +165,7 @@ async def test_trigger_failure(self, state_manager_config): state = TriggerState(identifier="test", inputs={"key": "value"}) with pytest.raises(Exception, match="Failed to trigger state: 400 Bad request"): - await sm.trigger("test_graph", state=state) - - def test_trigger_validation_no_state_or_states(self, state_manager_config): - sm = StateManager(**state_manager_config) - - with pytest.raises(ValueError, match="Either state or states must be provided"): - asyncio.run(sm.trigger("test_graph")) - - def test_trigger_validation_both_state_and_states(self, state_manager_config): - sm = StateManager(**state_manager_config) - state = TriggerState(identifier="test", inputs={"key": "value"}) - - with pytest.raises(ValueError, match="Only one of state or states must be provided"): - asyncio.run(sm.trigger("test_graph", state=state, states=[state])) - - def test_trigger_validation_empty_states_list(self, state_manager_config): - sm = StateManager(**state_manager_config) - - with pytest.raises(ValueError, match="States must be a non-empty list"): - asyncio.run(sm.trigger("test_graph", states=[])) + await sm.trigger("test_graph", inputs=state.inputs) class TestStateManagerGetGraph: diff --git a/state-manager/.coverage b/state-manager/.coverage index c086d6eb..e0975eb1 100644 Binary files a/state-manager/.coverage and b/state-manager/.coverage differ diff --git a/state-manager/app/controller/create_states.py b/state-manager/app/controller/create_states.py deleted file mode 100644 index 5a2362fd..00000000 --- a/state-manager/app/controller/create_states.py +++ /dev/null @@ -1,93 +0,0 @@ -from fastapi import HTTPException - -from app.singletons.logs_manager import LogsManager -from app.models.create_models import CreateRequestModel, CreateResponseModel, ResponseStateModel, TriggerGraphRequestModel, TriggerGraphResponseModel -from app.models.state_status_enum import StateStatusEnum -from app.models.db.state import State -from app.models.db.graph_template_model import GraphTemplate -from app.models.node_template_model import NodeTemplate - -from beanie.operators import In -from beanie import PydanticObjectId -import uuid - -logger = LogsManager().get_logger() - - -def get_node_template(graph_template: GraphTemplate, identifier: str) -> NodeTemplate: - node = graph_template.get_node_by_identifier(identifier) - if not node: - raise HTTPException(status_code=404, detail="Node template not found") - return node - - -async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraphRequestModel, x_exosphere_request_id: str) -> TriggerGraphResponseModel: - try: - # Generate a new run ID for this graph execution - run_id = str(uuid.uuid4()) - logger.info(f"Triggering graph {graph_name} with run_id {run_id}", x_exosphere_request_id=x_exosphere_request_id) - - # Create a CreateRequestModel with the generated run_id - create_request = CreateRequestModel( - run_id=run_id, - states=body.states - ) - - # Call the existing create_states function - create_response = await create_states(namespace_name, graph_name, create_request, x_exosphere_request_id) - - # Return the trigger response with the generated run_id - return TriggerGraphResponseModel( - run_id=run_id, - status=create_response.status, - states=create_response.states - ) - - except Exception as e: - logger.error(f"Error triggering graph {graph_name} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - raise e - - -async def create_states(namespace_name: str, graph_name: str, body: CreateRequestModel, x_exosphere_request_id: str) -> CreateResponseModel: - try: - states = [] - logger.info(f"Creating states for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - - graph_template = await GraphTemplate.find_one(GraphTemplate.name == graph_name, GraphTemplate.namespace == namespace_name) - if not graph_template: - raise HTTPException(status_code=404, detail="Graph template not found") - - for state in body.states: - - node_template = get_node_template(graph_template, state.identifier) - - states.append( - State( - identifier=state.identifier, - node_name=node_template.node_name, - namespace_name=node_template.namespace, - graph_name=graph_name, - run_id=body.run_id, - status=StateStatusEnum.CREATED, - inputs=state.inputs, - outputs={}, - error=None - ) - ) - - inserted_states = await State.insert_many(states) - - logger.info(f"Created states: {inserted_states.inserted_ids}", x_exosphere_request_id=x_exosphere_request_id) - - newStates = await State.find( - In(State.id, [PydanticObjectId(id) for id in inserted_states.inserted_ids]) - ).to_list() - - return CreateResponseModel( - status=StateStatusEnum.CREATED, - states=[ResponseStateModel(state_id=str(state.id), identifier=state.identifier, node_name=state.node_name, graph_name=state.graph_name, run_id=state.run_id, inputs=state.inputs, created_at=state.created_at) for state in newStates] - ) - - except Exception as e: - logger.error(f"Error creating states for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - raise e \ No newline at end of file diff --git a/state-manager/app/controller/trigger_graph.py b/state-manager/app/controller/trigger_graph.py new file mode 100644 index 00000000..45aa1139 --- /dev/null +++ b/state-manager/app/controller/trigger_graph.py @@ -0,0 +1,80 @@ +from fastapi import HTTPException + +from app.singletons.logs_manager import LogsManager +from app.models.trigger_model import TriggerGraphRequestModel, TriggerGraphResponseModel +from app.models.state_status_enum import StateStatusEnum +from app.models.db.state import State +from app.models.db.store import Store +from app.models.db.graph_template_model import GraphTemplate +from app.models.node_template_model import NodeTemplate +import uuid + +logger = LogsManager().get_logger() + +def check_required_store_keys(graph_template: GraphTemplate, store: dict[str, str]) -> None: + required_keys = set(graph_template.store_config.required_keys) + provided_keys = set(store.keys()) + + missing_keys = required_keys - provided_keys + if missing_keys: + raise HTTPException(status_code=400, detail=f"Missing store keys: {missing_keys}") + + +def construct_inputs(node: NodeTemplate, inputs: dict[str, str]) -> dict[str, str]: + return {key: inputs.get(key, value) for key, value in node.inputs.items()} + + +async def trigger_graph(namespace_name: str, graph_name: str, body: TriggerGraphRequestModel, x_exosphere_request_id: str) -> TriggerGraphResponseModel: + try: + run_id = str(uuid.uuid4()) + logger.info(f"Triggering graph {graph_name} with run_id {run_id}", x_exosphere_request_id=x_exosphere_request_id) + + try: + graph_template = await GraphTemplate.get(namespace_name, graph_name) + except ValueError as e: + logger.error(f"Graph template not found for namespace {namespace_name} and graph {graph_name}", x_exosphere_request_id=x_exosphere_request_id) + if "Graph template not found" in str(e): + raise HTTPException(status_code=404, detail=f"Graph template not found for namespace {namespace_name} and graph {graph_name}") + else: + raise e + + if not graph_template.is_valid(): + raise HTTPException(status_code=400, detail="Graph template is not valid") + + check_required_store_keys(graph_template, body.store) + + new_stores = [ + Store( + run_id=run_id, + namespace=namespace_name, + graph_name=graph_name, + key=key, + value=value + ) for key, value in body.store.items() + ] + + await Store.insert_many(new_stores) + + root = graph_template.get_root_node() + + new_state = State( + node_name=root.node_name, + namespace_name=namespace_name, + identifier=root.identifier, + graph_name=graph_name, + run_id=run_id, + status=StateStatusEnum.CREATED, + inputs=construct_inputs(root, body.inputs), + outputs={}, + error=None + ) + await new_state.insert() + + return TriggerGraphResponseModel( + status=StateStatusEnum.CREATED, + run_id=run_id + ) + + except Exception as e: + logger.error(f"Error triggering graph {graph_name} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise e diff --git a/state-manager/app/main.py b/state-manager/app/main.py index a6eb2795..2a8ed8c1 100644 --- a/state-manager/app/main.py +++ b/state-manager/app/main.py @@ -20,6 +20,7 @@ from .models.db.state import State from .models.db.graph_template_model import GraphTemplate from .models.db.registered_node import RegisteredNode +from .models.db.store import Store # injecting routes from .routes import router @@ -41,7 +42,7 @@ async def lifespan(app: FastAPI): # initializing beanie client = AsyncMongoClient(settings.mongo_uri) db = client[settings.mongo_database_name] - await init_beanie(db, document_models=[State, GraphTemplate, RegisteredNode]) + await init_beanie(db, document_models=[State, GraphTemplate, RegisteredNode, Store]) logger.info("beanie dbs initialized") # initialize secret diff --git a/state-manager/app/models/create_models.py b/state-manager/app/models/create_models.py deleted file mode 100644 index c34d490c..00000000 --- a/state-manager/app/models/create_models.py +++ /dev/null @@ -1,39 +0,0 @@ -from pydantic import BaseModel, Field -from typing import Any -from .state_status_enum import StateStatusEnum -from datetime import datetime - - -class RequestStateModel(BaseModel): - identifier: str = Field(..., description="Unique identifier of the node template within the graph template") - inputs: dict[str, Any] = Field(..., description="Inputs of the state") - - -class ResponseStateModel(BaseModel): - state_id: str = Field(..., description="ID of the state") - node_name: str = Field(..., description="Name of the node of the state") - identifier: str = Field(..., description="Identifier of the node for which state is created") - graph_name: str = Field(..., description="Name of the graph template for this state") - run_id: str = Field(..., description="Unique run ID for grouping states from the same graph execution") - inputs: dict[str, Any] = Field(..., description="Inputs of the state") - created_at: datetime = Field(..., description="Date and time when the state was created") - - -class CreateRequestModel(BaseModel): - run_id: str = Field(..., description="Unique run ID for grouping states from the same graph execution") - states: list[RequestStateModel] = Field(..., description="List of states") - - -class CreateResponseModel(BaseModel): - status: StateStatusEnum = Field(..., description="Status of the state") - states: list[ResponseStateModel] = Field(..., description="List of states") - - -class TriggerGraphRequestModel(BaseModel): - states: list[RequestStateModel] = Field(..., description="List of states to create for the graph execution") - - -class TriggerGraphResponseModel(BaseModel): - run_id: str = Field(..., description="Unique run ID generated for this graph execution") - status: StateStatusEnum = Field(..., description="Status of the states") - states: list[ResponseStateModel] = Field(..., description="List of created states") \ No newline at end of file diff --git a/state-manager/app/models/db/graph_template_model.py b/state-manager/app/models/db/graph_template_model.py index 4b3a4731..a6d05df9 100644 --- a/state-manager/app/models/db/graph_template_model.py +++ b/state-manager/app/models/db/graph_template_model.py @@ -12,6 +12,7 @@ from app.utils.encrypter import get_encrypter from app.models.dependent_string import DependentString from app.models.retry_policy_model import RetryPolicyModel +from app.models.store_config_model import StoreConfig class GraphTemplate(BaseDatabaseModel): name: str = Field(..., description="Name of the graph") @@ -21,6 +22,7 @@ class GraphTemplate(BaseDatabaseModel): validation_errors: List[str] = Field(default_factory=list, description="Validation errors of the graph") secrets: Dict[str, str] = Field(default_factory=dict, description="Secrets of the graph") retry_policy: RetryPolicyModel = Field(default_factory=RetryPolicyModel, description="Retry policy of the graph") + store_config: StoreConfig = Field(default_factory=StoreConfig, description="Store config of the graph") _node_by_identifier: Dict[str, NodeTemplate] | None = PrivateAttr(default=None) _parents_by_identifier: Dict[str, set[str]] | None = PrivateAttr(default=None) # type: ignore @@ -119,16 +121,18 @@ def dfs(node_identifier: str, parents: set[str], path: set[str]) -> None: @field_validator('name') @classmethod def validate_name(cls, v: str) -> str: - if v == "" or v is None: + trimmed_v = v.strip() + if trimmed_v == "" or trimmed_v is None: raise ValueError("Name cannot be empty") - return v + return trimmed_v @field_validator('namespace') @classmethod def validate_namespace(cls, v: str) -> str: - if v == "" or v is None: + trimmed_v = v.strip() + if trimmed_v == "" or trimmed_v is None: raise ValueError("Namespace cannot be empty") - return v + return trimmed_v @field_validator('secrets') @classmethod @@ -137,7 +141,6 @@ def validate_secrets(cls, v: Dict[str, str]) -> Dict[str, str]: if not secret_name or not secret_value: raise ValueError("Secrets cannot be empty") cls._validate_secret_value(secret_value) - return v @field_validator('nodes') @@ -236,12 +239,23 @@ def verify_input_dependencies(self) -> Self: continue dependent_string = DependentString.create_dependent_string(input_value) - dependent_identifiers = set([identifier for identifier, _ in dependent_string.get_identifier_field()]) + dependent_identifiers = set() + store_fields = set() + + for key, field in dependent_string.get_identifier_field(): + if key == "store": + store_fields.add(field) + else: + dependent_identifiers.add(key) for identifier in dependent_identifiers: if identifier not in self.get_parents_by_identifier(node.identifier): errors.append(f"Input {input_value} depends on {identifier} but {identifier} is not a parent of {node.identifier}") + for field in store_fields: + if field not in self.store_config.required_keys and field not in self.store_config.default_values: + errors.append(f"Input {input_value} depends on {field} but {field} is not a required key or a default value") + except Exception as e: errors.append(f"Error creating dependent string for input {input_value} check syntax string: {str(e)}") if errors: diff --git a/state-manager/app/models/db/store.py b/state-manager/app/models/db/store.py new file mode 100644 index 00000000..890b2a32 --- /dev/null +++ b/state-manager/app/models/db/store.py @@ -0,0 +1,36 @@ +from beanie import Document +from pydantic import Field +from pymongo import IndexModel + +class Store(Document): + run_id: str = Field(..., description="Run ID of the corresponding graph execution") + namespace: str = Field(..., description="Namespace of the graph") + graph_name: str = Field(..., description="Name of the graph") + key: str = Field(..., description="Key of the store") + value: str = Field(..., description="Value of the store") + + class Settings: + indexes = [ + IndexModel( + [ + ("run_id", 1), + ("namespace", 1), + ("graph_name", 1), + ("key", 1), + ], + unique=True, + name="uniq_run_id_namespace_graph_name_key", + ) + ] + + @staticmethod + async def get_value(run_id: str, namespace: str, graph_name: str, key: str) -> str | None: + store = await Store.find_one( + Store.run_id == run_id, + Store.namespace == namespace, + Store.graph_name == graph_name, + Store.key == key, + ) + if store is None: + return None + return store.value diff --git a/state-manager/app/models/dependent_string.py b/state-manager/app/models/dependent_string.py index 1e8da4a0..11c19a3d 100644 --- a/state-manager/app/models/dependent_string.py +++ b/state-manager/app/models/dependent_string.py @@ -34,11 +34,14 @@ def create_dependent_string(syntax_string: str) -> "DependentString": placeholder_content, tail = split.split("}}", 1) parts = [p.strip() for p in placeholder_content.split(".")] - if len(parts) != 3 or parts[1] != "outputs": + + if len(parts) == 3 and parts[1] == "outputs": + dependent_string.dependents[order] = Dependent(identifier=parts[0], field=parts[2], tail=tail) + elif len(parts) == 2 and parts[0] == "store": + dependent_string.dependents[order] = Dependent(identifier=parts[0], field=parts[1], tail=tail) + else: raise ValueError(f"Invalid syntax string placeholder {placeholder_content} for: {syntax_string}") - dependent_string.dependents[order] = Dependent(identifier=parts[0], field=parts[2], tail=tail) - return dependent_string def _build_mapping_key_to_dependent(self): diff --git a/state-manager/app/models/graph_models.py b/state-manager/app/models/graph_models.py index 8e67cb2d..1c9a5e91 100644 --- a/state-manager/app/models/graph_models.py +++ b/state-manager/app/models/graph_models.py @@ -4,18 +4,21 @@ from datetime import datetime from .graph_template_validation_status import GraphTemplateValidationStatus from .retry_policy_model import RetryPolicyModel +from .store_config_model import StoreConfig class UpsertGraphTemplateRequest(BaseModel): secrets: Dict[str, str] = Field(..., description="Dictionary of secrets that are used while graph execution") nodes: List[NodeTemplate] = Field(..., description="List of node templates that define the graph structure") retry_policy: RetryPolicyModel = Field(default_factory=RetryPolicyModel, description="Retry policy of the graph") + store_config: StoreConfig = Field(default_factory=StoreConfig, description="Store config of the graph") class UpsertGraphTemplateResponse(BaseModel): nodes: List[NodeTemplate] = Field(..., description="List of node templates that define the graph structure") secrets: Dict[str, bool] = Field(..., description="Dictionary of secrets that are used while graph execution") retry_policy: RetryPolicyModel = Field(default_factory=RetryPolicyModel, description="Retry policy of the graph") + store_config: StoreConfig = Field(default_factory=StoreConfig, description="Store config of the graph") created_at: datetime = Field(..., description="Timestamp when the graph template was created") updated_at: datetime = Field(..., description="Timestamp when the graph template was last updated") validation_status: GraphTemplateValidationStatus = Field(..., description="Current validation status of the graph template") diff --git a/state-manager/app/models/node_template_model.py b/state-manager/app/models/node_template_model.py index eb2c43a4..b9b54fe3 100644 --- a/state-manager/app/models/node_template_model.py +++ b/state-manager/app/models/node_template_model.py @@ -25,45 +25,55 @@ class NodeTemplate(BaseModel): @field_validator('node_name') @classmethod def validate_node_name(cls, v: str) -> str: - if v == "" or v is None: + trimmed_v = v.strip() + if trimmed_v == "" or trimmed_v is None: raise ValueError("Node name cannot be empty") - return v + return trimmed_v @field_validator('identifier') @classmethod def validate_identifier(cls, v: str) -> str: - if v == "" or v is None: + trimmed_v = v.strip() + if trimmed_v == "" or trimmed_v is None: raise ValueError("Node identifier cannot be empty") - return v + elif trimmed_v == "store": + raise ValueError("Node identifier cannot be reserved word 'store'") + return trimmed_v @field_validator('next_nodes') @classmethod def validate_next_nodes(cls, v: Optional[List[str]]) -> Optional[List[str]]: identifiers = set() errors = [] + trimmed_v = [] + if v is not None: for next_node_identifier in v: + trimmed_next_node_identifier = next_node_identifier.strip() - if next_node_identifier == "" or next_node_identifier is None: + if trimmed_next_node_identifier == "" or trimmed_next_node_identifier is None: errors.append("Next node identifier cannot be empty") continue - if next_node_identifier in identifiers: - errors.append(f"Next node identifier {next_node_identifier} is not unique") + if trimmed_next_node_identifier in identifiers: + errors.append(f"Next node identifier {trimmed_next_node_identifier} is not unique") continue - identifiers.add(next_node_identifier) + identifiers.add(trimmed_next_node_identifier) + trimmed_v.append(trimmed_next_node_identifier) if errors: raise ValueError("\n".join(errors)) - return v + return trimmed_v @field_validator('unites') @classmethod def validate_unites(cls, v: Optional[Unites]) -> Optional[Unites]: + trimmed_v = v if v is not None: - if v.identifier == "" or v.identifier is None: + trimmed_v = Unites(identifier=v.identifier.strip(), strategy=v.strategy) + if trimmed_v.identifier == "" or trimmed_v.identifier is None: raise ValueError("Unites identifier cannot be empty") - return v + return trimmed_v def get_dependent_strings(self) -> list[DependentString]: dependent_strings = [] diff --git a/state-manager/app/models/store_config_model.py b/state-manager/app/models/store_config_model.py new file mode 100644 index 00000000..16d220f9 --- /dev/null +++ b/state-manager/app/models/store_config_model.py @@ -0,0 +1,61 @@ +from pydantic import BaseModel, Field, field_validator + +class StoreConfig(BaseModel): + required_keys: list[str] = Field(default_factory=list, description="Required keys of the store") + default_values: dict[str, str] = Field(default_factory=dict, description="Default values of the store") + + @field_validator("required_keys") + def validate_required_keys(cls, v: list[str]) -> list[str]: + errors = [] + keys = set() + trimmed_keys = [] + + for key in v: + trimmed_key = key.strip() if key is not None else "" + + if trimmed_key == "": + errors.append("Key cannot be empty or contain only whitespace") + continue + + if '.' in trimmed_key: + errors.append(f"Key '{trimmed_key}' cannot contain '.' character") + continue + + if trimmed_key in keys: + errors.append(f"Key '{trimmed_key}' is duplicated") + continue + + keys.add(trimmed_key) + trimmed_keys.append(trimmed_key) + + if len(errors) > 0: + raise ValueError("\n".join(errors)) + return trimmed_keys + + @field_validator("default_values") + def validate_default_values(cls, v: dict[str, str]) -> dict[str, str]: + errors = [] + keys = set() + normalized_dict = {} + + for key, value in v.items(): + trimmed_key = key.strip() if key is not None else "" + + if trimmed_key == "": + errors.append("Key cannot be empty or contain only whitespace") + continue + + if '.' in trimmed_key: + errors.append(f"Key '{trimmed_key}' cannot contain '.' character") + continue + + if trimmed_key in keys: + errors.append(f"Key '{trimmed_key}' is duplicated") + continue + + keys.add(trimmed_key) + normalized_dict[trimmed_key] = str(value) + + if len(errors) > 0: + raise ValueError("\n".join(errors)) + return normalized_dict \ No newline at end of file diff --git a/state-manager/app/models/trigger_model.py b/state-manager/app/models/trigger_model.py new file mode 100644 index 00000000..a61ffceb --- /dev/null +++ b/state-manager/app/models/trigger_model.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field +from .state_status_enum import StateStatusEnum + +class TriggerGraphRequestModel(BaseModel): + store: dict[str, str] = Field(default_factory=dict, description="Store for the runtime") + inputs: dict[str, str] = Field(default_factory=dict, description="Inputs for the graph execution") + +class TriggerGraphResponseModel(BaseModel): + status: StateStatusEnum = Field(..., description="Status of the states") + run_id: str = Field(..., description="Unique run ID generated for this graph execution") \ No newline at end of file diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index ee219bb1..8143d574 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -9,8 +9,8 @@ from .models.enqueue_request import EnqueueRequestModel from .controller.enqueue_states import enqueue_states -from .models.create_models import CreateRequestModel, CreateResponseModel, TriggerGraphRequestModel, TriggerGraphResponseModel -from .controller.create_states import create_states, trigger_graph +from .models.trigger_model import TriggerGraphRequestModel, TriggerGraphResponseModel +from .controller.trigger_graph import trigger_graph from .models.executed_models import ExecutedRequestModel, ExecutedResponseModel from .controller.executed_state import executed_state @@ -92,27 +92,6 @@ async def trigger_graph_route(namespace_name: str, graph_name: str, body: Trigge return await trigger_graph(namespace_name, graph_name, body, x_exosphere_request_id) - -@router.post( - "/graph/{graph_name}/states/create", - response_model=CreateResponseModel, - status_code=status.HTTP_200_OK, - response_description="States created successfully", - tags=["state"] -) -async def create_state(namespace_name: str, graph_name: str, body: CreateRequestModel, request: Request, api_key: str = Depends(check_api_key)): - - x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) - - if api_key: - logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - else: - logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") - - return await create_states(namespace_name, graph_name, body, x_exosphere_request_id) - - @router.post( "/states/{state_id}/executed", response_model=ExecutedResponseModel, diff --git a/state-manager/app/tasks/create_next_states.py b/state-manager/app/tasks/create_next_states.py index 659dddfb..f51c9c69 100644 --- a/state-manager/app/tasks/create_next_states.py +++ b/state-manager/app/tasks/create_next_states.py @@ -7,11 +7,13 @@ from app.models.state_status_enum import StateStatusEnum from app.models.node_template_model import NodeTemplate from app.models.db.registered_node import RegisteredNode +from app.models.db.store import Store from app.models.dependent_string import DependentString from app.models.node_template_model import UnitesStrategyEnum from json_schema_to_pydantic import create_model from pydantic import BaseModel from typing import Type +import asyncio logger = LogsManager().get_logger() @@ -82,42 +84,6 @@ def validate_dependencies(next_state_node_template: NodeTemplate, next_state_inp raise AttributeError(f"Output field '{dependent.field}' not found on state '{dependent.identifier}' for template '{next_state_node_template.identifier}'") -def generate_next_state(next_state_input_model: Type[BaseModel], next_state_node_template: NodeTemplate, parents: dict[str, State], current_state: State) -> State: - next_state_input_data = {} - - for field_name, _ in next_state_input_model.model_fields.items(): - dependency_string = DependentString.create_dependent_string(next_state_node_template.inputs[field_name]) - - for identifier, field in dependency_string.get_identifier_field(): - if identifier == current_state.identifier: - if field not in current_state.outputs: - raise AttributeError(f"Output field '{field}' not found on current state '{current_state.identifier}' for template '{next_state_node_template.identifier}'") - dependency_string.set_value(identifier, field, current_state.outputs[field]) - else: - dependency_string.set_value(identifier, field, parents[identifier].outputs[field]) - - next_state_input_data[field_name] = dependency_string.generate_string() - - new_parents = { - **current_state.parents, - current_state.identifier: current_state.id - } - - return State( - node_name=next_state_node_template.node_name, - identifier=next_state_node_template.identifier, - namespace_name=next_state_node_template.namespace, - graph_name=current_state.graph_name, - status=StateStatusEnum.CREATED, - parents=new_parents, - inputs=next_state_input_data, - outputs={}, - does_unites=next_state_node_template.unites is not None, - run_id=current_state.run_id, - error=None - ) - - async def create_next_states(state_ids: list[PydanticObjectId], identifier: str, namespace: str, graph_name: str, parents_ids: dict[str, PydanticObjectId]): try: @@ -137,7 +103,8 @@ async def create_next_states(state_ids: list[PydanticObjectId], identifier: str, cached_registered_nodes: dict[tuple[str, str], RegisteredNode] = {} cached_input_models: dict[tuple[str, str], Type[BaseModel]] = {} - new_states = [] + cached_store_values: dict[tuple[str, str], str] = {} + new_states_coroutines = [] async def get_registered_node(node_template: NodeTemplate) -> RegisteredNode: key = (node_template.namespace, node_template.node_name) @@ -153,6 +120,59 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: if key not in cached_input_models: cached_input_models[key] = create_model((await get_registered_node(node_template)).inputs_schema) return cached_input_models[key] + + async def get_store_value(run_id: str, field: str) -> str: + key = (run_id, field) + if key not in cached_store_values: + store_value = await Store.get_value(run_id, namespace, graph_name, field) + + if store_value is None: + store_value = graph_template.store_config.default_values.get(field) + if store_value is None: + raise ValueError(f"Store value not found for field '{field}' in namespace '{namespace}' and graph '{graph_name}'") + + cached_store_values[key] = store_value + return cached_store_values[key] + + async def generate_next_state(next_state_input_model: Type[BaseModel], next_state_node_template: NodeTemplate, parents: dict[str, State], current_state: State) -> State: + next_state_input_data = {} + + for field_name, _ in next_state_input_model.model_fields.items(): + dependency_string = DependentString.create_dependent_string(next_state_node_template.inputs[field_name]) + + for identifier, field in dependency_string.get_identifier_field(): + + if identifier == "store": + dependency_string.set_value(identifier, field, await get_store_value(current_state.run_id, field)) + + elif identifier == current_state.identifier: + if field not in current_state.outputs: + raise AttributeError(f"Output field '{field}' not found on current state '{current_state.identifier}' for template '{next_state_node_template.identifier}'") + dependency_string.set_value(identifier, field, current_state.outputs[field]) + + else: + dependency_string.set_value(identifier, field, parents[identifier].outputs[field]) + + next_state_input_data[field_name] = dependency_string.generate_string() + + new_parents = { + **current_state.parents, + current_state.identifier: current_state.id + } + + return State( + node_name=next_state_node_template.node_name, + identifier=next_state_node_template.identifier, + namespace_name=next_state_node_template.namespace, + graph_name=current_state.graph_name, + status=StateStatusEnum.CREATED, + parents=new_parents, + inputs=next_state_input_data, + outputs={}, + does_unites=next_state_node_template.unites is not None, + run_id=current_state.run_id, + error=None + ) current_states = await State.find( In(State.id, state_ids) @@ -184,14 +204,14 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: validate_dependencies(next_state_node_template, next_state_input_model, identifier, parents) for current_state in current_states: - new_states.append(generate_next_state(next_state_input_model, next_state_node_template, parents, current_state)) + new_states_coroutines.append(generate_next_state(next_state_input_model, next_state_node_template, parents, current_state)) - if len(new_states) > 0: - await State.insert_many(new_states) + if len(new_states_coroutines) > 0: + await State.insert_many(await asyncio.gather(*new_states_coroutines)) await mark_success_states(state_ids) # handle unites - new_unit_states = [] + new_unit_states_coroutines = [] for pending_unites_identifier in pending_unites: next_state_node_template = graph_template.get_node_by_identifier(pending_unites_identifier) if not next_state_node_template: @@ -206,16 +226,16 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]: assert next_state_node_template.unites is not None parent_state = parents[next_state_node_template.unites.identifier] - new_unit_states.append(generate_next_state(next_state_input_model, next_state_node_template, parents, parent_state)) + new_unit_states_coroutines.append(generate_next_state(next_state_input_model, next_state_node_template, parents, parent_state)) try: - if len(new_unit_states) > 0: - await State.insert_many(new_unit_states) + if len(new_unit_states_coroutines) > 0: + await State.insert_many(await asyncio.gather(*new_unit_states_coroutines)) except (DuplicateKeyError, BulkWriteError): logger.warning( f"Caught duplicate key error for new unit states in namespace={namespace}, " f"graph={graph_name}, likely due to a race condition. " - f"Attempted to insert {len(new_unit_states)} states" + f"Attempted to insert {len(new_unit_states_coroutines)} states" ) except Exception as e: diff --git a/state-manager/tests/unit/controller/test_create_states.py b/state-manager/tests/unit/controller/test_create_states.py deleted file mode 100644 index 50bb4776..00000000 --- a/state-manager/tests/unit/controller/test_create_states.py +++ /dev/null @@ -1,301 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, MagicMock, patch -from fastapi import HTTPException -from beanie import PydanticObjectId -from datetime import datetime - -from app.controller.create_states import create_states, get_node_template -from app.models.create_models import CreateRequestModel, RequestStateModel -from app.models.state_status_enum import StateStatusEnum -from app.models.node_template_model import NodeTemplate - - -class TestGetNodeTemplate: - """Test cases for get_node_template function""" - - def test_get_node_template_success(self): - """Test successful retrieval of node template""" - # Arrange - mock_node = NodeTemplate( - node_name="test_node", - namespace="test_namespace", - identifier="test_identifier", - inputs={}, - next_nodes=[], - unites=None - ) - mock_graph_template = MagicMock() - mock_graph_template.get_node_by_identifier.return_value = mock_node - - # Act - result = get_node_template(mock_graph_template, "test_identifier") - - # Assert - assert result == mock_node - mock_graph_template.get_node_by_identifier.assert_called_once_with("test_identifier") - - def test_get_node_template_not_found(self): - """Test when node template is not found""" - # Arrange - mock_graph_template = MagicMock() - mock_graph_template.get_node_by_identifier.return_value = None - - # Act & Assert - with pytest.raises(HTTPException) as exc_info: - get_node_template(mock_graph_template, "non_existent_identifier") - - assert exc_info.value.status_code == 404 - assert exc_info.value.detail == "Node template not found" - - -class TestCreateStates: - """Test cases for create_states function""" - - @pytest.fixture - def mock_request_id(self): - return "test-request-id" - - @pytest.fixture - def mock_namespace(self): - return "test_namespace" - - @pytest.fixture - def mock_graph_name(self): - return "test_graph" - - @pytest.fixture - def mock_node_template(self): - return NodeTemplate( - node_name="test_node", - namespace="test_namespace", - identifier="test_identifier", - inputs={}, - next_nodes=[], - unites=None - ) - - @pytest.fixture - def mock_graph_template(self, mock_node_template, mock_graph_name, mock_namespace): - mock_template = MagicMock() - mock_template.name = mock_graph_name - mock_template.namespace = mock_namespace - mock_template.get_node_by_identifier.return_value = mock_node_template - return mock_template - - @pytest.fixture - def mock_create_request(self): - return CreateRequestModel( - run_id="test_run_id", - states=[ - RequestStateModel( - identifier="test_identifier", - inputs={"key": "value"} - ) - ] - ) - - @pytest.fixture - def mock_state(self): - state = MagicMock() - state.id = PydanticObjectId() - state.identifier = "test_identifier" - state.node_name = "test_node" - state.run_id = "test_run_id" - state.graph_name = "test_graph" - state.inputs = {"key": "value"} - state.created_at = datetime.now() - return state - - @patch('app.controller.create_states.GraphTemplate') - @patch('app.controller.create_states.State') - async def test_create_states_success( - self, - mock_state_class, - mock_graph_template_class, - mock_namespace, - mock_graph_name, - mock_create_request, - mock_graph_template, - mock_state, - mock_request_id - ): - """Test successful creation of states""" - # Arrange - # Mock the GraphTemplate class and its find_one method - mock_graph_template_class.find_one = AsyncMock(return_value=mock_graph_template) - - # Mock State.insert_many - mock_insert_result = MagicMock() - mock_insert_result.inserted_ids = [PydanticObjectId()] - mock_state_class.insert_many = AsyncMock(return_value=mock_insert_result) - - # Mock State.find().to_list() - mock_state_find = MagicMock() - mock_state_find.to_list = AsyncMock(return_value=[mock_state]) - mock_state_class.find = MagicMock(return_value=mock_state_find) - - # Act - result = await create_states( - mock_namespace, - mock_graph_name, - mock_create_request, - mock_request_id - ) - - # Assert - assert result.status == StateStatusEnum.CREATED - assert len(result.states) == 1 - assert result.states[0].identifier == "test_identifier" - assert result.states[0].node_name == "test_node" - assert result.states[0].inputs == {"key": "value"} - - # Verify find_one was called (with any arguments) - assert mock_graph_template_class.find_one.called - mock_state_class.insert_many.assert_called_once() - mock_state_class.find.assert_called_once() - - @patch('app.controller.create_states.GraphTemplate') - async def test_create_states_graph_template_not_found( - self, - mock_graph_template_class, - mock_namespace, - mock_graph_name, - mock_create_request, - mock_request_id - ): - """Test when graph template is not found""" - # Arrange - mock_graph_template_class.find_one = AsyncMock(return_value=None) - - # Act & Assert - with pytest.raises(HTTPException) as exc_info: - await create_states( - mock_namespace, - mock_graph_name, - mock_create_request, - mock_request_id - ) - - assert exc_info.value.status_code == 404 - assert exc_info.value.detail == "Graph template not found" - assert mock_graph_template_class.find_one.called - - @patch('app.controller.create_states.GraphTemplate') - async def test_create_states_node_template_not_found( - self, - mock_graph_template_class, - mock_namespace, - mock_graph_name, - mock_create_request, - mock_request_id - ): - """Test when node template is not found in graph template""" - # Arrange - mock_graph_template = MagicMock() - mock_graph_template.get_node_by_identifier.return_value = None - mock_graph_template_class.find_one = AsyncMock(return_value=mock_graph_template) - - # Act & Assert - with pytest.raises(HTTPException) as exc_info: - await create_states( - mock_namespace, - mock_graph_name, - mock_create_request, - mock_request_id - ) - - assert exc_info.value.status_code == 404 - assert exc_info.value.detail == "Node template not found" - assert mock_graph_template_class.find_one.called - - @patch('app.controller.create_states.GraphTemplate') - async def test_create_states_database_error( - self, - mock_graph_template_class, - mock_namespace, - mock_graph_name, - mock_create_request, - mock_graph_template, - mock_request_id - ): - """Test handling of database errors""" - # Arrange - mock_graph_template_class.find_one = AsyncMock(side_effect=Exception("Database error")) - - # Act & Assert - with pytest.raises(Exception) as exc_info: - await create_states( - mock_namespace, - mock_graph_name, - mock_create_request, - mock_request_id - ) - - assert str(exc_info.value) == "Database error" - assert mock_graph_template_class.find_one.called - - @patch('app.controller.create_states.GraphTemplate') - @patch('app.controller.create_states.State') - async def test_create_states_multiple_states( - self, - mock_state_class, - mock_graph_template_class, - mock_namespace, - mock_graph_name, - mock_graph_template, - mock_request_id - ): - """Test creation of multiple states""" - # Arrange - mock_graph_template_class.find_one = AsyncMock(return_value=mock_graph_template) - - mock_insert_result = MagicMock() - mock_insert_result.inserted_ids = [PydanticObjectId(), PydanticObjectId()] - mock_state_class.insert_many = AsyncMock(return_value=mock_insert_result) - - # Mock State.find().to_list() for multiple states - mock_state1 = MagicMock() - mock_state1.id = PydanticObjectId() - mock_state1.identifier = "node1" - mock_state1.node_name = "test_node" - mock_state1.run_id = "test_run_id" - mock_state1.graph_name = "test_graph" - mock_state1.inputs = {"input1": "value1"} - mock_state1.created_at = datetime.now() - - mock_state2 = MagicMock() - mock_state2.id = PydanticObjectId() - mock_state2.identifier = "node2" - mock_state2.node_name = "test_node" - mock_state2.run_id = "test_run_id" - mock_state2.graph_name = "test_graph" - mock_state2.inputs = {"input2": "value2"} - mock_state2.created_at = datetime.now() - - mock_state_find = MagicMock() - mock_state_find.to_list = AsyncMock(return_value=[mock_state1, mock_state2]) - mock_state_class.find = MagicMock(return_value=mock_state_find) - - create_request = CreateRequestModel( - run_id="test_run_id", - states=[ - RequestStateModel(identifier="node1", inputs={"input1": "value1"}), - RequestStateModel(identifier="node2", inputs={"input2": "value2"}) - ] - ) - - # Act - result = await create_states( - mock_namespace, - mock_graph_name, - create_request, - mock_request_id - ) - - # Assert - assert result.status == StateStatusEnum.CREATED - assert mock_graph_template_class.find_one.called - mock_state_class.insert_many.assert_called_once() - # Verify that insert_many was called with 2 states - call_args = mock_state_class.insert_many.call_args[0][0] - assert len(call_args) == 2 diff --git a/state-manager/tests/unit/controller/test_trigger_graph.py b/state-manager/tests/unit/controller/test_trigger_graph.py index a33c961d..a7900993 100644 --- a/state-manager/tests/unit/controller/test_trigger_graph.py +++ b/state-manager/tests/unit/controller/test_trigger_graph.py @@ -1,99 +1,123 @@ -from datetime import datetime import pytest -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock from fastapi import HTTPException -from app.controller.create_states import trigger_graph -from app.models.create_models import TriggerGraphRequestModel, RequestStateModel, ResponseStateModel +from app.controller.trigger_graph import trigger_graph +from app.models.trigger_model import TriggerGraphRequestModel from app.models.state_status_enum import StateStatusEnum @pytest.fixture def mock_request(): return TriggerGraphRequestModel( - states=[ - RequestStateModel( - identifier="test_node_1", - inputs={"input1": "value1"} - ), - RequestStateModel( - identifier="test_node_2", - inputs={"input2": "value2"} - ) - ] + store={"k1": "v1"}, + inputs={"input1": "value1"} ) @pytest.mark.asyncio async def test_trigger_graph_success(mock_request): - """Test successful graph triggering""" namespace_name = "test_namespace" graph_name = "test_graph" x_exosphere_request_id = "test_request_id" - - # Mock the create_states function - with patch('app.controller.create_states.create_states') as mock_create_states: - mock_response = MagicMock() - mock_response.status = StateStatusEnum.CREATED - mock_response.states = [ - ResponseStateModel( - state_id="state_1", - identifier="test_node_1", - node_name="TestNode1", - graph_name=graph_name, - run_id="generated_run_id", - inputs={"input1": "value1"}, - created_at=datetime(2024, 1, 1, 0, 0, 0) - ), - ResponseStateModel( - state_id="state_2", - identifier="test_node_2", - node_name="TestNode2", - graph_name=graph_name, - run_id="generated_run_id", - inputs={"input2": "value2"}, - created_at=datetime(2024, 1, 1, 0, 0, 0) - ) - ] - mock_create_states.return_value = mock_response - - # Call the function + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls, \ + patch('app.controller.trigger_graph.Store') as mock_store_cls, \ + patch('app.controller.trigger_graph.State') as mock_state_cls: + + mock_graph_template = MagicMock() + mock_graph_template.is_valid.return_value = True + mock_root_node = MagicMock() + mock_root_node.node_name = "root_node" + mock_root_node.identifier = "root_id" + mock_root_node.inputs = {"input1": "default"} + mock_graph_template.get_root_node.return_value = mock_root_node + mock_graph_template_cls.get = AsyncMock(return_value=mock_graph_template) + + mock_store_cls.insert_many = AsyncMock(return_value=None) + mock_state_instance = MagicMock() + mock_state_instance.insert = AsyncMock(return_value=None) + mock_state_cls.return_value = mock_state_instance + result = await trigger_graph(namespace_name, graph_name, mock_request, x_exosphere_request_id) - - # Verify the result - assert result.run_id is not None + assert result.status == StateStatusEnum.CREATED - assert len(result.states) == 2 - assert result.states[0].identifier == "test_node_1" - assert result.states[1].identifier == "test_node_2" - - # Verify create_states was called with the correct parameters - mock_create_states.assert_called_once() - call_args = mock_create_states.call_args - assert call_args[0][0] == namespace_name # namespace_name - assert call_args[0][1] == graph_name # graph_name - assert call_args[0][3] == x_exosphere_request_id # x_exosphere_request_id - - # Verify the CreateRequestModel was created with a generated run_id - create_request = call_args[0][2] # body parameter - assert create_request.run_id is not None - assert create_request.states == mock_request.states + assert isinstance(result.run_id, str) and len(result.run_id) > 0 + + mock_graph_template_cls.get.assert_awaited_once_with(namespace_name, graph_name) + mock_store_cls.insert_many.assert_awaited_once() + mock_state_instance.insert.assert_awaited_once() @pytest.mark.asyncio -async def test_trigger_graph_create_states_error(mock_request): - """Test error handling when create_states fails""" +async def test_trigger_graph_graph_template_not_found(mock_request): namespace_name = "test_namespace" graph_name = "test_graph" x_exosphere_request_id = "test_request_id" - - # Mock create_states to raise an exception - with patch('app.controller.create_states.create_states') as mock_create_states: - mock_create_states.side_effect = HTTPException(status_code=404, detail="Graph template not found") - - # Call the function and expect it to raise the same exception + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls: + mock_graph_template_cls.get = AsyncMock(side_effect=ValueError("Graph template not found")) + with pytest.raises(HTTPException) as exc_info: await trigger_graph(namespace_name, graph_name, mock_request, x_exosphere_request_id) - + assert exc_info.value.status_code == 404 - assert exc_info.value.detail == "Graph template not found" + assert "Graph template not found" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_trigger_graph_invalid_graph_template(mock_request): + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls: + mock_graph_template = MagicMock() + mock_graph_template.is_valid.return_value = False + mock_graph_template_cls.get = AsyncMock(return_value=mock_graph_template) + + with pytest.raises(HTTPException) as exc_info: + await trigger_graph(namespace_name, graph_name, mock_request, x_exosphere_request_id) + + assert exc_info.value.status_code == 400 + assert exc_info.value.detail == "Graph template is not valid" + + +@pytest.mark.asyncio +async def test_trigger_graph_missing_store_keys(): + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + req = TriggerGraphRequestModel(store={}, inputs={}) + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls: + mock_graph_template = MagicMock() + mock_graph_template.is_valid.return_value = True + mock_graph_template.store_config.required_keys = ["k1"] + mock_root_node = MagicMock() + mock_root_node.node_name = "root_node" + mock_root_node.identifier = "root_id" + mock_graph_template.get_root_node.return_value = mock_root_node + mock_graph_template_cls.get = AsyncMock(return_value=mock_graph_template) + + with pytest.raises(HTTPException) as exc_info: + await trigger_graph(namespace_name, graph_name, req, x_exosphere_request_id) + + assert exc_info.value.status_code == 400 + assert "Missing store keys" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_trigger_graph_value_error_not_graph_template_not_found(mock_request): + """Test trigger_graph handles ValueError that is not about graph template not found""" + namespace_name = "test_namespace" + graph_name = "test_graph" + x_exosphere_request_id = "test_request_id" + + with patch('app.controller.trigger_graph.GraphTemplate') as mock_graph_template_cls: + # Simulate a ValueError that doesn't contain "Graph template not found" + mock_graph_template_cls.get.side_effect = ValueError("Some other validation error") + + with pytest.raises(ValueError, match="Some other validation error"): + await trigger_graph(namespace_name, graph_name, mock_request, x_exosphere_request_id) diff --git a/state-manager/tests/unit/models/test_dependent_string.py b/state-manager/tests/unit/models/test_dependent_string.py new file mode 100644 index 00000000..b159cb8b --- /dev/null +++ b/state-manager/tests/unit/models/test_dependent_string.py @@ -0,0 +1,85 @@ +import pytest +from app.models.dependent_string import DependentString, Dependent + + +class TestDependentString: + """Additional test cases for DependentString model to improve coverage""" + + def test_generate_string_with_unset_dependent_value(self): + """Test generate_string method fails when dependent value is not set""" + dependent_string = DependentString( + head="prefix_", + dependents={ + 0: Dependent(identifier="node1", field="output1", tail="_suffix", value=None) + } + ) + + with pytest.raises(ValueError, match="Dependent value is not set for:"): + dependent_string.generate_string() + + def test_build_mapping_key_to_dependent_already_built(self): + """Test _build_mapping_key_to_dependent when mapping already exists""" + dependent_string = DependentString( + head="prefix_", + dependents={ + 0: Dependent(identifier="node1", field="output1", tail="_suffix") + } + ) + + # Build mapping first time + dependent_string._build_mapping_key_to_dependent() + original_mapping = dependent_string._mapping_key_to_dependent.copy() + + # Call again - should not rebuild + dependent_string._build_mapping_key_to_dependent() + assert dependent_string._mapping_key_to_dependent == original_mapping + + def test_set_value_multiple_dependents_same_key(self): + """Test set_value method with multiple dependents having same identifier and field""" + dependent1 = Dependent(identifier="node1", field="output1", tail="_suffix1") + dependent2 = Dependent(identifier="node1", field="output1", tail="_suffix2") + + dependent_string = DependentString( + head="prefix_", + dependents={0: dependent1, 1: dependent2} + ) + + dependent_string.set_value("node1", "output1", "test_value") + + assert dependent1.value == "test_value" + assert dependent2.value == "test_value" + + def test_get_identifier_field_multiple_mappings(self): + """Test get_identifier_field method with multiple identifier-field mappings""" + dependent_string = DependentString( + head="prefix_", + dependents={ + 0: Dependent(identifier="node1", field="output1", tail="_suffix1"), + 1: Dependent(identifier="node2", field="output2", tail="_suffix2"), + 2: Dependent(identifier="node1", field="output3", tail="_suffix3") + } + ) + + identifier_fields = dependent_string.get_identifier_field() + + # Should have 3 unique identifier-field pairs + expected_pairs = [("node1", "output1"), ("node2", "output2"), ("node1", "output3")] + assert len(identifier_fields) == 3 + assert set(identifier_fields) == set(expected_pairs) + + + def test_create_dependent_string_with_store_dependency(self): + """Test create_dependent_string method with store dependency""" + syntax_string = "prefix_${{store.config_key}}_suffix" + + dependent_string = DependentString.create_dependent_string(syntax_string) + + assert dependent_string.head == "prefix_" + assert len(dependent_string.dependents) == 1 + assert 0 in dependent_string.dependents + + dependent = dependent_string.dependents[0] + assert dependent.identifier == "store" + assert dependent.field == "config_key" + assert dependent.tail == "_suffix" + assert dependent.value is None diff --git a/state-manager/tests/unit/models/test_graph_template_model.py b/state-manager/tests/unit/models/test_graph_template_model.py index 55da8e94..db49004a 100644 --- a/state-manager/tests/unit/models/test_graph_template_model.py +++ b/state-manager/tests/unit/models/test_graph_template_model.py @@ -1,5 +1,5 @@ import pytest -from unittest.mock import patch +from unittest.mock import patch, MagicMock import base64 from app.models.db.graph_template_model import GraphTemplate @@ -165,4 +165,94 @@ def test_get_valid_timeout(self): def test_get_valid_exception_handling(self): """Test get_valid method exception handling""" # This test doesn't require GraphTemplate instantiation - assert GraphTemplate.get_valid.__name__ == "get_valid" \ No newline at end of file + assert GraphTemplate.get_valid.__name__ == "get_valid" + + @pytest.mark.asyncio + async def test_get_valid_negative_polling_interval(self): + """Test get_valid method with negative polling interval""" + with pytest.raises(ValueError, match="polling_interval must be positive"): + await GraphTemplate.get_valid("test_ns", "test_graph", polling_interval=-1.0) + + @pytest.mark.asyncio + async def test_get_valid_zero_polling_interval(self): + """Test get_valid method with zero polling interval""" + with pytest.raises(ValueError, match="polling_interval must be positive"): + await GraphTemplate.get_valid("test_ns", "test_graph", polling_interval=0.0) + + @pytest.mark.asyncio + async def test_get_valid_negative_timeout(self): + """Test get_valid method with negative timeout""" + with pytest.raises(ValueError, match="timeout must be positive"): + await GraphTemplate.get_valid("test_ns", "test_graph", timeout=-1.0) + + @pytest.mark.asyncio + async def test_get_valid_zero_timeout(self): + """Test get_valid method with zero timeout""" + with pytest.raises(ValueError, match="timeout must be positive"): + await GraphTemplate.get_valid("test_ns", "test_graph", timeout=0.0) + + @pytest.mark.asyncio + async def test_get_valid_coerces_small_polling_interval_mock(self): + """Test get_valid method coerces very small polling interval to 0.1""" + with patch.object(GraphTemplate, 'get') as mock_get, \ + patch('time.monotonic', side_effect=[0, 1, 2]), \ + patch('asyncio.sleep') as _: + + mock_template = MagicMock() + mock_template.is_valid.return_value = True + mock_get.return_value = mock_template + + result = await GraphTemplate.get_valid("test_ns", "test_graph", polling_interval=0.01) + + assert result == mock_template + # Should have coerced polling_interval to 0.1 + # (This is harder to test directly, but we can verify the function completed) + + @pytest.mark.asyncio + async def test_get_valid_coerces_small_polling_interval(self): + """Test get_valid method coerces very small polling interval to 0.1""" + from unittest.mock import MagicMock + + with patch.object(GraphTemplate, 'get') as mock_get, \ + patch('time.monotonic', side_effect=[0, 1, 2]), \ + patch('asyncio.sleep') as _: + + mock_template = MagicMock() + mock_template.is_valid.return_value = True + mock_get.return_value = mock_template + + result = await GraphTemplate.get_valid("test_ns", "test_graph", polling_interval=0.01) + + assert result == mock_template + + @pytest.mark.asyncio + async def test_get_valid_non_validating_state(self): + """Test get_valid method when graph template is in non-validating state""" + from unittest.mock import MagicMock + + with patch.object(GraphTemplate, 'get') as mock_get: + mock_template = MagicMock() + mock_template.is_valid.return_value = False + mock_template.is_validating.return_value = False + mock_template.validation_status.value = "INVALID" + mock_get.return_value = mock_template + + with pytest.raises(ValueError, match="Graph template is in a non-validating state: INVALID"): + await GraphTemplate.get_valid("test_ns", "test_graph") + + @pytest.mark.asyncio + async def test_get_valid_timeout_reached(self): + """Test get_valid method when timeout is reached""" + from unittest.mock import MagicMock + + with patch.object(GraphTemplate, 'get') as mock_get, \ + patch('time.monotonic', side_effect=[0, 0.5, 1.0, 1.5, 2.0]), \ + patch('asyncio.sleep') as _: + + mock_template = MagicMock() + mock_template.is_valid.return_value = False + mock_template.is_validating.return_value = True + mock_get.return_value = mock_template + + with pytest.raises(ValueError, match="Graph template is not valid for namespace: test_ns and graph name: test_graph after 1.0 seconds"): + await GraphTemplate.get_valid("test_ns", "test_graph", timeout=1.0) diff --git a/state-manager/tests/unit/models/test_node_template_model.py b/state-manager/tests/unit/models/test_node_template_model.py new file mode 100644 index 00000000..794d3643 --- /dev/null +++ b/state-manager/tests/unit/models/test_node_template_model.py @@ -0,0 +1,77 @@ +import pytest +from app.models.node_template_model import NodeTemplate, Unites, UnitesStrategyEnum +from app.models.dependent_string import DependentString + + +class TestNodeTemplate: + """Test cases for NodeTemplate model""" + + def test_validate_identifier_reserved_word_store(self): + """Test validation fails for reserved word 'store' as identifier""" + with pytest.raises(ValueError, match="Node identifier cannot be reserved word 'store'"): + NodeTemplate( + node_name="test_node", + namespace="test_ns", + identifier="store", + inputs={"input1": "value1"}, + next_nodes=[], + unites=None + ) + + def test_get_dependent_strings_with_non_string_input(self): + """Test get_dependent_strings method with non-string input""" + node = NodeTemplate( + node_name="test_node", + namespace="test_ns", + identifier="test_id", + inputs={"input1": "valid_string", "input2": 123}, + next_nodes=[], + unites=None + ) + + with pytest.raises(ValueError, match="Input 123 is not a string"): + node.get_dependent_strings() + + def test_get_dependent_strings_valid(self): + """Test get_dependent_strings method with valid string inputs""" + node = NodeTemplate( + node_name="test_node", + namespace="test_ns", + identifier="test_id", + inputs={ + "input1": "simple_string", + "input2": "${{node1.outputs.field1}}", + "input3": "prefix_${{store.key1}}_suffix" + }, + next_nodes=[], + unites=None + ) + + dependent_strings = node.get_dependent_strings() + assert len(dependent_strings) == 3 + assert all(isinstance(ds, DependentString) for ds in dependent_strings) + + +class TestUnites: + """Test cases for Unites model""" + + def test_unites_creation_default_strategy(self): + """Test creating Unites with default strategy""" + unites = Unites(identifier="test_id") + assert unites.identifier == "test_id" + assert unites.strategy == UnitesStrategyEnum.ALL_SUCCESS + + def test_unites_creation_custom_strategy(self): + """Test creating Unites with custom strategy""" + unites = Unites(identifier="test_id", strategy=UnitesStrategyEnum.ALL_DONE) + assert unites.identifier == "test_id" + assert unites.strategy == UnitesStrategyEnum.ALL_DONE + + +class TestUnitesStrategyEnum: + """Test cases for UnitesStrategyEnum""" + + def test_enum_values(self): + """Test enum values are correct""" + assert UnitesStrategyEnum.ALL_SUCCESS == "ALL_SUCCESS" + assert UnitesStrategyEnum.ALL_DONE == "ALL_DONE" diff --git a/state-manager/tests/unit/models/test_store.py b/state-manager/tests/unit/models/test_store.py new file mode 100644 index 00000000..20137c58 --- /dev/null +++ b/state-manager/tests/unit/models/test_store.py @@ -0,0 +1,64 @@ +import pytest +from unittest.mock import AsyncMock, patch, MagicMock +from app.models.db.store import Store + + +class TestStore: + """Test cases for Store model""" + + def test_store_settings_indexes(self): + """Test Store model has correct indexes defined""" + indexes = Store.Settings.indexes + assert len(indexes) == 1 + + index = indexes[0] + assert index.document["unique"] + assert index.document["name"] == "uniq_run_id_namespace_graph_name_key" + + @pytest.mark.asyncio + async def test_get_value_found(self): + """Test get_value method when store entry is found""" + # Create mock store instance + mock_store = MagicMock() + mock_store.value = "test_value" + + # Mock the entire Store class and its find_one method + with patch('app.models.db.store.Store') as mock_store_class: + mock_store_class.find_one = AsyncMock(return_value=mock_store) + + # Call the actual static method + result = await Store.get_value("test_run", "test_ns", "test_graph", "test_key") + + assert result == "test_value" + + @pytest.mark.asyncio + async def test_get_value_not_found(self): + """Test get_value method when store entry is not found""" + # Mock the entire Store class and its find_one method + with patch('app.models.db.store.Store') as mock_store_class: + mock_store_class.find_one = AsyncMock(return_value=None) + + # Call the actual static method + result = await Store.get_value("test_run", "test_ns", "test_graph", "nonexistent_key") + + assert result is None + + @pytest.mark.asyncio + async def test_get_value_with_different_parameters(self): + """Test get_value method with various parameter combinations""" + test_cases = [ + ("run1", "ns1", "graph1", "key1", "value1"), + ("run2", "ns2", "graph2", "key2", "value2"), + ("", "", "", "", ""), # Edge case with empty strings + ] + + for run_id, namespace, graph_name, key, expected_value in test_cases: + mock_store = MagicMock() + mock_store.value = expected_value + + with patch('app.models.db.store.Store') as mock_store_class: + mock_store_class.find_one = AsyncMock(return_value=mock_store) + + result = await Store.get_value(run_id, namespace, graph_name, key) + + assert result == expected_value diff --git a/state-manager/tests/unit/models/test_store_config_model.py b/state-manager/tests/unit/models/test_store_config_model.py new file mode 100644 index 00000000..3ca05a67 --- /dev/null +++ b/state-manager/tests/unit/models/test_store_config_model.py @@ -0,0 +1,150 @@ +import pytest +from app.models.store_config_model import StoreConfig + + +class TestStoreConfig: + """Test cases for StoreConfig model""" + + def test_store_config_creation_defaults(self): + """Test creating StoreConfig with default values""" + config = StoreConfig() + assert config.required_keys == [] + assert config.default_values == {} + + def test_store_config_creation_with_values(self): + """Test creating StoreConfig with provided values""" + config = StoreConfig( + required_keys=["key1", "key2"], + default_values={"default_key": "default_value"} + ) + assert config.required_keys == ["key1", "key2"] + assert config.default_values == {"default_key": "default_value"} + + def test_validate_required_keys_valid(self): + """Test validation of valid required keys""" + valid_keys = ["key1", "key2", "key3"] + result = StoreConfig.validate_required_keys(valid_keys) # type: ignore + assert result == valid_keys + + def test_validate_required_keys_with_whitespace(self): + """Test validation trims whitespace from keys""" + keys_with_whitespace = [" key1 ", " key2 ", "key3"] + result = StoreConfig.validate_required_keys(keys_with_whitespace) # type: ignore + assert result == ["key1", "key2", "key3"] + + def test_validate_required_keys_empty_string(self): + """Test validation fails for empty string keys""" + invalid_keys = ["key1", "", "key3"] + with pytest.raises(ValueError, match="Key cannot be empty or contain only whitespace"): + StoreConfig.validate_required_keys(invalid_keys) # type: ignore + + def test_validate_required_keys_whitespace_only(self): + """Test validation fails for whitespace-only keys""" + invalid_keys = ["key1", " ", "key3"] + with pytest.raises(ValueError, match="Key cannot be empty or contain only whitespace"): + StoreConfig.validate_required_keys(invalid_keys) # type: ignore + + def test_validate_required_keys_none_value(self): + """Test validation fails for None keys""" + invalid_keys = ["key1", None, "key3"] + with pytest.raises(ValueError, match="Key cannot be empty or contain only whitespace"): + StoreConfig.validate_required_keys(invalid_keys) # type: ignore + + def test_validate_required_keys_dot_character(self): + """Test validation fails for keys containing dot character""" + invalid_keys = ["key1", "key.with.dot", "key3"] + with pytest.raises(ValueError, match="Key 'key.with.dot' cannot contain '.' character"): + StoreConfig.validate_required_keys(invalid_keys) # type: ignore + + def test_validate_required_keys_duplicates(self): + """Test validation fails for duplicate keys""" + invalid_keys = ["key1", "key2", "key1"] + with pytest.raises(ValueError, match="Key 'key1' is duplicated"): + StoreConfig.validate_required_keys(invalid_keys) # type: ignore + + def test_validate_required_keys_duplicates_after_trim(self): + """Test validation fails for duplicate keys after trimming""" + invalid_keys = ["key1", " key1 ", "key2"] + with pytest.raises(ValueError, match="Key 'key1' is duplicated"): + StoreConfig.validate_required_keys(invalid_keys) # type: ignore + + def test_validate_required_keys_multiple_errors(self): + """Test validation collects multiple errors""" + invalid_keys = ["", "key.dot", "key1", "key1", " "] + with pytest.raises(ValueError) as exc_info: + StoreConfig.validate_required_keys(invalid_keys) # type: ignore + + error_message = str(exc_info.value) + assert "Key cannot be empty or contain only whitespace" in error_message + assert "Key 'key.dot' cannot contain '.' character" in error_message + assert "Key 'key1' is duplicated" in error_message + + def test_validate_default_values_valid(self): + """Test validation of valid default values""" + valid_values = {"key1": "value1", "key2": "value2"} + result = StoreConfig.validate_default_values(valid_values) # type: ignore + assert result == valid_values + + def test_validate_default_values_with_whitespace(self): + """Test validation trims whitespace from keys""" + values_with_whitespace = {" key1 ": "value1", " key2 ": "value2"} + result = StoreConfig.validate_default_values(values_with_whitespace) # type: ignore + assert result == {"key1": "value1", "key2": "value2"} + + def test_validate_default_values_empty_key(self): + """Test validation fails for empty string keys""" + invalid_values = {"key1": "value1", "": "value2"} + with pytest.raises(ValueError, match="Key cannot be empty or contain only whitespace"): + StoreConfig.validate_default_values(invalid_values) # type: ignore + + def test_validate_default_values_whitespace_only_key(self): + """Test validation fails for whitespace-only keys""" + invalid_values = {"key1": "value1", " ": "value2"} + with pytest.raises(ValueError, match="Key cannot be empty or contain only whitespace"): + StoreConfig.validate_default_values(invalid_values) # type: ignore + + def test_validate_default_values_none_key(self): + """Test validation fails for None keys""" + invalid_values = {"key1": "value1", None: "value2"} + with pytest.raises(ValueError, match="Key cannot be empty or contain only whitespace"): + StoreConfig.validate_default_values(invalid_values) # type: ignore + + def test_validate_default_values_dot_character(self): + """Test validation fails for keys containing dot character""" + invalid_values = {"key1": "value1", "key.with.dot": "value2"} + with pytest.raises(ValueError, match="Key 'key.with.dot' cannot contain '.' character"): + StoreConfig.validate_default_values(invalid_values) # type: ignore + + def test_validate_default_values_duplicates_after_trim(self): + """Test validation fails for duplicate keys after trimming""" + values_with_duplicates_after_trim = {" key1 ": "value1", "key1": "value2"} + with pytest.raises(ValueError, match="Key 'key1' is duplicated"): + StoreConfig.validate_default_values(values_with_duplicates_after_trim) # type: ignore + + def test_validate_default_values_multiple_errors(self): + """Test validation collects multiple errors""" + invalid_values = {"": "value1", "key.dot": "value2", " key1 ": "value3", "key1": "duplicate"} + with pytest.raises(ValueError) as exc_info: + StoreConfig.validate_default_values(invalid_values) # type: ignore + + error_message = str(exc_info.value) + assert "Key cannot be empty or contain only whitespace" in error_message + assert "Key 'key.dot' cannot contain '.' character" in error_message + assert "Key 'key1' is duplicated" in error_message + + def test_store_config_integration(self): + """Test creating StoreConfig with validation""" + # Test successful creation + config = StoreConfig( + required_keys=[" key1 ", "key2"], + default_values={" default1 ": "value1", "default2": "value2"} + ) + assert config.required_keys == ["key1", "key2"] + assert config.default_values == {"default1": "value1", "default2": "value2"} + + # Test failure case + with pytest.raises(ValueError): + StoreConfig( + required_keys=["key1", "key.invalid"], + default_values={"valid": "value"} + ) diff --git a/state-manager/tests/unit/tasks/test_create_next_states.py b/state-manager/tests/unit/tasks/test_create_next_states.py index f419100d..9fd46e06 100644 --- a/state-manager/tests/unit/tasks/test_create_next_states.py +++ b/state-manager/tests/unit/tasks/test_create_next_states.py @@ -10,6 +10,7 @@ from app.models.dependent_string import Dependent, DependentString from app.models.state_status_enum import StateStatusEnum from app.models.node_template_model import NodeTemplate, Unites, UnitesStrategyEnum +from app.models.store_config_model import StoreConfig from pydantic import BaseModel @@ -633,4 +634,612 @@ async def test_create_next_states_exception_handling(self): mock_find.set.assert_called_with({ "status": StateStatusEnum.NEXT_CREATED_ERROR, "error": "Graph template error" - }) \ No newline at end of file + }) + + +class TestGetStoreValue: + """Test cases for get_store_value function within create_next_states""" + + @pytest.mark.asyncio + async def test_get_store_value_from_cache(self): + """Test getting store value from cache within a single execution""" + # Test that multiple references to the same store field within one execution use cache + + with patch('app.tasks.create_next_states.GraphTemplate') as mock_graph_template, \ + patch('app.tasks.create_next_states.Store') as mock_store, \ + patch('app.tasks.create_next_states.State') as mock_state_class, \ + patch('app.tasks.create_next_states.validate_dependencies') as mock_validate: + + # Setup GraphTemplate mock + mock_template = MagicMock() + mock_template.store_config = StoreConfig(default_values={"default_field": "default_value"}) + current_node = NodeTemplate( + node_name="test_node", + identifier="current_id", + namespace="test", + inputs={}, + next_nodes=["next_node"], + unites=None + ) + # Create a node template that uses the same store field twice + next_node = NodeTemplate( + node_name="next_node", + identifier="next_node", + namespace="test", + inputs={ + "input1": "${{store.test_field}}", + "input2": "${{store.test_field}}_suffix" # Same field used twice + }, + next_nodes=None, + unites=None + ) + # Set up to handle multiple calls + def get_node_side_effect(identifier): + if identifier == "current_id": + return current_node + elif identifier == "next_node": + return next_node + return None + mock_template.get_node_by_identifier.side_effect = get_node_side_effect + mock_graph_template.get_valid = AsyncMock(return_value=mock_template) + + # Mock validate_dependencies to pass + mock_validate.return_value = None + + # Setup Store mock + mock_store.get_value = AsyncMock(return_value="store_value") + + # Setup State mock + mock_state_class.id = "id" + mock_current_state = MagicMock() + mock_current_state.run_id = "test_run" + mock_current_state.identifier = "current_id" + mock_current_state.outputs = {"field1": "output_value"} + mock_find = AsyncMock() + mock_find.to_list.return_value = [mock_current_state] + mock_find.set = AsyncMock() + mock_state_class.find.return_value = mock_find + mock_state_class.insert_many = AsyncMock() + + # Setup RegisteredNode mock + with patch('app.tasks.create_next_states.RegisteredNode') as mock_registered_node: + mock_registered_node_instance = MagicMock() + mock_registered_node_instance.inputs_schema = {"input1": {"type": "string"}, "input2": {"type": "string"}} + mock_registered_node.get_by_name_and_namespace = AsyncMock(return_value=mock_registered_node_instance) + + with patch('app.tasks.create_next_states.create_model') as mock_create_model: + mock_input_model = MagicMock() + mock_input_model.model_fields = { + "input1": MagicMock(annotation=str), + "input2": MagicMock(annotation=str) + } + mock_create_model.return_value = mock_input_model + + # Single call that should use the same store field twice + await create_next_states([PydanticObjectId()], "current_id", "test_namespace", "test_graph", {}) + + # Verify Store.get_value was called only once despite being used twice (cached) + mock_store.get_value.assert_called_once_with("test_run", "test_namespace", "test_graph", "test_field") + + @pytest.mark.asyncio + async def test_get_store_value_from_store(self): + """Test getting store value from Store when not cached""" + + with patch('app.tasks.create_next_states.GraphTemplate') as mock_graph_template, \ + patch('app.tasks.create_next_states.Store') as mock_store, \ + patch('app.tasks.create_next_states.State') as mock_state_class, \ + patch('app.tasks.create_next_states.validate_dependencies') as mock_validate: + + # Setup GraphTemplate mock + mock_template = MagicMock() + mock_template.store_config = StoreConfig(default_values={}) + current_node = NodeTemplate( + node_name="test_node", + identifier="current_id", + namespace="test", + inputs={}, + next_nodes=["next_node"], + unites=None + ) + next_node = NodeTemplate( + node_name="next_node", + identifier="next_node", + namespace="test", + inputs={"input1": "${{store.test_field}}"}, + next_nodes=None, + unites=None + ) + def get_node_side_effect(identifier): + if identifier == "current_id": + return current_node + elif identifier == "next_node": + return next_node + return None + mock_template.get_node_by_identifier.side_effect = get_node_side_effect + mock_graph_template.get_valid = AsyncMock(return_value=mock_template) + + # Mock validate_dependencies to pass + mock_validate.return_value = None + + # Setup Store mock to return a value + mock_store.get_value = AsyncMock(return_value="store_value") + + # Setup State mock + mock_state_class.id = "id" + mock_current_state = MagicMock() + mock_current_state.run_id = "test_run" + mock_current_state.identifier = "current_id" + mock_current_state.outputs = {"field1": "output_value"} + mock_find = AsyncMock() + mock_find.to_list.return_value = [mock_current_state] + mock_find.set = AsyncMock() + mock_state_class.find.return_value = mock_find + mock_state_class.insert_many = AsyncMock() + + # Setup RegisteredNode mock + with patch('app.tasks.create_next_states.RegisteredNode') as mock_registered_node: + mock_registered_node_instance = MagicMock() + mock_registered_node_instance.inputs_schema = {"input1": {"type": "string"}} + mock_registered_node.get_by_name_and_namespace = AsyncMock(return_value=mock_registered_node_instance) + + with patch('app.tasks.create_next_states.create_model') as mock_create_model: + mock_input_model = MagicMock() + mock_input_model.model_fields = {"input1": MagicMock(annotation=str)} + mock_create_model.return_value = mock_input_model + + await create_next_states([PydanticObjectId()], "current_id", "test_namespace", "test_graph", {}) + + # Verify Store.get_value was called with correct parameters + mock_store.get_value.assert_called_once_with("test_run", "test_namespace", "test_graph", "test_field") + + @pytest.mark.asyncio + async def test_get_store_value_from_default(self): + """Test getting store value from default values when Store returns None""" + + with patch('app.tasks.create_next_states.GraphTemplate') as mock_graph_template, \ + patch('app.tasks.create_next_states.Store') as mock_store, \ + patch('app.tasks.create_next_states.State') as mock_state_class, \ + patch('app.tasks.create_next_states.validate_dependencies') as mock_validate: + + # Setup GraphTemplate mock with default values + mock_template = MagicMock() + mock_template.store_config = StoreConfig(default_values={"test_field": "default_value"}) + current_node = NodeTemplate( + node_name="test_node", + identifier="current_id", + namespace="test", + inputs={}, + next_nodes=["next_node"], + unites=None + ) + next_node = NodeTemplate( + node_name="next_node", + identifier="next_node", + namespace="test", + inputs={"input1": "${{store.test_field}}"}, + next_nodes=None, + unites=None + ) + def get_node_side_effect(identifier): + if identifier == "current_id": + return current_node + elif identifier == "next_node": + return next_node + return None + mock_template.get_node_by_identifier.side_effect = get_node_side_effect + mock_graph_template.get_valid = AsyncMock(return_value=mock_template) + + # Mock validate_dependencies to pass + mock_validate.return_value = None + + # Setup Store mock to return None (not found) + mock_store.get_value = AsyncMock(return_value=None) + + # Setup State mock + mock_state_class.id = "id" + mock_current_state = MagicMock() + mock_current_state.run_id = "test_run" + mock_current_state.identifier = "current_id" + mock_current_state.outputs = {"field1": "output_value"} + mock_find = AsyncMock() + mock_find.to_list.return_value = [mock_current_state] + mock_find.set = AsyncMock() + mock_state_class.find.return_value = mock_find + mock_state_class.insert_many = AsyncMock() + + # Setup RegisteredNode mock + with patch('app.tasks.create_next_states.RegisteredNode') as mock_registered_node: + mock_registered_node_instance = MagicMock() + mock_registered_node_instance.inputs_schema = {"input1": {"type": "string"}} + mock_registered_node.get_by_name_and_namespace = AsyncMock(return_value=mock_registered_node_instance) + + with patch('app.tasks.create_next_states.create_model') as mock_create_model: + mock_input_model = MagicMock() + mock_input_model.model_fields = {"input1": MagicMock(annotation=str)} + mock_create_model.return_value = mock_input_model + + # Should complete successfully using default value + await create_next_states([PydanticObjectId()], "current_id", "test_namespace", "test_graph", {}) + + # Verify Store.get_value was called + mock_store.get_value.assert_called_once_with("test_run", "test_namespace", "test_graph", "test_field") + + @pytest.mark.asyncio + async def test_get_store_value_not_found_error(self): + """Test error when store value is not found in Store or default values""" + + with patch('app.tasks.create_next_states.GraphTemplate') as mock_graph_template, \ + patch('app.tasks.create_next_states.Store') as mock_store, \ + patch('app.tasks.create_next_states.State') as mock_state_class, \ + patch('app.tasks.create_next_states.validate_dependencies') as mock_validate: + + # Setup GraphTemplate mock with no default values + mock_template = MagicMock() + mock_template.store_config = StoreConfig(default_values={}) + current_node = NodeTemplate( + node_name="test_node", + identifier="current_id", + namespace="test", + inputs={}, + next_nodes=["next_node"], + unites=None + ) + next_node = NodeTemplate( + node_name="next_node", + identifier="next_node", + namespace="test", + inputs={"input1": "${{store.missing_field}}"}, + next_nodes=None, + unites=None + ) + def get_node_side_effect(identifier): + if identifier == "current_id": + return current_node + elif identifier == "next_node": + return next_node + return None + mock_template.get_node_by_identifier.side_effect = get_node_side_effect + mock_graph_template.get_valid = AsyncMock(return_value=mock_template) + + # Mock validate_dependencies to pass + mock_validate.return_value = None + + # Setup Store mock to return None (not found) + mock_store.get_value = AsyncMock(return_value=None) + + # Setup State mock + mock_state_class.id = "id" + mock_current_state = MagicMock() + mock_current_state.run_id = "test_run" + mock_current_state.identifier = "current_id" + mock_current_state.outputs = {"field1": "output_value"} + mock_find = AsyncMock() + mock_find.to_list.return_value = [mock_current_state] + mock_find.set = AsyncMock() + mock_state_class.find.return_value = mock_find + mock_state_class.insert_many = AsyncMock() + + # Setup RegisteredNode mock + with patch('app.tasks.create_next_states.RegisteredNode') as mock_registered_node: + mock_registered_node_instance = MagicMock() + mock_registered_node_instance.inputs_schema = {"input1": {"type": "string"}} + mock_registered_node.get_by_name_and_namespace = AsyncMock(return_value=mock_registered_node_instance) + + with patch('app.tasks.create_next_states.create_model') as mock_create_model: + mock_input_model = MagicMock() + mock_input_model.model_fields = {"input1": MagicMock(annotation=str)} + mock_create_model.return_value = mock_input_model + + with pytest.raises(ValueError, match="Store value not found for field 'missing_field' in namespace 'test_namespace' and graph 'test_graph'"): + await create_next_states([PydanticObjectId()], "current_id", "test_namespace", "test_graph", {}) + + @pytest.mark.asyncio + async def test_get_store_value_multiple_fields_cache_isolation(self): + """Test that cache correctly isolates different run_id and field combinations""" + + with patch('app.tasks.create_next_states.GraphTemplate') as mock_graph_template, \ + patch('app.tasks.create_next_states.Store') as mock_store, \ + patch('app.tasks.create_next_states.State') as mock_state_class, \ + patch('app.tasks.create_next_states.validate_dependencies') as mock_validate: + + # Setup GraphTemplate mock + mock_template = MagicMock() + mock_template.store_config = StoreConfig(default_values={}) + current_node = NodeTemplate( + node_name="test_node", + identifier="current_id", + namespace="test", + inputs={}, + next_nodes=["next_node"], + unites=None + ) + next_node = NodeTemplate( + node_name="next_node", + identifier="next_node", + namespace="test", + inputs={"input1": "${{store.field1}}", "input2": "${{store.field2}}"}, + next_nodes=None, + unites=None + ) + def get_node_side_effect(identifier): + if identifier == "current_id": + return current_node + elif identifier == "next_node": + return next_node + return None + mock_template.get_node_by_identifier.side_effect = get_node_side_effect + mock_graph_template.get_valid = AsyncMock(return_value=mock_template) + + # Mock validate_dependencies to pass + mock_validate.return_value = None + + # Setup Store mock to return different values for different fields + def mock_get_value(run_id, namespace, graph_name, field): + if field == "field1": + return "value1" + elif field == "field2": + return "value2" + return None + + mock_store.get_value = AsyncMock(side_effect=mock_get_value) + + # Setup State mock + mock_state_class.id = "id" + mock_current_state = MagicMock() + mock_current_state.run_id = "test_run" + mock_current_state.identifier = "current_id" + mock_current_state.outputs = {"field1": "output_value"} + mock_find = AsyncMock() + mock_find.to_list.return_value = [mock_current_state] + mock_find.set = AsyncMock() + mock_state_class.find.return_value = mock_find + mock_state_class.insert_many = AsyncMock() + + # Setup RegisteredNode mock + with patch('app.tasks.create_next_states.RegisteredNode') as mock_registered_node: + mock_registered_node_instance = MagicMock() + mock_registered_node_instance.inputs_schema = {"input1": {"type": "string"}, "input2": {"type": "string"}} + mock_registered_node.get_by_name_and_namespace = AsyncMock(return_value=mock_registered_node_instance) + + with patch('app.tasks.create_next_states.create_model') as mock_create_model: + mock_input_model = MagicMock() + mock_input_model.model_fields = { + "input1": MagicMock(annotation=str), + "input2": MagicMock(annotation=str) + } + mock_create_model.return_value = mock_input_model + + await create_next_states([PydanticObjectId()], "current_id", "test_namespace", "test_graph", {}) + + # Verify Store.get_value was called for both fields + assert mock_store.get_value.call_count == 2 + mock_store.get_value.assert_any_call("test_run", "test_namespace", "test_graph", "field1") + mock_store.get_value.assert_any_call("test_run", "test_namespace", "test_graph", "field2") + + @pytest.mark.asyncio + async def test_get_store_value_default_fallback(self): + """Test that default values are used when Store.get_value returns None""" + + with patch('app.tasks.create_next_states.GraphTemplate') as mock_graph_template, \ + patch('app.tasks.create_next_states.Store') as mock_store, \ + patch('app.tasks.create_next_states.State') as mock_state_class, \ + patch('app.tasks.create_next_states.validate_dependencies') as mock_validate: + + # Setup GraphTemplate mock with default values + mock_template = MagicMock() + mock_template.store_config = StoreConfig(default_values={"test_field": "default_value"}) + current_node = NodeTemplate( + node_name="test_node", + identifier="current_id", + namespace="test", + inputs={}, + next_nodes=["next_node"], + unites=None + ) + next_node = NodeTemplate( + node_name="next_node", + identifier="next_node", + namespace="test", + inputs={"input1": "${{store.test_field}}"}, + next_nodes=None, + unites=None + ) + def get_node_side_effect(identifier): + if identifier == "current_id": + return current_node + elif identifier == "next_node": + return next_node + return None + mock_template.get_node_by_identifier.side_effect = get_node_side_effect + mock_graph_template.get_valid = AsyncMock(return_value=mock_template) + + # Mock validate_dependencies to pass + mock_validate.return_value = None + + # Setup Store mock to return None + mock_store.get_value = AsyncMock(return_value=None) + + # Setup State mock + mock_state_class.id = "id" + mock_current_state = MagicMock() + mock_current_state.run_id = "test_run" + mock_current_state.identifier = "current_id" + mock_current_state.outputs = {"field1": "output_value"} + mock_find = AsyncMock() + mock_find.to_list.return_value = [mock_current_state] + mock_find.set = AsyncMock() + mock_state_class.find.return_value = mock_find + mock_state_class.insert_many = AsyncMock() + + # Setup RegisteredNode mock + with patch('app.tasks.create_next_states.RegisteredNode') as mock_registered_node: + mock_registered_node_instance = MagicMock() + mock_registered_node_instance.inputs_schema = {"input1": {"type": "string"}} + mock_registered_node.get_by_name_and_namespace = AsyncMock(return_value=mock_registered_node_instance) + + with patch('app.tasks.create_next_states.create_model') as mock_create_model: + mock_input_model = MagicMock() + mock_input_model.model_fields = {"input1": MagicMock(annotation=str)} + mock_create_model.return_value = mock_input_model + + # Should complete successfully using default value + await create_next_states([PydanticObjectId()], "current_id", "test_namespace", "test_graph", {}) + + # Verify Store.get_value was called + mock_store.get_value.assert_called_once_with("test_run", "test_namespace", "test_graph", "test_field") + + @pytest.mark.asyncio + async def test_get_store_value_cache_key_isolation(self): + """Test that cache keys properly isolate different run_id and field combinations""" + + # This test ensures that (run_id1, field1) is cached separately from (run_id2, field1) and (run_id1, field2) + with patch('app.tasks.create_next_states.GraphTemplate') as mock_graph_template, \ + patch('app.tasks.create_next_states.Store') as mock_store, \ + patch('app.tasks.create_next_states.State') as mock_state_class, \ + patch('app.tasks.create_next_states.validate_dependencies') as mock_validate: + + # Setup GraphTemplate mock + mock_template = MagicMock() + mock_template.store_config = StoreConfig(default_values={}) + current_node = NodeTemplate( + node_name="test_node", + identifier="current_id", + namespace="test", + inputs={}, + next_nodes=["next_node"], + unites=None + ) + next_node = NodeTemplate( + node_name="next_node", + identifier="next_node", + namespace="test", + inputs={"input1": "${{store.test_field}}"}, + next_nodes=None, + unites=None + ) + def get_node_side_effect(identifier): + if identifier == "current_id": + return current_node + elif identifier == "next_node": + return next_node + return None + mock_template.get_node_by_identifier.side_effect = get_node_side_effect + mock_graph_template.get_valid = AsyncMock(return_value=mock_template) + + # Mock validate_dependencies to pass + mock_validate.return_value = None + + # Setup Store mock to return different values based on run_id + def mock_get_value(run_id, namespace, graph_name, field): + return f"value_{run_id}_{field}" + + mock_store.get_value = AsyncMock(side_effect=mock_get_value) + + # Setup State mock for first run + mock_state_class.id = "id" + mock_current_state1 = MagicMock() + mock_current_state1.run_id = "run1" + mock_current_state1.identifier = "current_id" + mock_current_state1.outputs = {"field1": "output_value"} + + mock_current_state2 = MagicMock() + mock_current_state2.run_id = "run2" + mock_current_state2.identifier = "current_id" + mock_current_state2.outputs = {"field1": "output_value"} + + mock_find = AsyncMock() + mock_find.to_list.side_effect = [[mock_current_state1], [mock_current_state2]] + mock_find.set = AsyncMock() + mock_state_class.find.return_value = mock_find + mock_state_class.insert_many = AsyncMock() + + # Setup RegisteredNode mock + with patch('app.tasks.create_next_states.RegisteredNode') as mock_registered_node: + mock_registered_node_instance = MagicMock() + mock_registered_node_instance.inputs_schema = {"input1": {"type": "string"}} + mock_registered_node.get_by_name_and_namespace = AsyncMock(return_value=mock_registered_node_instance) + + with patch('app.tasks.create_next_states.create_model') as mock_create_model: + mock_input_model = MagicMock() + mock_input_model.model_fields = {"input1": MagicMock(annotation=str)} + mock_create_model.return_value = mock_input_model + + # First call with run1 + await create_next_states([PydanticObjectId()], "current_id", "test_namespace", "test_graph", {}) + + # Second call with run2 + await create_next_states([PydanticObjectId()], "current_id", "test_namespace", "test_graph", {}) + + # Verify Store.get_value was called twice with different run_ids + assert mock_store.get_value.call_count == 2 + mock_store.get_value.assert_any_call("run1", "test_namespace", "test_graph", "test_field") + mock_store.get_value.assert_any_call("run2", "test_namespace", "test_graph", "test_field") + + @pytest.mark.asyncio + async def test_get_store_value_exception_handling(self): + """Test that exceptions from Store.get_value are properly propagated""" + + with patch('app.tasks.create_next_states.GraphTemplate') as mock_graph_template, \ + patch('app.tasks.create_next_states.Store') as mock_store, \ + patch('app.tasks.create_next_states.State') as mock_state_class, \ + patch('app.tasks.create_next_states.validate_dependencies') as mock_validate: + + # Setup GraphTemplate mock + mock_template = MagicMock() + mock_template.store_config = StoreConfig(default_values={}) + current_node = NodeTemplate( + node_name="test_node", + identifier="current_id", + namespace="test", + inputs={}, + next_nodes=["next_node"], + unites=None + ) + next_node = NodeTemplate( + node_name="next_node", + identifier="next_node", + namespace="test", + inputs={"input1": "${{store.test_field}}"}, + next_nodes=None, + unites=None + ) + def get_node_side_effect(identifier): + if identifier == "current_id": + return current_node + elif identifier == "next_node": + return next_node + return None + mock_template.get_node_by_identifier.side_effect = get_node_side_effect + mock_graph_template.get_valid = AsyncMock(return_value=mock_template) + + # Mock validate_dependencies to pass + mock_validate.return_value = None + + # Setup Store mock to raise an exception + mock_store.get_value = AsyncMock(side_effect=Exception("Database connection error")) + + # Setup State mock + mock_state_class.id = "id" + mock_current_state = MagicMock() + mock_current_state.run_id = "test_run" + mock_current_state.identifier = "current_id" + mock_current_state.outputs = {"field1": "output_value"} + mock_find = AsyncMock() + mock_find.to_list.return_value = [mock_current_state] + mock_find.set = AsyncMock() + mock_state_class.find.return_value = mock_find + mock_state_class.insert_many = AsyncMock() + + # Setup RegisteredNode mock + with patch('app.tasks.create_next_states.RegisteredNode') as mock_registered_node: + mock_registered_node_instance = MagicMock() + mock_registered_node_instance.inputs_schema = {"input1": {"type": "string"}} + mock_registered_node.get_by_name_and_namespace = AsyncMock(return_value=mock_registered_node_instance) + + with patch('app.tasks.create_next_states.create_model') as mock_create_model: + mock_input_model = MagicMock() + mock_input_model.model_fields = {"input1": MagicMock(annotation=str)} + mock_create_model.return_value = mock_input_model + + with pytest.raises(Exception, match="Database connection error"): + await create_next_states([PydanticObjectId()], "current_id", "test_namespace", "test_graph", {}) diff --git a/state-manager/tests/unit/test_main.py b/state-manager/tests/unit/test_main.py index e7042b38..8d03fe8b 100644 --- a/state-manager/tests/unit/test_main.py +++ b/state-manager/tests/unit/test_main.py @@ -207,8 +207,9 @@ async def test_lifespan_init_beanie_with_correct_models(self, mock_logs_manager, from app.models.db.state import State from app.models.db.graph_template_model import GraphTemplate from app.models.db.registered_node import RegisteredNode - - expected_models = [State, GraphTemplate, RegisteredNode] + from app.models.db.store import Store + + expected_models = [State, GraphTemplate, RegisteredNode, Store] assert document_models == expected_models diff --git a/state-manager/tests/unit/test_routes.py b/state-manager/tests/unit/test_routes.py index 82cce20b..85e6157f 100644 --- a/state-manager/tests/unit/test_routes.py +++ b/state-manager/tests/unit/test_routes.py @@ -1,6 +1,6 @@ from app.routes import router from app.models.enqueue_request import EnqueueRequestModel -from app.models.create_models import TriggerGraphRequestModel, CreateRequestModel +from app.models.trigger_model import TriggerGraphRequestModel from app.models.executed_models import ExecutedRequestModel from app.models.errored_models import ErroredRequestModel from app.models.graph_models import UpsertGraphTemplateRequest, UpsertGraphTemplateResponse @@ -26,7 +26,7 @@ def test_router_has_correct_routes(self): # State management routes assert any('/v0/namespace/{namespace_name}/states/enqueue' in path for path in paths) assert any('/v0/namespace/{namespace_name}/graph/{graph_name}/trigger' in path for path in paths) - assert any('/v0/namespace/{namespace_name}/graph/{graph_name}/states/create' in path for path in paths) + # Removed deprecated create states route assertion assert any('/v0/namespace/{namespace_name}/states/{state_id}/executed' in path for path in paths) assert any('/v0/namespace/{namespace_name}/states/{state_id}/errored' in path for path in paths) assert any('/v0/namespace/{namespace_name}/states/{state_id}/prune' in path for path in paths) @@ -80,36 +80,13 @@ def test_enqueue_request_model_validation(self): def test_trigger_graph_request_model_validation(self): """Test TriggerGraphRequestModel validation""" - # Test with valid data - valid_data = { - "states": [ - { - "identifier": "node1", - "inputs": {"input1": "value1"} - } - ] - } - model = TriggerGraphRequestModel(**valid_data) # type: ignore - assert len(model.states) == 1 - assert model.states[0].identifier == "node1" - assert model.states[0].inputs == {"input1": "value1"} - - def test_create_request_model_validation(self): - """Test CreateRequestModel validation""" - # Test with valid data valid_data = { - "run_id": "test-run-id", - "states": [ - { - "identifier": "node1", - "inputs": {"input1": "value1"} - } - ] + "store": {"s1": "v1"}, + "inputs": {"input1": "value1"} } - model = CreateRequestModel(**valid_data) - assert model.run_id == "test-run-id" - assert len(model.states) == 1 - assert model.states[0].identifier == "node1" + model = TriggerGraphRequestModel(**valid_data) + assert model.store == {"s1": "v1"} + assert model.inputs == {"input1": "value1"} def test_prune_request_model_validation(self): """Test PruneRequestModel validation""" @@ -331,7 +308,6 @@ def test_route_handlers_exist(self): from app.routes import ( enqueue_state, trigger_graph_route, - create_state, executed_state_route, errored_state_route, upsert_graph_template, @@ -347,7 +323,6 @@ def test_route_handlers_exist(self): # Verify all handlers are callable assert callable(enqueue_state) assert callable(trigger_graph_route) - assert callable(create_state) assert callable(executed_state_route) assert callable(errored_state_route) assert callable(upsert_graph_template) @@ -440,11 +415,10 @@ async def test_enqueue_state_without_request_id(self, mock_enqueue_states, mock_ async def test_trigger_graph_route_with_valid_api_key(self, mock_trigger_graph, mock_request): """Test trigger_graph_route with valid API key""" from app.routes import trigger_graph_route - from app.models.create_models import TriggerGraphRequestModel # Arrange mock_trigger_graph.return_value = MagicMock() - body = TriggerGraphRequestModel(states=[]) + body = TriggerGraphRequestModel() # Act result = await trigger_graph_route("test_namespace", "test_graph", body, mock_request, "valid_key") @@ -457,11 +431,10 @@ async def test_trigger_graph_route_with_valid_api_key(self, mock_trigger_graph, async def test_trigger_graph_route_with_invalid_api_key(self, mock_trigger_graph, mock_request): """Test trigger_graph_route with invalid API key""" from app.routes import trigger_graph_route - from app.models.create_models import TriggerGraphRequestModel from fastapi import HTTPException # Arrange - body = TriggerGraphRequestModel(states=[]) + body = TriggerGraphRequestModel() # Act & Assert with pytest.raises(HTTPException) as exc_info: @@ -470,22 +443,11 @@ async def test_trigger_graph_route_with_invalid_api_key(self, mock_trigger_graph assert exc_info.value.status_code == 401 assert exc_info.value.detail == "Invalid API key" - @patch('app.routes.create_states') - async def test_create_state_with_valid_api_key(self, mock_create_states, mock_request): - """Test create_state with valid API key""" - from app.routes import create_state - from app.models.create_models import CreateRequestModel - - # Arrange - mock_create_states.return_value = MagicMock() - body = CreateRequestModel(run_id="test_run", states=[]) - - # Act - result = await create_state("test_namespace", "test_graph", body, mock_request, "valid_key") - - # Assert - mock_create_states.assert_called_once_with("test_namespace", "test_graph", body, "test-request-id") - assert result == mock_create_states.return_value + def test_no_create_state_route(self): + from app.routes import router + routes = [route for route in router.routes if hasattr(route, 'path')] + paths = [route.path for route in routes] # type: ignore + assert not any('/v0/namespace/{namespace_name}/graph/{graph_name}/states/create' in path for path in paths) @patch('app.routes.executed_state') async def test_executed_state_route_with_valid_api_key(self, mock_executed_state, mock_request, mock_background_tasks):