Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion state-manager/app/models/node_template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
from typing import Any, Optional, List


class Unites(BaseModel):
identifier: str = Field(..., description="Identifier of the node")


class NodeTemplate(BaseModel):
node_name: str = Field(..., description="Name of the node")
namespace: str = Field(..., description="Namespace of the node")
identifier: str = Field(..., description="Identifier of the node")
inputs: dict[str, Any] = Field(..., description="Inputs of the node")
next_nodes: Optional[List[str]] = Field(None, description="Next nodes to execute")
next_nodes: Optional[List[str]] = Field(None, description="Next nodes to execute")
unites: Optional[List[Unites]] = Field(None, description="Unites of the node")
30 changes: 30 additions & 0 deletions state-manager/app/tasks/create_next_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@
from app.models.graph_template_validation_status import GraphTemplateValidationStatus
from app.models.db.registered_node import RegisteredNode
from app.models.state_status_enum import StateStatusEnum
from beanie.operators import NE
from app.singletons.logs_manager import LogsManager

from json_schema_to_pydantic import create_model

logger = LogsManager().get_logger()

async def create_next_state(state: State):
graph_template = None

Expand Down Expand Up @@ -50,6 +54,32 @@ async def create_next_state(state: State):
if not next_node_template:
continue

depends_satisfied = True
if next_node_template.unites is not None and len(next_node_template.unites) > 0:
pending_count = 0
for depend in next_node_template.unites:
if depend.identifier == state.identifier:
continue
else:
root_parent = state.parents.get(depend.identifier)
if root_parent is None:
raise Exception(f"Root parent of {depend.identifier} not found")

pending_count = await State.find(
State.identifier == depend.identifier,
State.namespace_name == state.namespace_name,
State.graph_name == state.graph_name,
NE(State.status, StateStatusEnum.SUCCESS),
{f"parents.{depend.identifier}": parents[depend.identifier]}
).count()
if pending_count > 0:
logger.info(f"Node {next_node_template.identifier} depends on {depend.identifier} but it is not satisfied")
depends_satisfied = False
break

if not depends_satisfied:
continue

registered_node = await RegisteredNode.find_one(RegisteredNode.name == next_node_template.node_name, RegisteredNode.namespace == next_node_template.namespace)

if not registered_node:
Expand Down
17 changes: 16 additions & 1 deletion state-manager/app/tasks/verify_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,18 @@ def dfs_visit(current_node: str, parent_node: str | None = None, current_path: l

return dependency_graph

async def verify_unites(graph_nodes: list[NodeTemplate], dependency_graph: dict | None, errors: list[str]):
if dependency_graph is None:
return

for node in graph_nodes:
if node.unites is None or len(node.unites) == 0:
continue
for depend in node.unites:
if depend.identifier not in dependency_graph[node.identifier]:
errors.append(f"Node {node.identifier} depends on {depend.identifier} which is not a dependency of {node.identifier}")


async def verify_graph(graph_template: GraphTemplate):
try:
errors = []
Expand All @@ -229,16 +241,19 @@ async def verify_graph(graph_template: GraphTemplate):
await verify_node_exists(graph_template.nodes, database_nodes, errors)
await verify_node_identifiers(graph_template.nodes, errors)
await verify_secrets(graph_template, database_nodes, errors)
dependency_graph = await verify_topology(graph_template.nodes, errors)
dependency_graph = await verify_topology(graph_template.nodes, errors)

if dependency_graph is not None and len(errors) == 0:
await verify_inputs(graph_template.nodes, database_nodes, dependency_graph, errors)

await verify_unites(graph_template.nodes, dependency_graph, errors)

if errors or dependency_graph is None:
graph_template.validation_status = GraphTemplateValidationStatus.INVALID
graph_template.validation_errors = errors
await graph_template.save()
return

graph_template.validation_status = GraphTemplateValidationStatus.VALID
graph_template.validation_errors = None
await graph_template.save()
Expand Down