diff --git a/state-manager/app/controller/executed_state.py b/state-manager/app/controller/executed_state.py index b07fe167..f6f8ac9d 100644 --- a/state-manager/app/controller/executed_state.py +++ b/state-manager/app/controller/executed_state.py @@ -15,7 +15,7 @@ async def executed_state(namespace_name: str, state_id: ObjectId, body: Executed logger.info(f"Executed state {state_id} for namespace {namespace_name}", x_exosphere_request_id=x_exosphere_request_id) state = await State.find_one(State.id == state_id) - if not state: + if not state or not state.id: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="State not found") if state.status != StateStatusEnum.QUEUED: @@ -23,15 +23,18 @@ async def executed_state(namespace_name: str, state_id: ObjectId, body: Executed if len(body.outputs) == 0: await State.find_one(State.id == state_id).set( - {"status": StateStatusEnum.EXECUTED, "outputs": {}, "parents": {**state.parents, state.identifier: ObjectId(state.id)}} + {"status": StateStatusEnum.EXECUTED, "outputs": {}, "parents": {**state.parents, state.identifier: state.id}} ) background_tasks.add_task(create_next_state, state) else: - await State.find_one(State.id == state_id).set( - {"status": StateStatusEnum.EXECUTED, "outputs": body.outputs[0], "parents": {**state.parents, state.identifier: ObjectId(state.id)}} - ) + + state.outputs = body.outputs[0] + state.status = StateStatusEnum.EXECUTED + state.parents = {**state.parents, state.identifier: state.id} + await state.save() + background_tasks.add_task(create_next_state, state) for output in body.outputs[1:]: @@ -46,7 +49,7 @@ async def executed_state(namespace_name: str, state_id: ObjectId, body: Executed error=None, parents={ **state.parents, - state.identifier: ObjectId(state.id) + state.identifier: state.id } ) await new_state.save() diff --git a/state-manager/app/controller/get_secrets.py b/state-manager/app/controller/get_secrets.py index ae7ea3ba..f33625ff 100644 --- a/state-manager/app/controller/get_secrets.py +++ b/state-manager/app/controller/get_secrets.py @@ -2,7 +2,6 @@ from app.models.secrets_response import SecretsResponseModel from app.models.db.state import State from app.models.db.graph_template_model import GraphTemplate -from bson import ObjectId logger = LogsManager().get_logger() @@ -24,7 +23,7 @@ async def get_secrets(namespace_name: str, state_id: str, x_exosphere_request_id """ try: # Get the state - state = await State.get(ObjectId(state_id)) + state = await State.get(state_id) if not state: logger.error(f"State {state_id} not found", x_exosphere_request_id=x_exosphere_request_id) raise ValueError(f"State {state_id} not found") diff --git a/state-manager/app/models/db/state.py b/state-manager/app/models/db/state.py index 6989e1a5..dcf5757c 100644 --- a/state-manager/app/models/db/state.py +++ b/state-manager/app/models/db/state.py @@ -1,7 +1,7 @@ -from bson import ObjectId from .base import BaseDatabaseModel from ..state_status_enum import StateStatusEnum from pydantic import Field +from beanie import PydanticObjectId from typing import Any, Optional @@ -15,4 +15,4 @@ class State(BaseDatabaseModel): inputs: dict[str, Any] = Field(..., description="Inputs of the state") outputs: dict[str, Any] = Field(..., description="Outputs of the state") error: Optional[str] = Field(None, description="Error message") - parents: dict[str, ObjectId] = Field(default_factory=dict, description="Parents of the state") \ No newline at end of file + parents: dict[str, PydanticObjectId] = Field(default_factory=dict, description="Parents of the state") \ No newline at end of file diff --git a/state-manager/app/tasks/create_next_state.py b/state-manager/app/tasks/create_next_state.py index 268e818f..8c4c30cf 100644 --- a/state-manager/app/tasks/create_next_state.py +++ b/state-manager/app/tasks/create_next_state.py @@ -1,8 +1,6 @@ import asyncio import time -from bson import ObjectId - from app.models.db.state import State from app.models.db.graph_template_model import GraphTemplate from app.models.graph_template_validation_status import GraphTemplateValidationStatus @@ -14,6 +12,8 @@ async def create_next_state(state: State): graph_template = None + if state is None or state.id is None: + raise ValueError("State is not valid") try: start_time = time.time() timeout_seconds = 300 # 5 minutes @@ -37,9 +37,13 @@ async def create_next_state(state: State): next_node_identifier = node_template.next_nodes if not next_node_identifier: - raise Exception(f"Node template {state.identifier} has no next nodes") + state.status = StateStatusEnum.SUCCESS + await state.save() + return - cache_states = {} + cache_states = {} + + parents = state.parents | {state.identifier: state.id} for identifier in next_node_identifier: next_node_template = graph_template.get_node_by_identifier(identifier) @@ -72,15 +76,15 @@ async def create_next_state(state: State): raise Exception(f"Invalid input placeholder format: '{placeholder_content}' for field {field_name}") input_identifier = parts[0] - input_field = parts[2] + input_field = parts[2] - parent_id = state.parents.get(input_identifier) + parent_id = parents.get(input_identifier) if not parent_id: raise Exception(f"Parent identifier '{input_identifier}' not found in state parents.") if parent_id not in cache_states: - dependent_state = await State.get(ObjectId(parent_id)) + dependent_state = await State.get(parent_id) if not dependent_state: raise Exception(f"Dependent state {input_identifier} not found") cache_states[parent_id] = dependent_state @@ -106,10 +110,7 @@ async def create_next_state(state: State): inputs=next_node_input_data, outputs={}, error=None, - parents={ - **state.parents, - next_node_template.identifier: ObjectId(state.id) - } + parents=parents ) await new_state.save()