diff --git a/python-sdk/exospherehost/runtime.py b/python-sdk/exospherehost/runtime.py index 9495cb91..f202afaf 100644 --- a/python-sdk/exospherehost/runtime.py +++ b/python-sdk/exospherehost/runtime.py @@ -56,7 +56,6 @@ def __init__(self, namespace: str, name: str, state_manager_uri: str | None = No self._namespace = namespace self._key = key self._batch_size = batch_size - self._connected = False self._state_queue = Queue(maxsize=2*batch_size) self._workers = workers self._nodes = [] @@ -104,8 +103,39 @@ def _get_executed_endpoint(self, state_id: str): def _get_errored_endpoint(self, state_id: str): """Get the endpoint URL for notifying errored states.""" return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/states/{state_id}/errored" + + def _get_register_endpoint(self): + """Get the endpoint URL for registering nodes with runtime""" + return f"{self._state_manager_uri}/{str(self._state_manager_version)}/namespace/{self._namespace}/nodes/" + + async def _register_nodes(self): + """Register nodes with the runtime""" + async with ClientSession() as session: + endpoint = self._get_register_endpoint() + body = { + "runtime_name": self._name, + "runtime_namespace": self._namespace, + "nodes": [ + { + "name": node.get_unique_name(), + "namespace": self._namespace, + "inputs_schema": node.Inputs.model_json_schema(), + "outputs_schema": node.Outputs.model_json_schema(), + } for node in self._nodes + ] + } + headers = {"x-api-key": self._key} + + async with session.put(endpoint, json=body, headers=headers) as response: # type: ignore + res = await response.json() - def connect(self, nodes: List[BaseNode]): + if response.status != 200: + raise RuntimeError(f"Failed to register nodes: {res}") + + return res + + + async def _register(self, nodes: List[BaseNode]): """ Connect nodes to the runtime. @@ -121,7 +151,9 @@ def connect(self, nodes: List[BaseNode]): self._nodes = self._validate_nodes(nodes) self._node_names = [node.get_unique_name() for node in nodes] self._node_mapping = {node.get_unique_name(): node for node in self._nodes} - self._connected = True + + await self._register_nodes() + async def _enqueue_call(self): """ @@ -250,7 +282,7 @@ async def _worker(self): self._state_queue.task_done() # type: ignore - async def _start(self): + async def _start(self, nodes: List[BaseNode]): """ Start the runtime execution. @@ -260,15 +292,14 @@ async def _start(self): Raises: RuntimeError: If the runtime is not connected (no nodes registered) """ - if not self._connected: - raise RuntimeError("Runtime not connected, you need to call Runtime.connect() before calling Runtime.start()") + await self._register(nodes) poller = asyncio.create_task(self._enqueue()) worker_tasks = [asyncio.create_task(self._worker()) for _ in range(self._workers)] await asyncio.gather(poller, *worker_tasks) - def start(self): + def start(self, nodes: List[BaseNode]): """ Start the runtime execution. @@ -281,6 +312,6 @@ def start(self): """ try: loop = asyncio.get_running_loop() - return loop.create_task(self._start()) + return loop.create_task(self._start(nodes)) except RuntimeError: - asyncio.run(self._start()) + asyncio.run(self._start(nodes)) diff --git a/state-manager/app/controller/register_nodes.py b/state-manager/app/controller/register_nodes.py new file mode 100644 index 00000000..117de4df --- /dev/null +++ b/state-manager/app/controller/register_nodes.py @@ -0,0 +1,68 @@ +from ..models.register_nodes_request import RegisterNodesRequestModel +from ..models.register_nodes_response import RegisterNodesResponseModel, RegisteredNodeModel +from ..models.db.registered_node import RegisteredNode + +from app.singletons.logs_manager import LogsManager +from beanie.operators import Set + +logger = LogsManager().get_logger() + + +async def register_nodes(namespace_name: str, body: RegisterNodesRequestModel, x_exosphere_request_id: str) -> RegisterNodesResponseModel: + + try: + logger.info(f"Registering nodes for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + # Check if nodes already exist and update them, or create new ones + registered_nodes = [] + + for node_data in body.nodes: + # Check if node already exists + existing_node = await RegisteredNode.find_one( + RegisteredNode.name == node_data.name, + RegisteredNode.namespace == namespace_name + ) + + if existing_node: + # Update existing node + await existing_node.update( + Set({ + RegisteredNode.runtime_name: body.runtime_name, + RegisteredNode.runtime_namespace: namespace_name, + RegisteredNode.inputs_schema: node_data.inputs_schema, # type: ignore + RegisteredNode.outputs_schema: node_data.outputs_schema # type: ignore + })) + logger.info(f"Updated existing node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + else: + # Create new node + new_node = RegisteredNode( + name=node_data.name, + namespace=namespace_name, + runtime_name=body.runtime_name, + runtime_namespace=namespace_name, + inputs_schema=node_data.inputs_schema, + outputs_schema=node_data.outputs_schema + ) + await new_node.insert() + logger.info(f"Created new node {node_data.name} in namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + + registered_nodes.append( + RegisteredNodeModel( + name=node_data.name, + inputs_schema=node_data.inputs_schema, + outputs_schema=node_data.outputs_schema + ) + ) + + response = RegisterNodesResponseModel( + runtime_name=body.runtime_name, + registered_nodes=registered_nodes + ) + + logger.info(f"Successfully registered {len(registered_nodes)} nodes for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + return response + + except Exception as e: + logger.error(f"Error registering nodes for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id, error=e) + raise e \ No newline at end of file diff --git a/state-manager/app/main.py b/state-manager/app/main.py index f934f321..b60bc316 100644 --- a/state-manager/app/main.py +++ b/state-manager/app/main.py @@ -21,6 +21,7 @@ from .models.db.state import State from .models.db.namespace import Namespace from .models.db.graph_template_model import GraphTemplate +from .models.db.registered_node import RegisteredNode # injecting routes from .routes import router @@ -36,7 +37,7 @@ async def lifespan(app: FastAPI): # initializing beanie client = AsyncMongoClient(os.getenv("MONGO_URI")) db = client[os.getenv("MONGO_DATABASE_NAME", "exosphere-state-manager")] - await init_beanie(db, document_models=[State, Namespace, GraphTemplate]) + await init_beanie(db, document_models=[State, Namespace, GraphTemplate, RegisteredNode]) logger.info("beanie dbs initialized") # initialize secret diff --git a/state-manager/app/models/db/registered_node.py b/state-manager/app/models/db/registered_node.py new file mode 100644 index 00000000..46d421e1 --- /dev/null +++ b/state-manager/app/models/db/registered_node.py @@ -0,0 +1,12 @@ +from .base import BaseDatabaseModel +from pydantic import Field +from typing import Any + + +class RegisteredNode(BaseDatabaseModel): + name: str = Field(..., description="Unique name of the registered node") + namespace: str = Field(..., description="Namespace of the registered node") + runtime_name: str = Field(..., description="Name of the runtime that registered this node") + runtime_namespace: str = Field(..., description="Namespace of the runtime that registered this node") + inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs") + outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs") \ No newline at end of file diff --git a/state-manager/app/models/register_nodes_request.py b/state-manager/app/models/register_nodes_request.py new file mode 100644 index 00000000..a3fbed92 --- /dev/null +++ b/state-manager/app/models/register_nodes_request.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field +from typing import Any, List + + +class NodeRegistrationModel(BaseModel): + name: str = Field(..., description="Unique name of the node") + inputs_schema: dict[str, Any] = Field(..., description="JSON schema for node inputs") + outputs_schema: dict[str, Any] = Field(..., description="JSON schema for node outputs") + + +class RegisterNodesRequestModel(BaseModel): + runtime_name: str = Field(..., description="Name of the runtime registering the nodes") + nodes: List[NodeRegistrationModel] = Field(..., description="List of nodes to register") \ No newline at end of file diff --git a/state-manager/app/models/register_nodes_response.py b/state-manager/app/models/register_nodes_response.py new file mode 100644 index 00000000..52bbfeb9 --- /dev/null +++ b/state-manager/app/models/register_nodes_response.py @@ -0,0 +1,13 @@ +from pydantic import BaseModel, Field +from typing import Any, List + + +class RegisteredNodeModel(BaseModel): + name: str = Field(..., description="Name of the registered node") + inputs_schema: dict[str, Any] = Field(..., description="Inputs for the registered node") + outputs_schema: dict[str, Any] = Field(..., description="Outputs for the registered node") + + +class RegisterNodesResponseModel(BaseModel): + runtime_name: str = Field(..., description="Name of the runtime that registered the nodes") + registered_nodes: List[RegisteredNodeModel] = Field(..., description="List of successfully registered nodes") \ No newline at end of file diff --git a/state-manager/app/routes.py b/state-manager/app/routes.py index 1eecac60..b1d92227 100644 --- a/state-manager/app/routes.py +++ b/state-manager/app/routes.py @@ -21,6 +21,10 @@ from .models.graph_models import UpsertGraphTemplateRequest, UpsertGraphTemplateResponse from .controller.upsert_graph_template import upsert_graph_template as upsert_graph_template_controller +from .models.register_nodes_request import RegisterNodesRequestModel +from .models.register_nodes_response import RegisterNodesResponseModel +from .controller.register_nodes import register_nodes + logger = LogsManager().get_logger() @@ -124,4 +128,23 @@ async def upsert_graph_template(namespace_name: str, graph_name: str, body: Upse logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") - return await upsert_graph_template_controller(namespace_name, graph_name, body, x_exosphere_request_id) \ No newline at end of file + return await upsert_graph_template_controller(namespace_name, graph_name, body, x_exosphere_request_id) + + +@router.put( + "/nodes/", + response_model=RegisterNodesResponseModel, + status_code=status.HTTP_200_OK, + response_description="Nodes registered successfully", + tags=["nodes"] +) +async def register_nodes_route(namespace_name: str, body: RegisterNodesRequestModel, request: Request, api_key: str = Depends(check_api_key)): + x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4())) + + if api_key: + logger.info(f"API key is valid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + else: + logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key") + + return await register_nodes(namespace_name, body, x_exosphere_request_id) \ No newline at end of file