Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python-sdk/exospherehost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@ async def execute(self, inputs: Inputs) -> Outputs:
from ._version import version as __version__
from .runtime import Runtime
from .node.BaseNode import BaseNode
from .statemanager import StateManager, TriggerState
from .statemanager import StateManager
from .signals import PruneSignal, ReQueueAfterSignal
from .models import UnitesStrategyEnum, UnitesModel, GraphNodeModel, RetryStrategyEnum, RetryPolicyModel, StoreConfigModel

VERSION = __version__

__all__ = ["Runtime", "BaseNode", "StateManager", "TriggerState", "VERSION", "PruneSignal", "ReQueueAfterSignal"]
__all__ = ["Runtime", "BaseNode", "StateManager", "VERSION", "PruneSignal", "ReQueueAfterSignal", "UnitesStrategyEnum", "UnitesModel", "GraphNodeModel", "RetryStrategyEnum", "RetryPolicyModel", "StoreConfigModel"]
2 changes: 1 addition & 1 deletion python-sdk/exospherehost/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.0.2b4"
version = "0.0.2b5"
160 changes: 160 additions & 0 deletions python-sdk/exospherehost/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from pydantic import BaseModel, Field, field_validator
from typing import Any, Optional, List
from enum import Enum


class UnitesStrategyEnum(str, Enum):
ALL_SUCCESS = "ALL_SUCCESS"
ALL_DONE = "ALL_DONE"


class UnitesModel(BaseModel):
identifier: str = Field(..., description="Identifier of the node")
strategy: UnitesStrategyEnum = Field(default=UnitesStrategyEnum.ALL_SUCCESS, description="Strategy of the unites")


class GraphNodeModel(BaseModel):
node_name: str = Field(..., description="Name of the node")
namespace: str = Field(..., description="Namespace of the node")
identifier: str = Field(..., description="Identifier of the node")
inputs: dict[str, Any] = Field(..., description="Inputs of the node")
next_nodes: Optional[List[str]] = Field(None, description="Next nodes to execute")
unites: Optional[UnitesModel] = Field(None, description="Unites of the node")

@field_validator('node_name')
@classmethod
def validate_node_name(cls, v: str) -> str:
trimmed_v = v.strip()
if trimmed_v == "" or trimmed_v is None:
raise ValueError("Node name cannot be empty")
return trimmed_v

@field_validator('identifier')
@classmethod
def validate_identifier(cls, v: str) -> str:
trimmed_v = v.strip()
if trimmed_v == "" or trimmed_v is None:
raise ValueError("Node identifier cannot be empty")
elif trimmed_v == "store":
raise ValueError("Node identifier cannot be reserved word 'store'")
return trimmed_v

@field_validator('next_nodes')
@classmethod
def validate_next_nodes(cls, v: Optional[List[str]]) -> Optional[List[str]]:
identifiers = set()
errors = []
trimmed_v = []

if v is not None:
for next_node_identifier in v:
trimmed_next_node_identifier = next_node_identifier.strip()

if trimmed_next_node_identifier == "" or trimmed_next_node_identifier is None:
errors.append("Next node identifier cannot be empty")
continue

if trimmed_next_node_identifier in identifiers:
errors.append(f"Next node identifier {trimmed_next_node_identifier} is not unique")
continue

identifiers.add(trimmed_next_node_identifier)
trimmed_v.append(trimmed_next_node_identifier)
if errors:
raise ValueError("\n".join(errors))
return trimmed_v

@field_validator('unites')
@classmethod
def validate_unites(cls, v: Optional[UnitesModel]) -> Optional[UnitesModel]:
trimmed_v = v
if v is not None:
trimmed_v = UnitesModel(identifier=v.identifier.strip(), strategy=v.strategy)
if trimmed_v.identifier == "" or trimmed_v.identifier is None:
raise ValueError("Unites identifier cannot be empty")
return trimmed_v


class RetryStrategyEnum(str, Enum):
EXPONENTIAL = "EXPONENTIAL"
EXPONENTIAL_FULL_JITTER = "EXPONENTIAL_FULL_JITTER"
EXPONENTIAL_EQUAL_JITTER = "EXPONENTIAL_EQUAL_JITTER"

LINEAR = "LINEAR"
LINEAR_FULL_JITTER = "LINEAR_FULL_JITTER"
LINEAR_EQUAL_JITTER = "LINEAR_EQUAL_JITTER"

FIXED = "FIXED"
FIXED_FULL_JITTER = "FIXED_FULL_JITTER"
FIXED_EQUAL_JITTER = "FIXED_EQUAL_JITTER"


class RetryPolicyModel(BaseModel):
max_retries: int = Field(default=3, description="The maximum number of retries", ge=0)
strategy: RetryStrategyEnum = Field(default=RetryStrategyEnum.EXPONENTIAL, description="The method of retry")
backoff_factor: int = Field(default=2000, description="The backoff factor in milliseconds (default: 2000 = 2 seconds)", gt=0)
exponent: int = Field(default=2, description="The exponent for the exponential retry strategy", gt=0)
max_delay: int | None = Field(default=None, description="The maximum delay in milliseconds (no default limit when None)", gt=0)


class StoreConfigModel(BaseModel):
required_keys: list[str] = Field(default_factory=list, description="Required keys of the store")
default_values: dict[str, str] = Field(default_factory=dict, description="Default values of the store")

@field_validator("required_keys")
@classmethod
def validate_required_keys(cls, v: list[str]) -> list[str]:
errors = []
keys = set()
trimmed_keys = []

for key in v:
trimmed_key = key.strip() if key is not None else ""

if trimmed_key == "":
errors.append("Key cannot be empty or contain only whitespace")
continue

if '.' in trimmed_key:
errors.append(f"Key '{trimmed_key}' cannot contain '.' character")
continue

if trimmed_key in keys:
errors.append(f"Key '{trimmed_key}' is duplicated")
continue

keys.add(trimmed_key)
trimmed_keys.append(trimmed_key)

if len(errors) > 0:
raise ValueError("\n".join(errors))
return trimmed_keys

@field_validator("default_values")
@classmethod
def validate_default_values(cls, v: dict[str, str]) -> dict[str, str]:
errors = []
keys = set()
normalized_dict = {}

for key, value in v.items():
trimmed_key = key.strip() if key is not None else ""

if trimmed_key == "":
errors.append("Key cannot be empty or contain only whitespace")
continue

if '.' in trimmed_key:
errors.append(f"Key '{trimmed_key}' cannot contain '.' character")
continue

if trimmed_key in keys:
errors.append(f"Key '{trimmed_key}' is duplicated")
continue

keys.add(trimmed_key)
normalized_dict[trimmed_key] = str(value)

if len(errors) > 0:
raise ValueError("\n".join(errors))
return normalized_dict
55 changes: 12 additions & 43 deletions python-sdk/exospherehost/statemanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,7 @@
import asyncio
import time

from typing import Any
from pydantic import BaseModel


class TriggerState(BaseModel):
"""
Represents a trigger state for graph execution.

A trigger state contains an identifier and a set of input parameters that
will be passed to the graph when it is triggered for execution.

Attributes:
identifier (str): A unique identifier for this trigger state. This is used
to distinguish between different trigger states and may be used by the
graph to determine how to process the trigger.
inputs (dict[str, str]): A dictionary of input parameters that will be
passed to the graph. The keys are parameter names and values are
parameter values, both as strings.

Example:
```python
# Create a trigger state with identifier and inputs
trigger_state = TriggerState(
identifier="user-login",
inputs={
"user_id": "12345",
"session_token": "abc123def456",
"timestamp": "2024-01-15T10:30:00Z"
}
)
```
"""
identifier: str
inputs: dict[str, str]
from .models import GraphNodeModel, RetryPolicyModel, StoreConfigModel


class StateManager:
Expand Down Expand Up @@ -67,7 +34,7 @@ def _get_upsert_graph_endpoint(self, graph_name: str):
def _get_get_graph_endpoint(self, graph_name: str):
return f"{self._state_manager_uri}/{self._state_manager_version}/namespace/{self._namespace}/graph/{graph_name}"

async def trigger(self, graph_name: str, inputs: dict[str, str] | None = None, store: dict[str, str] | None = None):
async def trigger(self, graph_name: str, inputs: dict[str, str] | None = None, store: dict[str, str] | None = None, start_delay: int = 0):
"""
Trigger execution of a graph.

Expand All @@ -82,7 +49,8 @@ async def trigger(self, graph_name: str, inputs: dict[str, str] | None = None, s
graph. Strings only.
store (dict[str, str] | None): Optional key-value store that will be merged
into the graph-level store before execution (beta).

start_delay (int): Optional delay in milliseconds before the graph starts execution.

Returns:
dict: JSON payload returned by the state-manager API.

Expand All @@ -108,6 +76,7 @@ async def trigger(self, graph_name: str, inputs: dict[str, str] | None = None, s
store = {}

body = {
"start_delay": start_delay,
"inputs": inputs,
"store": store
}
Expand Down Expand Up @@ -156,7 +125,7 @@ async def get_graph(self, graph_name: str):
raise Exception(f"Failed to get graph: {response.status} {await response.text()}")
return await response.json()

async def upsert_graph(self, graph_name: str, graph_nodes: list[dict[str, Any]], secrets: dict[str, str], retry_policy: dict[str, Any] | None = None, store_config: dict[str, Any] | None = None, validation_timeout: int = 60, polling_interval: int = 1):
async def upsert_graph(self, graph_name: str, graph_nodes: list[GraphNodeModel], secrets: dict[str, str], retry_policy: RetryPolicyModel | None = None, store_config: StoreConfigModel | None = None, validation_timeout: int = 60, polling_interval: int = 1):
"""
Create or update a graph definition.

Expand All @@ -169,10 +138,10 @@ async def upsert_graph(self, graph_name: str, graph_nodes: list[dict[str, Any]],

Args:
graph_name (str): Graph identifier.
graph_nodes (list[dict[str, Any]]): Graph node list.
graph_nodes (list[GraphNodeModel]): List of graph node models defining the workflow.
secrets (dict[str, str]): Secrets available to all nodes.
retry_policy (dict[str, Any] | None): Optional per-node retry policy.
store_config (dict[str, Any] | None): Beta configuration for the
retry_policy (RetryPolicyModel | None): Optional per-node retry policy configuration.
store_config (StoreConfigModel | None): Beta configuration for the
graph-level store (schema is subject to change).
validation_timeout (int): Seconds to wait for validation (default 60).
polling_interval (int): Polling interval in seconds (default 1).
Expand All @@ -189,13 +158,13 @@ async def upsert_graph(self, graph_name: str, graph_nodes: list[dict[str, Any]],
}
body = {
"secrets": secrets,
"nodes": graph_nodes
"nodes": [node.model_dump() for node in graph_nodes]
}

if retry_policy is not None:
body["retry_policy"] = retry_policy
body["retry_policy"] = retry_policy.model_dump()
if store_config is not None:
body["store_config"] = store_config
body["store_config"] = store_config.model_dump()

async with aiohttp.ClientSession() as session:
async with session.put(endpoint, json=body, headers=headers) as response: # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion python-sdk/tests/test_coverage_additions.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def test_statemanager_trigger_defaults(monkeypatch):
# Verify it sent empty inputs/store when omitted
mock_session.post.assert_called_once()
_, kwargs = mock_session.post.call_args
assert kwargs["json"] == {"inputs": {}, "store": {}}
assert kwargs["json"] == {"inputs": {}, "store": {}, "start_delay": 0}


class _DummyNode(BaseNode):
Expand Down
23 changes: 14 additions & 9 deletions python-sdk/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio
from unittest.mock import AsyncMock, patch, MagicMock
from pydantic import BaseModel
from exospherehost import Runtime, BaseNode, StateManager, TriggerState
from exospherehost import Runtime, BaseNode, StateManager


def create_mock_aiohttp_session():
Expand Down Expand Up @@ -205,21 +205,26 @@ async def test_state_manager_graph_lifecycle(self, mock_env_vars):
sm = StateManager(namespace="test_namespace")

# Test graph creation
from exospherehost.models import GraphNodeModel
graph_nodes = [
{"name": "IntegrationTestNode", "type": "test"}
GraphNodeModel(
node_name="IntegrationTestNode",
namespace="test_namespace",
identifier="IntegrationTestNode",
inputs={"type": "test"},
next_nodes=None,
unites=None
)
]
secrets = {"api_key": "test_key", "database_url": "db://test"}

result = await sm.upsert_graph("test_graph", graph_nodes, secrets, validation_timeout=10, polling_interval=0.1) # type: ignore
assert result["validation_status"] == "VALID"

# Test graph triggering
trigger_state = TriggerState(
identifier="test_trigger",
inputs={"user_id": "123", "action": "login"}
)
trigger_state = {"identifier": "test_trigger", "inputs": {"user_id": "123", "action": "login"}}

trigger_result = await sm.trigger("test_graph", inputs=trigger_state.inputs)
trigger_result = await sm.trigger("test_graph", inputs=trigger_state["inputs"])
assert trigger_result == {"status": "triggered"}


Expand Down Expand Up @@ -448,10 +453,10 @@ async def test_state_manager_error_propagation(self, mock_env_vars):
mock_session_class.return_value = mock_session

sm = StateManager(namespace="error_test")
trigger_state = TriggerState(identifier="test", inputs={"key": "value"})
trigger_state = {"identifier": "test", "inputs": {"key": "value"}}

with pytest.raises(Exception, match="Failed to trigger state: 404 Graph not found"):
await sm.trigger("nonexistent_graph", inputs=trigger_state.inputs)
await sm.trigger("nonexistent_graph", inputs=trigger_state["inputs"])


class TestConcurrencyIntegration:
Expand Down
Loading