Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions state-manager/app/controller/executed_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,26 @@ async def executed_state(namespace_name: str, state_id: ObjectId, body: Executed
logger.info(f"Executed state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id)

state = await State.find_one(State.id == state_id)
if not state:
if not state or not state.id:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found")

if state.status != StateStatusEnum.QUEUED:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="State is not queued")

if len(body.outputs) == 0:
await State.find_one(State.id == state_id).set(
{"status": StateStatusEnum.EXECUTED, "outputs": {}, "parents": {**state.parents, state.identifier: ObjectId(state.id)}}
{"status": StateStatusEnum.EXECUTED, "outputs": {}, "parents": {**state.parents, state.identifier: state.id}}
)

background_tasks.add_task(create_next_state, state)

else:
await State.find_one(State.id == state_id).set(
{"status": StateStatusEnum.EXECUTED, "outputs": body.outputs[0], "parents": {**state.parents, state.identifier: ObjectId(state.id)}}
)

state.outputs = body.outputs[0]
state.status = StateStatusEnum.EXECUTED
state.parents = {**state.parents, state.identifier: state.id}
await state.save()

background_tasks.add_task(create_next_state, state)

for output in body.outputs[1:]:
Expand All @@ -46,7 +49,7 @@ async def executed_state(namespace_name: str, state_id: ObjectId, body: Executed
error=None,
parents={
**state.parents,
state.identifier: ObjectId(state.id)
state.identifier: state.id
}
)
await new_state.save()
Expand Down
3 changes: 1 addition & 2 deletions state-manager/app/controller/get_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from app.models.secrets_response import SecretsResponseModel
from app.models.db.state import State
from app.models.db.graph_template_model import GraphTemplate
from bson import ObjectId

logger = LogsManager().get_logger()

Expand All @@ -24,7 +23,7 @@ async def get_secrets(namespace_name: str, state_id: str, x_exosphere_request_id
"""
try:
# Get the state
state = await State.get(ObjectId(state_id))
state = await State.get(state_id)
if not state:
logger.error(f"State {state_id} not found", x_exosphere_request_id=x_exosphere_request_id)
raise ValueError(f"State {state_id} not found")
Expand Down
4 changes: 2 additions & 2 deletions state-manager/app/models/db/state.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from bson import ObjectId
from .base import BaseDatabaseModel
from ..state_status_enum import StateStatusEnum
from pydantic import Field
from beanie import PydanticObjectId
from typing import Any, Optional


Expand All @@ -15,4 +15,4 @@ class State(BaseDatabaseModel):
inputs: dict[str, Any] = Field(..., description="Inputs of the state")
outputs: dict[str, Any] = Field(..., description="Outputs of the state")
error: Optional[str] = Field(None, description="Error message")
parents: dict[str, ObjectId] = Field(default_factory=dict, description="Parents of the state")
parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state")
23 changes: 12 additions & 11 deletions state-manager/app/tasks/create_next_state.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
import time

from bson import ObjectId

from app.models.db.state import State
from app.models.db.graph_template_model import GraphTemplate
from app.models.graph_template_validation_status import GraphTemplateValidationStatus
Expand All @@ -14,6 +12,8 @@
async def create_next_state(state: State):
graph_template = None

if state is None or state.id is None:
raise ValueError("State is not valid")
try:
start_time = time.time()
timeout_seconds = 300 # 5 minutes
Expand All @@ -37,9 +37,13 @@ async def create_next_state(state: State):

next_node_identifier = node_template.next_nodes
if not next_node_identifier:
raise Exception(f"Node template {state.identifier} has no next nodes")
state.status = StateStatusEnum.SUCCESS
await state.save()
return

cache_states = {}
cache_states = {}

parents = state.parents | {state.identifier: state.id}

for identifier in next_node_identifier:
next_node_template = graph_template.get_node_by_identifier(identifier)
Expand Down Expand Up @@ -72,15 +76,15 @@ async def create_next_state(state: State):
raise Exception(f"Invalid input placeholder format: '{placeholder_content}' for field {field_name}")

input_identifier = parts[0]
input_field = parts[2]
input_field = parts[2]

parent_id = state.parents.get(input_identifier)
parent_id = parents.get(input_identifier)

if not parent_id:
raise Exception(f"Parent identifier '{input_identifier}' not found in state parents.")

if parent_id not in cache_states:
dependent_state = await State.get(ObjectId(parent_id))
dependent_state = await State.get(parent_id)
if not dependent_state:
raise Exception(f"Dependent state {input_identifier} not found")
cache_states[parent_id] = dependent_state
Expand All @@ -106,10 +110,7 @@ async def create_next_state(state: State):
inputs=next_node_input_data,
outputs={},
error=None,
parents={
**state.parents,
next_node_template.identifier: ObjectId(state.id)
}
parents=parents
)

await new_state.save()
Expand Down