Skip to content
2 changes: 1 addition & 1 deletion state-manager/app/models/node_template_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ class NodeTemplate(BaseModel):
identifier: str = Field(..., description="Identifier of the node")
inputs: dict[str, Any] = Field(..., description="Inputs of the node")
next_nodes: Optional[List[str]] = Field(None, description="Next nodes to execute")
unites: Optional[List[Unites]] = Field(None, description="Unites of the node")
unites: Optional[Unites] = Field(None, description="Unites of the node")
117 changes: 70 additions & 47 deletions state-manager/app/tasks/create_next_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,24 +40,21 @@ async def mark_success_states(state_ids: list[PydanticObjectId]):


async def check_unites_satisfied(namespace: str, graph_name: str, node_template: NodeTemplate, parents: dict[str, PydanticObjectId]) -> bool:
if node_template.unites is None or len(node_template.unites) == 0:
if node_template.unites is None:
return True

for unit in node_template.unites:
unites_id = parents.get(unit.identifier)
if not unites_id:
raise ValueError(f"Unit identifier not found in parents: {unit.identifier}")
else:
pending_count = await State.find(
State.identifier == unit.identifier,
unites_id = parents.get(node_template.unites.identifier)
if not unites_id:
raise ValueError(f"Unit identifier not found in parents: {node_template.unites.identifier}")
else:
if await State.find(
State.namespace_name == namespace,
State.graph_name == graph_name,
NE(State.status, StateStatusEnum.SUCCESS),
{
f"parents.{unit.identifier}": unites_id
f"parents.{node_template.unites.identifier}": unites_id
}
).count()
if pending_count > 0:
).count() > 0:
return False
return True

Expand Down Expand Up @@ -107,6 +104,41 @@ def validate_dependencies(next_state_node_template: NodeTemplate, next_state_inp
raise AttributeError(f"Output field '{dependent.field}' not found on state '{dependent.identifier}' for template '{next_state_node_template.identifier}'")


def generate_next_state(next_state_input_model: Type[BaseModel], next_state_node_template: NodeTemplate, parents: dict[str, State], current_state: State) -> State:
next_state_input_data = {}

for field_name, _ in next_state_input_model.model_fields.items():
dependency_string = get_dependents(next_state_node_template.inputs[field_name])

for key in sorted(dependency_string.dependents.keys()):
if dependency_string.dependents[key].identifier == current_state.identifier:
if dependency_string.dependents[key].field not in current_state.outputs:
raise AttributeError(f"Output field '{dependency_string.dependents[key].field}' not found on current state '{current_state.identifier}' for template '{next_state_node_template.identifier}'")
dependency_string.dependents[key].value = current_state.outputs[dependency_string.dependents[key].field]
else:
dependency_string.dependents[key].value = parents[dependency_string.dependents[key].identifier].outputs[dependency_string.dependents[key].field]

next_state_input_data[field_name] = dependency_string.generate_string()

new_parents = {
**current_state.parents,
current_state.identifier: current_state.id
}

return State(
node_name=next_state_node_template.node_name,
identifier=next_state_node_template.identifier,
namespace_name=next_state_node_template.namespace,
graph_name=current_state.graph_name,
status=StateStatusEnum.CREATED,
parents=new_parents,
inputs=next_state_input_data,
outputs={},
run_id=current_state.run_id,
error=None
)


async def create_next_states(state_ids: list[PydanticObjectId], identifier: str, namespace: str, graph_name: str, parents_ids: dict[str, PydanticObjectId]):

try:
Expand Down Expand Up @@ -161,56 +193,47 @@ async def get_input_model(node_template: NodeTemplate) -> Type[BaseModel]:
for parent_state in parent_states:
parents[parent_state.identifier] = parent_state

pending_unites = []

for next_state_identifier in next_state_identifiers:
next_state_node_template = graph_template.get_node_by_identifier(next_state_identifier)
if not next_state_node_template:
raise ValueError(f"Next state node template not found for identifier: {next_state_identifier}")

if not await check_unites_satisfied(namespace, graph_name, next_state_node_template, parents_ids):
if next_state_node_template.unites is not None:
pending_unites.append(next_state_identifier)
continue

next_state_input_model = await get_input_model(next_state_node_template)
validate_dependencies(next_state_node_template, next_state_input_model, identifier, parents)

for current_state in current_states:
next_state_input_data = {}

for field_name, _ in next_state_input_model.model_fields.items():
dependency_string = get_dependents(next_state_node_template.inputs[field_name])

for key in sorted(dependency_string.dependents.keys()):
if dependency_string.dependents[key].identifier == identifier:
if dependency_string.dependents[key].field not in current_state.outputs:
raise AttributeError(f"Output field '{dependency_string.dependents[key].field}' not found on current state '{identifier}' for template '{next_state_node_template.identifier}'")
dependency_string.dependents[key].value = current_state.outputs[dependency_string.dependents[key].field]
else:
dependency_string.dependents[key].value = parents[dependency_string.dependents[key].identifier].outputs[dependency_string.dependents[key].field]

next_state_input_data[field_name] = dependency_string.generate_string()

new_parents = {
**parents_ids,
identifier: current_state.id
}

new_states.append(
State(
node_name=next_state_node_template.node_name,
identifier=next_state_node_template.identifier,
namespace_name=next_state_node_template.namespace,
graph_name=graph_name,
status=StateStatusEnum.CREATED,
parents=new_parents,
inputs=next_state_input_data,
outputs={},
run_id=current_state.run_id,
error=None
)
)
new_states.append(generate_next_state(next_state_input_model, next_state_node_template, parents, current_state))

await State.insert_many(new_states)
if len(new_states) > 0:
await State.insert_many(new_states)
await mark_success_states(state_ids)

# handle unites
new_unit_states = []
for pending_unites_identifier in pending_unites:
next_state_node_template = graph_template.get_node_by_identifier(pending_unites_identifier)
if not next_state_node_template:
raise ValueError(f"Next state node template not found for identifier: {pending_unites_identifier}")

if not await check_unites_satisfied(namespace, graph_name, next_state_node_template, parents_ids):
continue

next_state_input_model = await get_input_model(next_state_node_template)
validate_dependencies(next_state_node_template, next_state_input_model, identifier, parents)

assert next_state_node_template.unites is not None
parent_state = parents[next_state_node_template.unites.identifier]

new_unit_states.append(generate_next_state(next_state_input_model, next_state_node_template, parents, parent_state))

if len(new_unit_states) > 0:
await State.insert_many(new_unit_states)

except Exception as e:
await State.find(
Expand Down
8 changes: 4 additions & 4 deletions state-manager/app/tasks/verify_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,11 @@ async def verify_unites(graph_nodes: list[NodeTemplate], dependency_graph: dict
return

for node in graph_nodes:
if node.unites is None or len(node.unites) == 0:
if node.unites is None:
continue
for depend in node.unites:
if depend.identifier not in dependency_graph[node.identifier]:
errors.append(f"Node {node.identifier} depends on {depend.identifier} which is not a dependency of {node.identifier}")

if node.unites.identifier not in dependency_graph[node.identifier]:
errors.append(f"Node {node.identifier} depends on {node.unites.identifier} which is not a dependency of {node.identifier}")


async def verify_graph(graph_template: GraphTemplate):
Expand Down
55 changes: 55 additions & 0 deletions state-manager/tests/unit/models/test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
from datetime import datetime
from app.models.db.base import BaseDatabaseModel


class TestBaseDatabaseModel:
"""Test cases for BaseDatabaseModel"""

def test_base_model_field_definitions(self):
"""Test that BaseDatabaseModel has the expected fields"""
# Check that the model has the expected fields
model_fields = BaseDatabaseModel.model_fields

assert 'created_at' in model_fields
assert 'updated_at' in model_fields

# Check field descriptions
assert model_fields['created_at'].description == "Date and time when the model was created"
assert model_fields['updated_at'].description == "Date and time when the model was last updated"

def test_base_model_abc_inheritance(self):
"""Test that BaseDatabaseModel is an abstract base class"""
# Should not be able to instantiate BaseDatabaseModel directly
with pytest.raises(Exception): # Could be TypeError or CollectionWasNotInitialized
BaseDatabaseModel()

def test_base_model_document_inheritance(self):
"""Test that BaseDatabaseModel inherits from Document"""
# Check that it has the expected base classes
bases = BaseDatabaseModel.__bases__
assert len(bases) >= 2 # Should have at least ABC and Document as base classes

def test_base_model_has_update_updated_at_method(self):
"""Test that BaseDatabaseModel has the update_updated_at method"""
assert hasattr(BaseDatabaseModel, 'update_updated_at')
assert callable(BaseDatabaseModel.update_updated_at)

def test_base_model_field_types(self):
"""Test that BaseDatabaseModel fields have correct types"""
model_fields = BaseDatabaseModel.model_fields

# Check that created_at and updated_at are datetime fields
created_at_field = model_fields['created_at']
updated_at_field = model_fields['updated_at']

assert created_at_field.annotation == datetime
assert updated_at_field.annotation == datetime

def test_base_model_has_before_event_decorator(self):
"""Test that BaseDatabaseModel uses the before_event decorator"""
# Check that the update_updated_at method exists and is callable
update_method = BaseDatabaseModel.update_updated_at

# The method should exist and be callable
assert callable(update_method)
107 changes: 107 additions & 0 deletions state-manager/tests/unit/models/test_graph_template_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import pytest
from unittest.mock import patch
import base64
from app.models.db.graph_template_model import GraphTemplate


class TestGraphTemplate:
"""Test cases for GraphTemplate model"""

def test_validate_secrets_valid(self):
"""Test validation of valid secrets"""
valid_secrets = {
"secret1": "valid_encrypted_string_that_is_long_enough_for_testing_32_chars",
"secret2": "another_valid_encrypted_string_that_is_long_enough_for_testing_32",
}

# Mock base64 decoding to succeed
with patch("base64.urlsafe_b64decode", return_value=b"x" * 20):
result = GraphTemplate.validate_secrets(valid_secrets)

assert result == valid_secrets

def test_validate_secrets_empty_name(self):
"""Test validation with empty secret name"""
invalid_secrets = {"": "valid_value"}

with pytest.raises(ValueError, match="Secrets cannot be empty"):
GraphTemplate.validate_secrets(invalid_secrets)

def test_validate_secrets_empty_value(self):
"""Test validation with empty secret value"""
invalid_secrets = {"secret1": ""}

with pytest.raises(ValueError, match="Secrets cannot be empty"):
GraphTemplate.validate_secrets(invalid_secrets)

def test_validate_secret_value_too_short(self):
"""Test validation of secret value that's too short"""
short_value = "short"

with pytest.raises(ValueError, match="Value appears to be too short for an encrypted string"):
GraphTemplate._validate_secret_value(short_value)

def test_validate_secret_value_invalid_base64(self):
"""Test validation of secret value with invalid base64"""
invalid_base64 = "invalid_base64_string_that_is_long_enough_but_not_valid_base64"

with pytest.raises(ValueError, match="Value is not valid URL-safe base64 encoded"):
GraphTemplate._validate_secret_value(invalid_base64)

def test_validate_secret_value_valid(self):
"""Test validation of valid secret value"""
# Create a valid base64 string that decodes to at least 12 bytes and is long enough
valid_bytes = b"x" * 20
valid_base64 = base64.urlsafe_b64encode(valid_bytes).decode()
# Pad to make it at least 32 characters
padded_base64 = valid_base64 + "x" * (32 - len(valid_base64))

# Should not raise any exception
GraphTemplate._validate_secret_value(padded_base64)

def test_validate_secrets_with_long_valid_strings(self):
"""Test validation with properly long secret values"""
long_secrets = {
"secret1": "x" * 50, # 50 characters
"secret2": "y" * 100, # 100 characters
}

# Mock base64 decoding to succeed
with patch("base64.urlsafe_b64decode", return_value=b"x" * 20):
result = GraphTemplate.validate_secrets(long_secrets)

assert result == long_secrets

def test_validate_secret_value_exactly_32_chars(self):
"""Test validation with exactly 32 character string"""
exactly_32 = "x" * 32

# Mock base64 decoding to succeed
with patch("base64.urlsafe_b64decode", return_value=b"x" * 20):
# Should not raise exception
GraphTemplate._validate_secret_value(exactly_32)

def test_validate_secret_value_31_chars_fails(self):
"""Test validation with 31 character string fails"""
exactly_31 = "x" * 31

with pytest.raises(ValueError, match="Value appears to be too short for an encrypted string"):
GraphTemplate._validate_secret_value(exactly_31)

def test_validate_secret_value_base64_decode_exception(self):
"""Test validation when base64 decoding raises exception"""
invalid_base64 = "invalid_base64_string_that_is_long_enough_but_not_valid_base64"

with pytest.raises(ValueError, match="Value is not valid URL-safe base64 encoded"):
GraphTemplate._validate_secret_value(invalid_base64)

def test_validate_secret_value_decoded_exactly_12_bytes(self):
"""Test validation with decoded value exactly 12 bytes"""
# Create a valid base64 string that decodes to exactly 12 bytes and is long enough
exactly_12_bytes = b"x" * 12
base64_string = base64.urlsafe_b64encode(exactly_12_bytes).decode()
# Pad to make it at least 32 characters
padded_base64 = base64_string + "x" * (32 - len(base64_string))

# Should not raise exception
GraphTemplate._validate_secret_value(padded_base64)
Loading
Loading