Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions state-manager/app/controller/create_states.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,44 @@
from fastapi import HTTPException

from app.singletons.logs_manager import LogsManager
from app.models.create_models import CreateRequestModel, CreateResponseModel, ResponseStateModel
from app.models.state_status_enum import StateStatusEnum
from app.models.db.state import State
from app.models.db.graph_template_model import GraphTemplate
from app.models.node_template_model import NodeTemplate

from beanie.operators import In
from bson import ObjectId

logger = LogsManager().get_logger()

async def create_states(namespace_name: str, body: CreateRequestModel, x_exosphere_request_id: str) -> CreateResponseModel:

def get_node_template(graph_template: GraphTemplate, identifier: str) -> NodeTemplate:
node = graph_template.get_node_by_identifier(identifier)
if not node:
raise HTTPException(status_code=404, detail="Node template not found")
return node


async def create_states(namespace_name: str, graph_name: str, body: CreateRequestModel, x_exosphere_request_id: str) -> CreateResponseModel:
try:
states = []
logger.info(f"Creating states for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)

graph_template = await GraphTemplate.find_one(GraphTemplate.name == graph_name, GraphTemplate.namespace == namespace_name)
if not graph_template:
raise HTTPException(status_code=404, detail="Graph template not found")

for state in body.states:

node_template = get_node_template(graph_template, state.identifier)

states.append(
State(
node_name=state.node_name,
namespace_name=namespace_name,
graph_name=state.graph_name,
identifier=state.identifier,
node_name=node_template.node_name,
namespace_name=node_template.namespace,
graph_name=graph_name,
status=StateStatusEnum.CREATED,
inputs=state.inputs,
outputs={},
Expand Down
1 change: 1 addition & 0 deletions state-manager/app/controller/enqueue_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ async def enqueue_states(namespace_name: str, body: EnqueueRequestModel, x_exosp
StateModel(
state_id=str(state.id),
node_name=state.node_name,
identifier=state.identifier,
inputs=state.inputs,
created_at=state.created_at
)
Expand Down
4 changes: 2 additions & 2 deletions state-manager/app/models/create_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@


class RequestStateModel(BaseModel):
node_name: str = Field(..., description="Name of the node of the state")
graph_name: str = Field(..., description="Name of the graph template for this state")
identifier: str = Field(..., description="Unique identifier of the node template within the graph template")
inputs: dict[str, Any] = Field(..., description="Inputs of the state")


class ResponseStateModel(BaseModel):
state_id: str = Field(..., description="ID of the state")
node_name: str = Field(..., description="Name of the node of the state")
identifier: str = Field(..., description="Identifier of the node for which state is created")
graph_name: str = Field(..., description="Name of the graph template for this state")
inputs: dict[str, Any] = Field(..., description="Inputs of the state")
created_at: datetime = Field(..., description="Date and time when the state was created")
Expand Down
8 changes: 8 additions & 0 deletions state-manager/app/models/db/graph_template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Dict
from app.utils.encrypter import get_encrypter


class GraphTemplate(BaseDatabaseModel):
name: str = Field(..., description="Name of the graph")
namespace: str = Field(..., description="Namespace of the graph")
Expand All @@ -26,6 +27,13 @@ class Settings:
)
]

def get_node_by_identifier(self, identifier: str) -> NodeTemplate | None:
"""Get a node by its identifier using O(1) dictionary lookup."""
for node in self.nodes:
if node.identifier == identifier:
return node
return None

@field_validator('secrets')
@classmethod
def validate_secrets(cls, v: Dict[str, str]) -> Dict[str, str]:
Expand Down
1 change: 1 addition & 0 deletions state-manager/app/models/db/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class State(BaseDatabaseModel):

node_name: str = Field(..., description="Name of the node of the state")
namespace_name: str = Field(..., description="Name of the namespace of the state")
identifier: str = Field(..., description="Identifier of the node for which state is created")
graph_name: str = Field(..., description="Name of the graph template for this state")
status: StateStatusEnum = Field(..., description="Status of the state")
inputs: dict[str, Any] = Field(..., description="Inputs of the state")
Expand Down
1 change: 1 addition & 0 deletions state-manager/app/models/enqueue_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
class StateModel(BaseModel):
state_id: str = Field(..., description="ID of the state")
node_name: str = Field(..., description="Name of the node of the state")
identifier: str = Field(..., description="Identifier of the node for which state is created")
inputs: dict[str, Any] = Field(..., description="Inputs of the state")
created_at: datetime = Field(..., description="Date and time when the state was created")

Expand Down
6 changes: 3 additions & 3 deletions state-manager/app/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,13 @@ async def enqueue_state(namespace_name: str, body: EnqueueRequestModel, request:


@router.post(
"/states/create",
"/graph/{graph_name}/states/create",
response_model=CreateResponseModel,
status_code=status.HTTP_200_OK,
response_description="States created successfully",
tags=["state"]
)
async def create_state(namespace_name: str, body: CreateRequestModel, request: Request, api_key: str = Depends(check_api_key)):
async def create_state(namespace_name: str, graph_name: str, body: CreateRequestModel, request: Request, api_key: str = Depends(check_api_key)):

x_exosphere_request_id = getattr(request.state, "x_exosphere_request_id", str(uuid4()))

Expand All @@ -71,7 +71,7 @@ async def create_state(namespace_name: str, body: CreateRequestModel, request: R
logger.error(f"API key is invalid for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key")

return await create_states(namespace_name, body, x_exosphere_request_id)
return await create_states(namespace_name, graph_name, body, x_exosphere_request_id)


@router.post(
Expand Down