Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions python-sdk/exospherehost/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(self, namespace: str, name: str, state_manager_uri: str | None = No
self._namespace = namespace
self._key = key
self._batch_size = batch_size
self._connected = False
self._state_queue = Queue(maxsize=2*batch_size)
self._workers = workers
self._nodes = []
Expand Down Expand Up @@ -104,8 +103,39 @@ def _get_executed_endpoint(self, state_id: str):
def _get_errored_endpoint(self, state_id: str):
"""Get the endpoint URL for notifying errored states."""
return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/states/{state_id}/errored"

def _get_register_endpoint(self):
"""Get the endpoint URL for registering nodes with runtime"""
return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/nodes/"

async def _register_nodes(self):
"""Register nodes with the runtime"""
async with ClientSession() as session:
endpoint = self._get_register_endpoint()
body = {
"runtime_name": self._name,
"runtime_namespace": self._namespace,
"nodes": [
{
"name": node.get_unique_name(),
"namespace": self._namespace,
"inputs_schema": node.Inputs.model_json_schema(),
"outputs_schema": node.Outputs.model_json_schema(),
} for node in self._nodes
]
}
headers = {"x-api-key": self._key}

async with session.put(endpoint, json=body, headers=headers) as response: # type: ignore
res = await response.json()

def connect(self, nodes: List[BaseNode]):
if response.status != 200:
raise RuntimeError(f"Failed to register nodes: {res}")

return res


async def _register(self, nodes: List[BaseNode]):
"""
Connect nodes to the runtime.

Expand All @@ -121,7 +151,9 @@ def connect(self, nodes: List[BaseNode]):
self._nodes = self._validate_nodes(nodes)
self._node_names = [node.get_unique_name() for node in nodes]
self._node_mapping = {node.get_unique_name(): node for node in self._nodes}
self._connected = True

await self._register_nodes()


async def _enqueue_call(self):
"""
Expand Down Expand Up @@ -250,7 +282,7 @@ async def _worker(self):

self._state_queue.task_done() # type: ignore

async def _start(self):
async def _start(self, nodes: List[BaseNode]):
"""
Start the runtime execution.

Expand All @@ -260,15 +292,14 @@ async def _start(self):
Raises:
RuntimeError: If the runtime is not connected (no nodes registered)
"""
if not self._connected:
raise RuntimeError("Runtime not connected, you need to call Runtime.connect() before calling Runtime.start()")
await self._register(nodes)

poller = asyncio.create_task(self._enqueue())
worker_tasks = [asyncio.create_task(self._worker()) for _ in range(self._workers)]

await asyncio.gather(poller, *worker_tasks)

def start(self):
def start(self, nodes: List[BaseNode]):
"""
Start the runtime execution.

Expand All @@ -281,6 +312,6 @@ def start(self):
"""
try:
loop = asyncio.get_running_loop()
return loop.create_task(self._start())
return loop.create_task(self._start(nodes))
except RuntimeError:
asyncio.run(self._start())
asyncio.run(self._start(nodes))
68 changes: 68 additions & 0 deletions state-manager/app/controller/register_nodes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from ..models.register_nodes_request import RegisterNodesRequestModel
from ..models.register_nodes_response import RegisterNodesResponseModel, RegisteredNodeModel
from ..models.db.registered_node import RegisteredNode

from app.singletons.logs_manager import LogsManager
from beanie.operators import Set

logger = LogsManager().get_logger()


async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x_exosphere_request_id: str) -> RegisterNodesResponseModel:

try:
logger.info(f"Registering nodes for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)

# Check if nodes already exist and update them, or create new ones
registered_nodes = []

for node_data in body.nodes:
# Check if node already exists
existing_node = await RegisteredNode.find_one(
RegisteredNode.name == node_data.name,
RegisteredNode.namespace == namespace_name
)

if existing_node:
# Update existing node
await existing_node.update(
Set({
RegisteredNode.runtime_name: body.runtime_name,
RegisteredNode.runtime_namespace: namespace_name,
RegisteredNode.inputs_schema: node_data.inputs_schema, # type: ignore
RegisteredNode.outputs_schema: node_data.outputs_schema # type: ignore
}))
logger.info(f"Updated existing node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)

else:
# Create new node
new_node = RegisteredNode(
name=node_data.name,
namespace=namespace_name,
runtime_name=body.runtime_name,
runtime_namespace=namespace_name,
inputs_schema=node_data.inputs_schema,
outputs_schema=node_data.outputs_schema
)
await new_node.insert()
logger.info(f"Created new node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)

registered_nodes.append(
RegisteredNodeModel(
name=node_data.name,
inputs_schema=node_data.inputs_schema,
outputs_schema=node_data.outputs_schema
)
)

response = RegisterNodesResponseModel(
runtime_name=body.runtime_name,
registered_nodes=registered_nodes
)

logger.info(f"Successfully registered {len(registered_nodes)} nodes for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
return response

except Exception as e:
logger.error(f"Error registering nodes for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id, error=e)
raise e
3 changes: 2 additions & 1 deletion state-manager/app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .models.db.state import State
from .models.db.namespace import Namespace
from .models.db.graph_template_model import GraphTemplate
from .models.db.registered_node import RegisteredNode

# injecting routes
from .routes import router
Expand All @@ -36,7 +37,7 @@ async def lifespan(app: FastAPI):
# initializing beanie
client = AsyncMongoClient(os.getenv("MONGO_URI"))
db = client[os.getenv("MONGO_DATABASE_NAME", "exosphere-state-manager")]
await init_beanie(db, document_models=[State, Namespace, GraphTemplate])
await init_beanie(db, document_models=[State, Namespace, GraphTemplate, RegisteredNode])
logger.info("beanie dbs initialized")

# initialize secret
Expand Down
12 changes: 12 additions & 0 deletions state-manager/app/models/db/registered_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .base import BaseDatabaseModel
from pydantic import Field
from typing import Any


class RegisteredNode(BaseDatabaseModel):
name: str = Field(..., description="Unique name of the registered node")
namespace: str = Field(..., description="Namespace of the registered node")
runtime_name: str = Field(..., description="Name of the runtime that registered this node")
runtime_namespace: str = Field(..., description="Namespace of the runtime that registered this node")
inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs")
outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs")
13 changes: 13 additions & 0 deletions state-manager/app/models/register_nodes_request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import Any, List


class NodeRegistrationModel(BaseModel):
name: str = Field(..., description="Unique name of the node")
inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs")
outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs")


class RegisterNodesRequestModel(BaseModel):
runtime_name: str = Field(..., description="Name of the runtime registering the nodes")
nodes: List[NodeRegistrationModel] = Field(..., description="List of nodes to register")
13 changes: 13 additions & 0 deletions state-manager/app/models/register_nodes_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from pydantic import BaseModel, Field
from typing import Any, List


class RegisteredNodeModel(BaseModel):
name: str = Field(..., description="Name of the registered node")
inputs_schema: dict[str, Any] = Field(..., description="Inputs for the registered node")
outputs_schema: dict[str, Any] = Field(..., description="Outputs for the registered node")


class RegisterNodesResponseModel(BaseModel):
runtime_name: str = Field(..., description="Name of the runtime that registered the nodes")
registered_nodes: List[RegisteredNodeModel] = Field(..., description="List of successfully registered nodes")
25 changes: 24 additions & 1 deletion state-manager/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
from .models.graph_models import UpsertGraphTemplateRequest, UpsertGraphTemplateResponse
from .controller.upsert_graph_template import upsert_graph_template as upsert_graph_template_controller

from .models.register_nodes_request import RegisterNodesRequestModel
from .models.register_nodes_response import RegisterNodesResponseModel
from .controller.register_nodes import register_nodes



logger = LogsManager().get_logger()
Expand Down Expand Up @@ -124,4 +128,23 @@ async def upsert_graph_template(namespace_name: str, graph_name: str, body: Upse
logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")

return await upsert_graph_template_controller(namespace_name, graph_name, body, x_exosphere_request_id)
return await upsert_graph_template_controller(namespace_name, graph_name, body, x_exosphere_request_id)


@router.put(
"/nodes/",
response_model=RegisterNodesResponseModel,
status_code=status.HTTP_200_OK,
response_description="Nodes registered successfully",
tags=["nodes"]
)
async def register_nodes_route(namespace_name: str, body: RegisterNodesRequestModel, request: Request, api_key: str = Depends(check_api_key)):
x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4()))

if api_key:
logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
else:
logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")

return await register_nodes(namespace_name, body, x_exosphere_request_id)