Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python-sdk/exospherehost/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.0.7b9"
version = "0.0.7b10"
20 changes: 16 additions & 4 deletions python-sdk/exospherehost/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,11 @@ async def _get_secrets(self, state_id: str) -> Dict[str, str]:
logger.error(f"Failed to get secrets for state {state_id}: {res}")
return {}

return res
if "secrets" in res:
return res["secrets"]
else:
logger.error(f"'secrets' not found in response for state {state_id}")
return {}

def _validate_nodes(self):
"""
Expand Down Expand Up @@ -352,6 +356,12 @@ def _validate_nodes(self):
if len(errors) > 0:
raise ValueError("Following errors while validating nodes: " + "\n".join(errors))

def _need_secrets(self, node: type[BaseNode]) -> bool:
"""
Check if the node needs secrets.
"""
return len(node.Secrets.model_fields.keys()) > 0

async def _worker(self, idx: int):
"""
Worker task that processes states from the queue.
Expand All @@ -369,10 +379,12 @@ async def _worker(self, idx: int):
node = self._node_mapping[state["node_name"]]
logger.info(f"Executing state {state['state_id']} for node {node.__name__}")

secrets = await self._get_secrets(state["state_id"])
logger.info(f"Got secrets for state {state['state_id']} for node {node.__name__}")
secrets = {}
if self._need_secrets(node):
secrets = await self._get_secrets(state["state_id"])
logger.info(f"Got secrets for state {state['state_id']} for node {node.__name__}")

outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets["secrets"])) # type: ignore
outputs = await node()._execute(node.Inputs(**state["inputs"]), node.Secrets(**secrets))
logger.info(f"Got outputs for state {state['state_id']} for node {node.__name__}")

if outputs is None:
Expand Down
87 changes: 87 additions & 0 deletions python-sdk/tests/test_base_node_abstract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import pytest
from pydantic import BaseModel
from exospherehost.node.BaseNode import BaseNode


class TestBaseNodeAbstract:
"""Test the abstract BaseNode class and its NotImplementedError."""

def test_base_node_abstract_execute(self):
"""Test that BaseNode.execute raises NotImplementedError."""
# Create a concrete subclass that implements execute but raises NotImplementedError
class ConcreteNode(BaseNode):
class Inputs(BaseModel):
name: str

class Outputs(BaseModel):
message: str

class Secrets(BaseModel):
pass

async def execute(self):
raise NotImplementedError("execute method must be implemented by all concrete node classes")

node = ConcreteNode()

with pytest.raises(NotImplementedError, match="execute method must be implemented by all concrete node classes"):
# This should raise NotImplementedError
import asyncio
asyncio.run(node.execute())

def test_base_node_abstract_execute_with_inputs(self):
"""Test that BaseNode._execute raises NotImplementedError when execute is not implemented."""
# Create a concrete subclass that implements execute but raises NotImplementedError
class ConcreteNode(BaseNode):
class Inputs(BaseModel):
name: str

class Outputs(BaseModel):
message: str

class Secrets(BaseModel):
pass

async def execute(self):
raise NotImplementedError("execute method must be implemented by all concrete node classes")

node = ConcreteNode()

with pytest.raises(NotImplementedError, match="execute method must be implemented by all concrete node classes"):
# This should raise NotImplementedError
import asyncio
asyncio.run(node._execute(node.Inputs(name="test"), node.Secrets())) # type: ignore

def test_base_node_initialization(self):
"""Test that BaseNode initializes correctly."""
# Create a concrete subclass
class ConcreteNode(BaseNode):
class Inputs(BaseModel):
name: str

class Outputs(BaseModel):
message: str

class Secrets(BaseModel):
pass

async def execute(self):
return self.Outputs(message="test")

node = ConcreteNode()
assert node.inputs is None

def test_base_node_inputs_class(self):
"""Test that BaseNode has Inputs class."""
assert hasattr(BaseNode, 'Inputs')
assert issubclass(BaseNode.Inputs, BaseModel)

def test_base_node_outputs_class(self):
"""Test that BaseNode has Outputs class."""
assert hasattr(BaseNode, 'Outputs')
assert issubclass(BaseNode.Outputs, BaseModel)

def test_base_node_secrets_class(self):
"""Test that BaseNode has Secrets class."""
assert hasattr(BaseNode, 'Secrets')
assert issubclass(BaseNode.Secrets, BaseModel)
10 changes: 5 additions & 5 deletions python-sdk/tests/test_runtime_comprehensive.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def test_worker_successful_execution(self, runtime_config):
with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \
patch('exospherehost.runtime.Runtime._notify_executed') as mock_notify_executed:

mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}}
mock_get_secrets.return_value = {"api_key": "test_key"}
mock_notify_executed.return_value = None

runtime = Runtime(**runtime_config)
Expand Down Expand Up @@ -327,7 +327,7 @@ async def test_worker_with_list_output(self, runtime_config):
with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \
patch('exospherehost.runtime.Runtime._notify_executed') as mock_notify_executed:

mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}}
mock_get_secrets.return_value = {"api_key": "test_key"}
mock_notify_executed.return_value = None

runtime = Runtime(**runtime_config)
Expand Down Expand Up @@ -362,7 +362,7 @@ async def test_worker_with_none_output(self, runtime_config):
with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \
patch('exospherehost.runtime.Runtime._notify_executed') as mock_notify_executed:

mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}}
mock_get_secrets.return_value = {"api_key": "test_key"}
mock_notify_executed.return_value = None

runtime = Runtime(**runtime_config)
Expand Down Expand Up @@ -394,7 +394,7 @@ async def test_worker_execution_error(self, runtime_config):
with patch('exospherehost.runtime.Runtime._get_secrets') as mock_get_secrets, \
patch('exospherehost.runtime.Runtime._notify_errored') as mock_notify_errored:

mock_get_secrets.return_value = {"secrets": {"api_key": "test_key"}}
mock_get_secrets.return_value = {"api_key": "test_key"}
mock_notify_errored.return_value = None

runtime = Runtime(**runtime_config)
Expand Down Expand Up @@ -511,7 +511,7 @@ async def test_get_secrets_success(self, runtime_config):
runtime = Runtime(**runtime_config)
result = await runtime._get_secrets("test_state_1")

assert result == {"secrets": {"api_key": "secret_key"}}
assert result == {"api_key": "secret_key"}

@pytest.mark.asyncio
async def test_get_secrets_failure(self, runtime_config):
Expand Down
205 changes: 205 additions & 0 deletions python-sdk/tests/test_runtime_edge_cases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import pytest
import asyncio
import warnings
from unittest.mock import AsyncMock, patch, MagicMock
from pydantic import BaseModel
from exospherehost.runtime import Runtime, _setup_default_logging
from exospherehost.node.BaseNode import BaseNode


class MockTestNode(BaseNode):
class Inputs(BaseModel):
name: str

class Outputs(BaseModel):
message: str

class Secrets(BaseModel):
api_key: str

async def execute(self):
return self.Outputs(message=f"Hello {self.inputs.name}") # type: ignore


class MockTestNodeWithNonStringFields(BaseNode):
class Inputs(BaseModel):
name: str
count: int # This should cause validation error

class Outputs(BaseModel):
message: str

class Secrets(BaseModel):
api_key: str

async def execute(self):
return self.Outputs(message=f"Hello {self.inputs.name}") # type: ignore


class MockTestNodeWithoutSecrets(BaseNode):
class Inputs(BaseModel):
name: str

class Outputs(BaseModel):
message: str

class Secrets(BaseModel):
pass # Empty secrets

async def execute(self):
return self.Outputs(message=f"Hello {self.inputs.name}") # type: ignore


class MockTestNodeWithError(BaseNode):
class Inputs(BaseModel):
should_fail: str

class Outputs(BaseModel):
result: str

class Secrets(BaseModel):
api_key: str

async def execute(self):
if self.inputs.should_fail == "true": # type: ignore
raise ValueError("Test error")
return self.Outputs(result="success")


class TestRuntimeEdgeCases:
"""Test edge cases and error handling scenarios in the Runtime class."""

def test_setup_default_logging_disabled(self, monkeypatch):
"""Test that _setup_default_logging returns early when disabled."""
monkeypatch.setenv('EXOSPHERE_DISABLE_DEFAULT_LOGGING', 'true')

# This should not raise any exceptions and should return early
_setup_default_logging()

def test_setup_default_logging_invalid_level(self, monkeypatch):
"""Test _setup_default_logging with invalid log level."""
monkeypatch.setenv('EXOSPHERE_LOG_LEVEL', 'INVALID_LEVEL')

# Should fall back to INFO level
_setup_default_logging()

def test_runtime_validation_non_string_fields(self):
"""Test that Runtime validates node fields are strings."""
with pytest.raises(ValueError, match="must be of type str"):
Runtime(
namespace="test",
name="test",
nodes=[MockTestNodeWithNonStringFields],
state_manager_uri="http://localhost:8080",
key="test_key"
)

def test_runtime_validation_duplicate_node_names(self):
"""Test that Runtime validates no duplicate node names."""
# Create two classes with the same name
class TestNode1(MockTestNode):
pass

class TestNode2(MockTestNode):
pass

# Rename the second class to have the same name as the first
TestNode2.__name__ = "TestNode1"

# Suppress the RuntimeWarning about unawaited coroutines
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message=".*coroutine.*was never awaited.*", category=RuntimeWarning)
with pytest.raises(ValueError, match="Duplicate node class names found"):
Runtime(
namespace="test",
name="test",
nodes=[TestNode1, TestNode2],
state_manager_uri="http://localhost:8080",
key="test_key"
)

def test_need_secrets_empty_secrets(self):
"""Test _need_secrets with empty secrets class."""
runtime = Runtime(
namespace="test",
name="test",
nodes=[MockTestNodeWithoutSecrets],
state_manager_uri="http://localhost:8080",
key="test_key"
)

# Should return False for empty secrets
assert not runtime._need_secrets(MockTestNodeWithoutSecrets)

def test_need_secrets_with_secrets(self):
"""Test _need_secrets with secrets class that has fields."""
runtime = Runtime(
namespace="test",
name="test",
nodes=[MockTestNode],
state_manager_uri="http://localhost:8080",
key="test_key"
)

# Should return True for secrets with fields
assert runtime._need_secrets(MockTestNode)

@pytest.mark.asyncio
async def test_enqueue_error_handling(self):
"""Test error handling in _enqueue method."""
runtime = Runtime(
namespace="test",
name="test",
nodes=[MockTestNode],
state_manager_uri="http://localhost:8080",
key="test_key"
)

# Mock _enqueue_call to raise an exception
with patch.object(runtime, '_enqueue_call', side_effect=Exception("Test error")):
# This should not raise an exception but log an error
# We'll test this by checking that the method doesn't crash
task = asyncio.create_task(runtime._enqueue())
await asyncio.sleep(0.1) # Let it run briefly
task.cancel()
try:
await task
except asyncio.CancelledError:
pass

def test_start_without_running_loop(self):
"""Test start method when no event loop is running."""
runtime = Runtime(
namespace="test",
name="test",
nodes=[MockTestNode],
state_manager_uri="http://localhost:8080",
key="test_key"
)

# Mock _start to avoid actual execution
with patch.object(runtime, '_start', new_callable=AsyncMock):
# This should not raise an exception
result = runtime.start()
assert result is None

def test_start_with_running_loop(self):
"""Test start method when an event loop is already running."""
runtime = Runtime(
namespace="test",
name="test",
nodes=[MockTestNode],
state_manager_uri="http://localhost:8080",
key="test_key"
)

# Mock _start to avoid actual execution
with patch.object(runtime, '_start', new_callable=AsyncMock):
# Create a mock loop
mock_loop = MagicMock()
mock_task = MagicMock()
mock_loop.create_task.return_value = mock_task

with patch('asyncio.get_running_loop', return_value=mock_loop):
result = runtime.start()
assert result == mock_task
Loading
Loading