Skip to content

Commit

Permalink
fix hierarchy schema update (#4858)
Browse files Browse the repository at this point in the history
* failing integration test for hierarchy schema update

* fix hierarchy update, more tests

* add changelogs
  • Loading branch information
ajtmccarty authored Nov 6, 2024
1 parent 4372ba1 commit a81f3ea
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 54 deletions.
107 changes: 54 additions & 53 deletions backend/infrahub/core/schema/schema_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ def process_validate(self) -> None:
def process_post_validation(self) -> None:
self.cleanup_inherited_elements()
self.add_groups()
self.add_hierarchy()
self.add_hierarchy_generic()
self.add_hierarchy_node()
self.generate_weight()
self.process_labels()
self.process_dropdowns()
Expand Down Expand Up @@ -1368,7 +1369,34 @@ def add_groups(self) -> None:
if changed:
self.set(name=node_name, schema=schema)

def add_hierarchy(self) -> None:
def _get_hierarchy_child_rel(self, peer: str, hierarchical: str, read_only: bool) -> RelationshipSchema:
return RelationshipSchema(
name="children",
identifier="parent__child",
peer=peer,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.MANY,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.INBOUND,
hierarchical=hierarchical,
read_only=read_only,
)

def _get_hierarchy_parent_rel(self, peer: str, hierarchical: str, read_only: bool) -> RelationshipSchema:
return RelationshipSchema(
name="parent",
identifier="parent__child",
peer=peer,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.ONE,
max_count=1,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.OUTBOUND,
hierarchical=hierarchical,
read_only=read_only,
)

def add_hierarchy_generic(self) -> None:
for generic_name in self.generics.keys():
generic = self.get_generic(name=generic_name, duplicate=False)

Expand All @@ -1380,36 +1408,16 @@ def add_hierarchy(self) -> None:

if "parent" not in generic.relationship_names:
generic.relationships.append(
RelationshipSchema(
name="parent",
identifier="parent__child",
peer=generic_name,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.ONE,
max_count=1,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.OUTBOUND,
hierarchical=generic_name,
read_only=read_only,
)
self._get_hierarchy_parent_rel(peer=generic_name, hierarchical=generic_name, read_only=read_only)
)
if "children" not in generic.relationship_names:
generic.relationships.append(
RelationshipSchema(
name="children",
identifier="parent__child",
peer=generic_name,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.MANY,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.INBOUND,
hierarchical=generic_name,
read_only=read_only,
)
self._get_hierarchy_child_rel(peer=generic_name, hierarchical=generic_name, read_only=read_only)
)

self.set(name=generic_name, schema=generic)

def add_hierarchy_node(self) -> None:
for node_name in self.nodes.keys():
node = self.get_node(name=node_name, duplicate=False)

Expand All @@ -1419,36 +1427,29 @@ def add_hierarchy(self) -> None:
node = node.duplicate()
read_only = InfrahubKind.IPPREFIX in node.inherit_from

if node.parent and "parent" not in node.relationship_names:
node.relationships.append(
RelationshipSchema(
name="parent",
identifier="parent__child",
peer=node.parent,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.ONE,
max_count=1,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.OUTBOUND,
hierarchical=node.hierarchy,
read_only=read_only,
if node.parent:
if "parent" not in node.relationship_names:
node.relationships.append(
self._get_hierarchy_parent_rel(
peer=node.parent, hierarchical=node.hierarchy, read_only=read_only
)
)
)

if node.children and "children" not in node.relationship_names:
node.relationships.append(
RelationshipSchema(
name="children",
identifier="parent__child",
peer=node.children,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.MANY,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.INBOUND,
hierarchical=node.hierarchy,
read_only=read_only,
else:
parent_rel = node.get_relationship(name="parent")
if parent_rel.peer != node.parent:
parent_rel.peer = node.parent

if node.children:
if "children" not in node.relationship_names:
node.relationships.append(
self._get_hierarchy_child_rel(
peer=node.children, hierarchical=node.hierarchy, read_only=read_only
)
)
)
else:
children_rel = node.get_relationship(name="children")
if children_rel.peer != node.children:
children_rel.peer = node.children

self.set(name=node_name, schema=node)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import pytest
from infrahub_sdk.client import InfrahubClient

from infrahub.core import registry
from infrahub.core.branch.models import Branch
from infrahub.core.initialization import create_branch
from infrahub.core.node import Node
from infrahub.core.schema import SchemaRoot
from infrahub.core.schema.attribute_schema import AttributeSchema
from infrahub.core.schema.generic_schema import GenericSchema
from infrahub.core.schema.node_schema import NodeSchema
from infrahub.database import InfrahubDatabase
from tests.helpers.test_app import TestInfrahubApp

PERSON_KIND = "TestingPerson"
CAR_KIND = "TestingCar"
MANUFACTURER_KIND_01 = "TestingManufacturer"
MANUFACTURER_KIND_03 = "TestingCarMaker"
TAG_KIND = "TestingTag"


class TestSchemaLifecycleBase(TestInfrahubApp):
@pytest.fixture(scope="class")
def schema_location_generic(self) -> GenericSchema:
return GenericSchema(
name="Generic",
namespace="Location",
hierarchical=True,
attributes=[AttributeSchema(name="name", kind="Text", unique=True)],
)

@pytest.fixture(scope="class")
def schema_location_country(self) -> NodeSchema:
return NodeSchema(
name="Country", namespace="Location", inherit_from=["LocationGeneric"], children="LocationSite", parent=""
)

@pytest.fixture(scope="class")
def schema_location_site(self) -> NodeSchema:
return NodeSchema(
name="Site", namespace="Location", inherit_from=["LocationGeneric"], children="", parent="LocationCountry"
)

@pytest.fixture(scope="class")
async def location_schema_01(
self,
schema_location_generic: GenericSchema,
schema_location_country: NodeSchema,
schema_location_site: NodeSchema,
) -> SchemaRoot:
return SchemaRoot(
version="1.0", generics=[schema_location_generic], nodes=[schema_location_country, schema_location_site]
)

@pytest.fixture(scope="class")
def schema_location_country_02(self) -> NodeSchema:
return NodeSchema(
name="Country",
namespace="Location",
inherit_from=["LocationGeneric"],
children="LocationMetro",
parent=None,
)

@pytest.fixture(scope="class")
def schema_location_metro_02(self) -> NodeSchema:
return NodeSchema(
name="Metro",
namespace="Location",
inherit_from=["LocationGeneric"],
children="LocationSite",
parent="LocationCountry",
)

@pytest.fixture(scope="class")
def schema_location_site_02(self) -> NodeSchema:
return NodeSchema(
name="Site", namespace="Location", inherit_from=["LocationGeneric"], children=None, parent="LocationMetro"
)

@pytest.fixture(scope="class")
async def location_schema_02(
self,
schema_location_generic: GenericSchema,
schema_location_country_02: NodeSchema,
schema_location_site_02: NodeSchema,
schema_location_metro_02: NodeSchema,
) -> SchemaRoot:
return SchemaRoot(
version="1.0",
generics=[schema_location_generic],
nodes=[schema_location_country_02, schema_location_metro_02, schema_location_site_02],
)

@pytest.fixture(scope="class")
async def initial_schema(
self, db: InfrahubDatabase, initialize_registry, default_branch: Branch, location_schema_01: SchemaRoot
) -> None:
branch_schema = registry.schema.get_schema_branch(name=default_branch.name)
tmp_schema = branch_schema.duplicate()
tmp_schema.load_schema(schema=location_schema_01)
tmp_schema.process()

await registry.schema.update_schema_branch(schema=tmp_schema, db=db, branch=default_branch.name, update_db=True)

@pytest.fixture(scope="class")
async def branch_1(self, db: InfrahubDatabase) -> Branch:
return await create_branch(db=db, branch_name="branch_1")

async def test_baseline(
self, db: InfrahubDatabase, client: InfrahubClient, initial_schema: dict[str, Node]
) -> None:
country_schema = await client.schema.get(kind="LocationCountry")
rels_by_name = {r.name: r for r in country_schema.relationships}
assert rels_by_name["parent"].peer == "LocationGeneric"
assert rels_by_name["children"].peer == "LocationSite"
site_schema = await client.schema.get(kind="LocationSite")
rels_by_name = {r.name: r for r in site_schema.relationships}
assert rels_by_name["parent"].peer == "LocationCountry"
assert rels_by_name["children"].peer == "LocationGeneric"

async def test_check_schema_02(self, client: InfrahubClient, branch_1: Branch, location_schema_02: SchemaRoot):
success, response = await client.schema.check(
schemas=[location_schema_02.model_dump(mode="json")], branch=branch_1.name
)
assert success
assert response == {
"diff": {
"added": {"LocationMetro": {"added": {}, "changed": {}, "removed": {}}},
"removed": {},
"changed": {
"LocationSite": {
"added": {},
"removed": {},
"changed": {
"parent": None,
"relationships": {
"added": {},
"removed": {},
"changed": {"parent": {"added": {}, "removed": {}, "changed": {"peer": None}}},
},
},
},
"LocationCountry": {
"added": {},
"removed": {},
"changed": {
"children": None,
"relationships": {
"added": {},
"removed": {},
"changed": {"children": {"added": {}, "removed": {}, "changed": {"peer": None}}},
},
},
},
"LocationGeneric": {"added": {}, "changed": {"used_by": None}, "removed": {}},
},
},
}

async def test_load_schema_02(
self, db: InfrahubDatabase, client: InfrahubClient, branch_1: Branch, location_schema_02: SchemaRoot
):
response = await client.schema.load(schemas=[location_schema_02.model_dump(mode="json")], branch=branch_1.name)
assert not response.errors

country_schema = await client.schema.get(kind="LocationCountry", branch=branch_1.name)
rels_by_name = {r.name: r for r in country_schema.relationships}
assert rels_by_name["parent"].peer == "LocationGeneric"
assert rels_by_name["children"].peer == "LocationMetro"
metro_schema = await client.schema.get(kind="LocationMetro", branch=branch_1.name)
rels_by_name = {r.name: r for r in metro_schema.relationships}
assert rels_by_name["parent"].peer == "LocationCountry"
assert rels_by_name["children"].peer == "LocationSite"
site_schema = await client.schema.get(kind="LocationSite", branch=branch_1.name)
rels_by_name = {r.name: r for r in site_schema.relationships}
assert rels_by_name["parent"].peer == "LocationMetro"
assert rels_by_name["children"].peer == "LocationGeneric"

country_schema = db.schema.get(name="LocationCountry", branch=branch_1, duplicate=False)
assert country_schema.parent == "" # noqa: PLC1901
assert country_schema.children == "LocationMetro"
metro_schema = db.schema.get(name="LocationMetro", branch=branch_1, duplicate=False)
assert metro_schema.parent == "LocationCountry"
assert metro_schema.children == "LocationSite"
site_schema = db.schema.get(name="LocationSite", branch=branch_1, duplicate=False)
assert site_schema.parent == "LocationMetro"
assert site_schema.children == "" # noqa: PLC1901
3 changes: 2 additions & 1 deletion backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,7 +2011,7 @@ async def fruit_tag_schema_global(db: InfrahubDatabase, group_schema, data_schem


@pytest.fixture
async def hierarchical_location_schema_simple(db: InfrahubDatabase, default_branch: Branch) -> None:
async def hierarchical_location_schema_simple(db: InfrahubDatabase, default_branch: Branch) -> SchemaRoot:
SCHEMA: dict[str, Any] = {
"generics": [
{
Expand Down Expand Up @@ -2072,6 +2072,7 @@ async def hierarchical_location_schema_simple(db: InfrahubDatabase, default_bran

schema = SchemaRoot(**SCHEMA)
registry.schema.register_schema(schema=schema, branch=default_branch.name)
return schema


@pytest.fixture
Expand Down
Loading

0 comments on commit a81f3ea

Please sign in to comment.