Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion python-sdk/README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,24 @@
# ExosphereHost Python SDK
This SDK is official python SDK for ExosphereHost and interacting with exospherehost.
This is the official Python SDK for ExosphereHost and for interacting with ExosphereHost.

## Node Creation
You can simply connect to exosphere state manager and start creating your nodes, as shown in sample below:

```python
from exospherehost import Runtime, BaseNode
from typing import Any
import os

class SampleNode(BaseNode):
async def execute(self, inputs: dict[str, Any]) -> dict[str, Any]:
print(inputs)
return {"message": "success"}

runtime = Runtime("SampleNamespace", os.getenv("EXOSPHERE_STATE_MANAGER_URI", "http://localhost:8000"), os.getenv("EXOSPHERE_API_KEY", ""))

runtime.connect([SampleNode()])
runtime.start()
```

## Support
For first-party support and questions, do not hesitate to reach out to us at <nivedit@exosphere.host>.
9 changes: 5 additions & 4 deletions python-sdk/exospherehost/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._version import VERSION
from ._version import version as __version__
from .runtime import Runtime
from .node.BaseNode import BaseNode

__version__ = VERSION
VERSION = __version__

def test():
print(f"ExosphereHost PySDK v{VERSION}")
__all__ = ["Runtime", "BaseNode", "VERSION"]
2 changes: 1 addition & 1 deletion python-sdk/exospherehost/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = "0.0.3b"
version = "0.0.4b"
4 changes: 2 additions & 2 deletions python-sdk/exospherehost/node/BaseNode.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Optional, Any


class BaseNode(ABC):
Expand All @@ -8,7 +8,7 @@ def __init__(self, unique_name: Optional[str] = None):
self.unique_name: Optional[str] = unique_name

@abstractmethod
async def execute(self):
async def execute(self, inputs: dict[str, Any]) -> dict[str, Any]:
pass

def get_unique_name(self) -> str:
Expand Down
1 change: 0 additions & 1 deletion python-sdk/exospherehost/node/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
from .BaseNode import BaseNode
78 changes: 59 additions & 19 deletions python-sdk/exospherehost/runtime.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import asyncio
from asyncio import Queue, sleep
import logging
from typing import Any, List
from .node import BaseNode
from .node.BaseNode import BaseNode
from aiohttp import ClientSession
from logging import getLogger

logger = getLogger(__name__)

class Runtime:

def __init__(self, namespace: str, state_manager_uri: str, batch_size: int = 16, workers=4, state_manage_version: int = 0, poll_interval: int = 10):
def __init__(self, namespace: str, state_manager_uri: str, key: str, batch_size: int = 16, workers=4, state_manage_version: str = "v0", poll_interval: int = 1):
self._namespace = namespace
self._key = key
self._batch_size = batch_size
self._connected = False
self._state_queue = Queue(maxsize=2*batch_size)
Expand All @@ -34,33 +37,61 @@ def _get_executed_endpoint(self, state_id: str):
def _get_errored_endpoint(self, state_id: str):
return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/states/{state_id}/errored"

async def connect(self, nodes: List[BaseNode]):
def connect(self, nodes: List[BaseNode]):
self._nodes = self._validate_nodes(nodes)
self._node_names = [node.get_unique_name() for node in nodes]
self._node_mapping = {node.get_unique_name(): node for node in self._nodes}
self._connected = True

async def _enqueue_call(self):
async with ClientSession() as session:
async with session.post(self._get_enque_endpoint(), json={"nodes": self._node_names, "batch_size": self._batch_size}) as response:
return await response.json()
endpoint = self._get_enque_endpoint()
body = {"nodes": self._node_names, "batch_size": self._batch_size}
headers = {"x-api-key": self._key}

async with session.post(endpoint, json=body, headers=headers) as response:
res = await response.json()

if response.status != 200:
logger.error(f"Failed to enqueue states: {res}")

return res

async def _enqueue(self):
if self._state_queue.qsize() < self._batch_size:
data = await self._enqueue_call()
for state in data["states"]:
await self._state_queue.put(state)
await sleep(self._poll_interval)
while True:
try:
if self._state_queue.qsize() < self._batch_size:
data = await self._enqueue_call()
for state in data["states"]:
await self._state_queue.put(state)
except Exception as e:
logger.error(f"Error enqueuing states: {e}")

await sleep(self._poll_interval)

async def _notify_executed(self, state_id: str, outputs: dict[str, Any]):
async with ClientSession() as session:
async with session.post(self._get_executed_endpoint(state_id), json={"outputs": outputs}) as response:
return await response.json()
endpoint = self._get_executed_endpoint(state_id)
body = {"outputs": outputs}
headers = {"x-api-key": self._key}

async with session.post(endpoint, json=body, headers=headers) as response:
res = await response.json()

if response.status != 200:
logger.error(f"Failed to notify executed state {state_id}: {res}")

async def _notify_errored(self, state_id: str, error: str):
async with ClientSession() as session:
async with session.post(self._get_errored_endpoint(state_id), json={"error": error}) as response:
return await response.json()
endpoint = self._get_errored_endpoint(state_id)
body = {"error": error}
headers = {"x-api-key": self._key}

async with session.post(endpoint, json=body, headers=headers) as response:
res = await response.json()

if response.status != 200:
logger.error(f"Failed to notify errored state {state_id}: {res}")

def _validate_nodes(self, nodes: List[BaseNode]):
invalid_nodes = []
Expand All @@ -81,15 +112,24 @@ async def _worker(self):
try:
node = self._node_mapping[state["node_name"]]
outputs = await node.execute(state["inputs"]) # type: ignore
await self._notify_executed(state["id"], outputs)
await self._notify_executed(state["state_id"], outputs)
except Exception as e:
await self._notify_errored(state["id"], str(e))
await self._notify_errored(state["state_id"], str(e))

self._state_queue.task_done() # type: ignore

async def start(self):
async def _start(self):
if not self._connected:
raise RuntimeError("Runtime not connected, you need to call Runtime.connect() before calling Runtime.start()")

poller = asyncio.create_task(self._enqueue())
worker_tasks = [asyncio.create_task(self._worker()) for _ in range(self._workers)]

await asyncio.gather(poller, *worker_tasks)
await asyncio.gather(poller, *worker_tasks)

def start(self):
try:
loop = asyncio.get_running_loop()
return loop.create_task(self._start())
except RuntimeError:
asyncio.run(self._start())
2 changes: 1 addition & 1 deletion python-sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Documentation = "https://docs.exosphere.host"
Repository = "https://github.com/exospherehost/exospherehost/tree/main/python-sdk"

[tool.setuptools.dynamic]
version = {attr = "exospherehost.__version__"}
version = {attr = "exospherehost._version.version"}

[dependency-groups]
dev = [
Expand Down
13 changes: 13 additions & 0 deletions python-sdk/sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from exospherehost import Runtime, BaseNode
from typing import Any
import os

class SampleNode(BaseNode):
async def execute(self, inputs: dict[str, Any]) -> dict[str, Any]:
print(inputs)
return {"message": "success"}

runtime = Runtime("SampleNamespace", os.getenv("EXOSPHERE_STATE_MANAGER_URI", "http://localhost:8000"), os.getenv("EXOSPHERE_API_KEY", ""))

runtime.connect([SampleNode()])
runtime.start()
Loading