diff --git a/python-sdk/exospherehost/_version.py b/python-sdk/exospherehost/_version.py index 1ec1aab7..1295569a 100644 --- a/python-sdk/exospherehost/_version.py +++ b/python-sdk/exospherehost/_version.py @@ -1 +1 @@ -version = "0.0.7b9" \ No newline at end of file +version = "0.0.7b10" \ No newline at end of file diff --git a/python-sdk/exospherehost/runtime.py b/python-sdk/exospherehost/runtime.py index 3ceadcbe..89786c58 100644 --- a/python-sdk/exospherehost/runtime.py +++ b/python-sdk/exospherehost/runtime.py @@ -293,7 +293,11 @@ async def _get_secrets(self, state_id: str) -> Dict[str, str]: logger.error(f"Failed to get secrets for state {state_id}: {res}") return {} - return res + if "secrets" in res: + return res["secrets"] + else: + logger.error(f"'secrets' not found in response for state {state_id}") + return {} def _validate_nodes(self): """ @@ -352,6 +356,12 @@ def _validate_nodes(self): if len(errors) > 0: raise ValueError("Following errors while validating nodes: " + "\n".join(errors)) + def _need_secrets(self, node: type[BaseNode]) -> bool: + """ + Check if the node needs secrets. + """ + return len(node.Secrets.model_fields.keys()) > 0 + async def _worker(self, idx: int): """ Worker task that processes states from the queue. @@ -369,10 +379,12 @@ async def _worker(self, idx: int): node = self._node_mapping[state["node_name"]] logger.info(f"Executing state {state['state_id']} for node {node.__name__}") - secrets = await self._get_secrets(state["state_id"]) - logger.info(f"Got secrets for state {state['state_id']} for node {node.__name__}") + secrets = {} + if self._need_secrets(node): + secrets = await self._get_secrets(state["state_id"]) + logger.info(f"Got secrets for state {state['state_id']} for node {node.__name__}") - outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"])) # type: ignore + outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets)) logger.info(f"Got outputs for state {state['state_id']} for node {node.__name__}") if outputs is None: diff --git a/python-sdk/tests/test_base_node_abstract.py b/python-sdk/tests/test_base_node_abstract.py new file mode 100644 index 00000000..8476411a --- /dev/null +++ b/python-sdk/tests/test_base_node_abstract.py @@ -0,0 +1,87 @@ +import pytest +from pydantic import BaseModel +from exospherehost.node.BaseNode import BaseNode + + +class TestBaseNodeAbstract: + """Test the abstract BaseNode class and its NotImplementedError.""" + + def test_base_node_abstract_execute(self): + """Test that BaseNode.execute raises NotImplementedError.""" + # Create a concrete subclass that implements execute but raises NotImplementedError + class ConcreteNode(BaseNode): + class Inputs(BaseModel): + name: str + + class Outputs(BaseModel): + message: str + + class Secrets(BaseModel): + pass + + async def execute(self): + raise NotImplementedError("execute method must be implemented by all concrete node classes") + + node = ConcreteNode() + + with pytest.raises(NotImplementedError, match="execute method must be implemented by all concrete node classes"): + # This should raise NotImplementedError + import asyncio + asyncio.run(node.execute()) + + def test_base_node_abstract_execute_with_inputs(self): + """Test that BaseNode._execute raises NotImplementedError when execute is not implemented.""" + # Create a concrete subclass that implements execute but raises NotImplementedError + class ConcreteNode(BaseNode): + class Inputs(BaseModel): + name: str + + class Outputs(BaseModel): + message: str + + class Secrets(BaseModel): + pass + + async def execute(self): + raise NotImplementedError("execute method must be implemented by all concrete node classes") + + node = ConcreteNode() + + with pytest.raises(NotImplementedError, match="execute method must be implemented by all concrete node classes"): + # This should raise NotImplementedError + import asyncio + asyncio.run(node._execute(node.Inputs(name="test"), node.Secrets())) # type: ignore + + def test_base_node_initialization(self): + """Test that BaseNode initializes correctly.""" + # Create a concrete subclass + class ConcreteNode(BaseNode): + class Inputs(BaseModel): + name: str + + class Outputs(BaseModel): + message: str + + class Secrets(BaseModel): + pass + + async def execute(self): + return self.Outputs(message="test") + + node = ConcreteNode() + assert node.inputs is None + + def test_base_node_inputs_class(self): + """Test that BaseNode has Inputs class.""" + assert hasattr(BaseNode, 'Inputs') + assert issubclass(BaseNode.Inputs, BaseModel) + + def test_base_node_outputs_class(self): + """Test that BaseNode has Outputs class.""" + assert hasattr(BaseNode, 'Outputs') + assert issubclass(BaseNode.Outputs, BaseModel) + + def test_base_node_secrets_class(self): + """Test that BaseNode has Secrets class.""" + assert hasattr(BaseNode, 'Secrets') + assert issubclass(BaseNode.Secrets, BaseModel) \ No newline at end of file diff --git a/python-sdk/tests/test_runtime_comprehensive.py b/python-sdk/tests/test_runtime_comprehensive.py index a5f4b346..8013629f 100644 --- a/python-sdk/tests/test_runtime_comprehensive.py +++ b/python-sdk/tests/test_runtime_comprehensive.py @@ -282,7 +282,7 @@ async def test_worker_successful_execution(self, runtime_config): with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \ patch('exospherehost.runtime.Runtime._notify_executed') as mock_notify_executed: - mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}} + mock_get_secrets.return_value = {"api_key": "test_key"} mock_notify_executed.return_value = None runtime = Runtime(**runtime_config) @@ -327,7 +327,7 @@ async def test_worker_with_list_output(self, runtime_config): with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \ patch('exospherehost.runtime.Runtime._notify_executed') as mock_notify_executed: - mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}} + mock_get_secrets.return_value = {"api_key": "test_key"} mock_notify_executed.return_value = None runtime = Runtime(**runtime_config) @@ -362,7 +362,7 @@ async def test_worker_with_none_output(self, runtime_config): with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \ patch('exospherehost.runtime.Runtime._notify_executed') as mock_notify_executed: - mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}} + mock_get_secrets.return_value = {"api_key": "test_key"} mock_notify_executed.return_value = None runtime = Runtime(**runtime_config) @@ -394,7 +394,7 @@ async def test_worker_execution_error(self, runtime_config): with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \ patch('exospherehost.runtime.Runtime._notify_errored') as mock_notify_errored: - mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}} + mock_get_secrets.return_value = {"api_key": "test_key"} mock_notify_errored.return_value = None runtime = Runtime(**runtime_config) @@ -511,7 +511,7 @@ async def test_get_secrets_success(self, runtime_config): runtime = Runtime(**runtime_config) result = await runtime._get_secrets("test_state_1") - assert result == {"secrets": {"api_key": "secret_key"}} + assert result == {"api_key": "secret_key"} @pytest.mark.asyncio async def test_get_secrets_failure(self, runtime_config): diff --git a/python-sdk/tests/test_runtime_edge_cases.py b/python-sdk/tests/test_runtime_edge_cases.py new file mode 100644 index 00000000..62f82ff7 --- /dev/null +++ b/python-sdk/tests/test_runtime_edge_cases.py @@ -0,0 +1,205 @@ +import pytest +import asyncio +import warnings +from unittest.mock import AsyncMock, patch, MagicMock +from pydantic import BaseModel +from exospherehost.runtime import Runtime, _setup_default_logging +from exospherehost.node.BaseNode import BaseNode + + +class MockTestNode(BaseNode): + class Inputs(BaseModel): + name: str + + class Outputs(BaseModel): + message: str + + class Secrets(BaseModel): + api_key: str + + async def execute(self): + return self.Outputs(message=f"Hello {self.inputs.name}") # type: ignore + + +class MockTestNodeWithNonStringFields(BaseNode): + class Inputs(BaseModel): + name: str + count: int # This should cause validation error + + class Outputs(BaseModel): + message: str + + class Secrets(BaseModel): + api_key: str + + async def execute(self): + return self.Outputs(message=f"Hello {self.inputs.name}") # type: ignore + + +class MockTestNodeWithoutSecrets(BaseNode): + class Inputs(BaseModel): + name: str + + class Outputs(BaseModel): + message: str + + class Secrets(BaseModel): + pass # Empty secrets + + async def execute(self): + return self.Outputs(message=f"Hello {self.inputs.name}") # type: ignore + + +class MockTestNodeWithError(BaseNode): + class Inputs(BaseModel): + should_fail: str + + class Outputs(BaseModel): + result: str + + class Secrets(BaseModel): + api_key: str + + async def execute(self): + if self.inputs.should_fail == "true": # type: ignore + raise ValueError("Test error") + return self.Outputs(result="success") + + +class TestRuntimeEdgeCases: + """Test edge cases and error handling scenarios in the Runtime class.""" + + def test_setup_default_logging_disabled(self, monkeypatch): + """Test that _setup_default_logging returns early when disabled.""" + monkeypatch.setenv('EXOSPHERE_DISABLE_DEFAULT_LOGGING', 'true') + + # This should not raise any exceptions and should return early + _setup_default_logging() + + def test_setup_default_logging_invalid_level(self, monkeypatch): + """Test _setup_default_logging with invalid log level.""" + monkeypatch.setenv('EXOSPHERE_LOG_LEVEL', 'INVALID_LEVEL') + + # Should fall back to INFO level + _setup_default_logging() + + def test_runtime_validation_non_string_fields(self): + """Test that Runtime validates node fields are strings.""" + with pytest.raises(ValueError, match="must be of type str"): + Runtime( + namespace="test", + name="test", + nodes=[MockTestNodeWithNonStringFields], + state_manager_uri="http://localhost:8080", + key="test_key" + ) + + def test_runtime_validation_duplicate_node_names(self): + """Test that Runtime validates no duplicate node names.""" + # Create two classes with the same name + class TestNode1(MockTestNode): + pass + + class TestNode2(MockTestNode): + pass + + # Rename the second class to have the same name as the first + TestNode2.__name__ = "TestNode1" + + # Suppress the RuntimeWarning about unawaited coroutines + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=".*coroutine.*was never awaited.*", category=RuntimeWarning) + with pytest.raises(ValueError, match="Duplicate node class names found"): + Runtime( + namespace="test", + name="test", + nodes=[TestNode1, TestNode2], + state_manager_uri="http://localhost:8080", + key="test_key" + ) + + def test_need_secrets_empty_secrets(self): + """Test _need_secrets with empty secrets class.""" + runtime = Runtime( + namespace="test", + name="test", + nodes=[MockTestNodeWithoutSecrets], + state_manager_uri="http://localhost:8080", + key="test_key" + ) + + # Should return False for empty secrets + assert not runtime._need_secrets(MockTestNodeWithoutSecrets) + + def test_need_secrets_with_secrets(self): + """Test _need_secrets with secrets class that has fields.""" + runtime = Runtime( + namespace="test", + name="test", + nodes=[MockTestNode], + state_manager_uri="http://localhost:8080", + key="test_key" + ) + + # Should return True for secrets with fields + assert runtime._need_secrets(MockTestNode) + + @pytest.mark.asyncio + async def test_enqueue_error_handling(self): + """Test error handling in _enqueue method.""" + runtime = Runtime( + namespace="test", + name="test", + nodes=[MockTestNode], + state_manager_uri="http://localhost:8080", + key="test_key" + ) + + # Mock _enqueue_call to raise an exception + with patch.object(runtime, '_enqueue_call', side_effect=Exception("Test error")): + # This should not raise an exception but log an error + # We'll test this by checking that the method doesn't crash + task = asyncio.create_task(runtime._enqueue()) + await asyncio.sleep(0.1) # Let it run briefly + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + def test_start_without_running_loop(self): + """Test start method when no event loop is running.""" + runtime = Runtime( + namespace="test", + name="test", + nodes=[MockTestNode], + state_manager_uri="http://localhost:8080", + key="test_key" + ) + + # Mock _start to avoid actual execution + with patch.object(runtime, '_start', new_callable=AsyncMock): + # This should not raise an exception + result = runtime.start() + assert result is None + + def test_start_with_running_loop(self): + """Test start method when an event loop is already running.""" + runtime = Runtime( + namespace="test", + name="test", + nodes=[MockTestNode], + state_manager_uri="http://localhost:8080", + key="test_key" + ) + + # Mock _start to avoid actual execution + with patch.object(runtime, '_start', new_callable=AsyncMock): + # Create a mock loop + mock_loop = MagicMock() + mock_task = MagicMock() + mock_loop.create_task.return_value = mock_task + + with patch('asyncio.get_running_loop', return_value=mock_loop): + result = runtime.start() + assert result == mock_task \ No newline at end of file diff --git a/python-sdk/tests/test_runtime_validation.py b/python-sdk/tests/test_runtime_validation.py index 95f30e65..d2f04a15 100644 --- a/python-sdk/tests/test_runtime_validation.py +++ b/python-sdk/tests/test_runtime_validation.py @@ -1,4 +1,5 @@ import pytest +import warnings from pydantic import BaseModel from exospherehost.runtime import Runtime from exospherehost.node.BaseNode import BaseNode @@ -81,7 +82,9 @@ def test_node_validation_errors(monkeypatch): def test_duplicate_node_names_raise(monkeypatch): monkeypatch.setenv("EXOSPHERE_STATE_MANAGER_URI", "http://sm") monkeypatch.setenv("EXOSPHERE_API_KEY", "k") - class AnotherGood(BaseNode): + + # Create two classes with the same name using a different approach + class GoodNode1(BaseNode): class Inputs(BaseModel): name: str class Outputs(BaseModel): @@ -90,6 +93,22 @@ class Secrets(BaseModel): api_key: str async def execute(self): return self.Outputs(message="ok") - AnotherGood.__name__ = "GoodNode" # force duplicate name - with pytest.raises(ValueError): - Runtime(namespace="ns", name="rt", nodes=[GoodNode, AnotherGood]) \ No newline at end of file + + class GoodNode2(BaseNode): + class Inputs(BaseModel): + name: str + class Outputs(BaseModel): + message: str + class Secrets(BaseModel): + api_key: str + async def execute(self): + return self.Outputs(message="ok") + + # Use the same name for both classes + GoodNode2.__name__ = "GoodNode1" + + # Suppress the RuntimeWarning about unawaited coroutines + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message=".*coroutine.*was never awaited.*", category=RuntimeWarning) + with pytest.raises(ValueError): + Runtime(namespace="ns", name="rt", nodes=[GoodNode1, GoodNode2]) \ No newline at end of file