Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix hierarchy schema update #4858

Merged
merged 3 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 54 additions & 53 deletions backend/infrahub/core/schema/schema_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,8 @@ def process_validate(self) -> None:
def process_post_validation(self) -> None:
self.cleanup_inherited_elements()
self.add_groups()
self.add_hierarchy()
self.add_hierarchy_generic()
self.add_hierarchy_node()
self.generate_weight()
self.process_labels()
self.process_dropdowns()
Expand Down Expand Up @@ -1368,7 +1369,34 @@ def add_groups(self) -> None:
if changed:
self.set(name=node_name, schema=schema)

def add_hierarchy(self) -> None:
def _get_hierarchy_child_rel(self, peer: str, hierarchical: str, read_only: bool) -> RelationshipSchema:
return RelationshipSchema(
name="children",
identifier="parent__child",
peer=peer,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.MANY,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.INBOUND,
hierarchical=hierarchical,
read_only=read_only,
)

def _get_hierarchy_parent_rel(self, peer: str, hierarchical: str, read_only: bool) -> RelationshipSchema:
return RelationshipSchema(
name="parent",
identifier="parent__child",
peer=peer,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.ONE,
max_count=1,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.OUTBOUND,
hierarchical=hierarchical,
read_only=read_only,
)

def add_hierarchy_generic(self) -> None:
for generic_name in self.generics.keys():
generic = self.get_generic(name=generic_name, duplicate=False)

Expand All @@ -1380,36 +1408,16 @@ def add_hierarchy(self) -> None:

if "parent" not in generic.relationship_names:
generic.relationships.append(
RelationshipSchema(
name="parent",
identifier="parent__child",
peer=generic_name,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.ONE,
max_count=1,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.OUTBOUND,
hierarchical=generic_name,
read_only=read_only,
)
self._get_hierarchy_parent_rel(peer=generic_name, hierarchical=generic_name, read_only=read_only)
)
if "children" not in generic.relationship_names:
generic.relationships.append(
RelationshipSchema(
name="children",
identifier="parent__child",
peer=generic_name,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.MANY,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.INBOUND,
hierarchical=generic_name,
read_only=read_only,
)
self._get_hierarchy_child_rel(peer=generic_name, hierarchical=generic_name, read_only=read_only)
)

self.set(name=generic_name, schema=generic)

def add_hierarchy_node(self) -> None:
for node_name in self.nodes.keys():
node = self.get_node(name=node_name, duplicate=False)

Expand All @@ -1419,36 +1427,29 @@ def add_hierarchy(self) -> None:
node = node.duplicate()
read_only = InfrahubKind.IPPREFIX in node.inherit_from

if node.parent and "parent" not in node.relationship_names:
node.relationships.append(
RelationshipSchema(
name="parent",
identifier="parent__child",
peer=node.parent,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.ONE,
max_count=1,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.OUTBOUND,
hierarchical=node.hierarchy,
read_only=read_only,
if node.parent:
if "parent" not in node.relationship_names:
node.relationships.append(
self._get_hierarchy_parent_rel(
peer=node.parent, hierarchical=node.hierarchy, read_only=read_only
)
)
)

if node.children and "children" not in node.relationship_names:
node.relationships.append(
RelationshipSchema(
name="children",
identifier="parent__child",
peer=node.children,
kind=RelationshipKind.HIERARCHY,
cardinality=RelationshipCardinality.MANY,
branch=BranchSupportType.AWARE,
direction=RelationshipDirection.INBOUND,
hierarchical=node.hierarchy,
read_only=read_only,
else:
parent_rel = node.get_relationship(name="parent")
if parent_rel.peer != node.parent:
parent_rel.peer = node.parent

if node.children:
if "children" not in node.relationship_names:
node.relationships.append(
self._get_hierarchy_child_rel(
peer=node.children, hierarchical=node.hierarchy, read_only=read_only
)
)
)
else:
children_rel = node.get_relationship(name="children")
if children_rel.peer != node.children:
children_rel.peer = node.children

self.set(name=node_name, schema=node)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import pytest
from infrahub_sdk.client import InfrahubClient

from infrahub.core import registry
from infrahub.core.branch.models import Branch
from infrahub.core.initialization import create_branch
from infrahub.core.node import Node
from infrahub.core.schema import SchemaRoot
from infrahub.core.schema.attribute_schema import AttributeSchema
from infrahub.core.schema.generic_schema import GenericSchema
from infrahub.core.schema.node_schema import NodeSchema
from infrahub.database import InfrahubDatabase
from tests.helpers.test_app import TestInfrahubApp

PERSON_KIND = "TestingPerson"
CAR_KIND = "TestingCar"
MANUFACTURER_KIND_01 = "TestingManufacturer"
MANUFACTURER_KIND_03 = "TestingCarMaker"
TAG_KIND = "TestingTag"


class TestSchemaLifecycleBase(TestInfrahubApp):
@pytest.fixture(scope="class")
def schema_location_generic(self) -> GenericSchema:
return GenericSchema(
name="Generic",
namespace="Location",
hierarchical=True,
attributes=[AttributeSchema(name="name", kind="Text", unique=True)],
)

@pytest.fixture(scope="class")
def schema_location_country(self) -> NodeSchema:
return NodeSchema(
name="Country", namespace="Location", inherit_from=["LocationGeneric"], children="LocationSite", parent=""
)

@pytest.fixture(scope="class")
def schema_location_site(self) -> NodeSchema:
return NodeSchema(
name="Site", namespace="Location", inherit_from=["LocationGeneric"], children="", parent="LocationCountry"
)

@pytest.fixture(scope="class")
async def location_schema_01(
self,
schema_location_generic: GenericSchema,
schema_location_country: NodeSchema,
schema_location_site: NodeSchema,
) -> SchemaRoot:
return SchemaRoot(
version="1.0", generics=[schema_location_generic], nodes=[schema_location_country, schema_location_site]
)

@pytest.fixture(scope="class")
def schema_location_country_02(self) -> NodeSchema:
return NodeSchema(
name="Country",
namespace="Location",
inherit_from=["LocationGeneric"],
children="LocationMetro",
parent=None,
)

@pytest.fixture(scope="class")
def schema_location_metro_02(self) -> NodeSchema:
return NodeSchema(
name="Metro",
namespace="Location",
inherit_from=["LocationGeneric"],
children="LocationSite",
parent="LocationCountry",
)

@pytest.fixture(scope="class")
def schema_location_site_02(self) -> NodeSchema:
return NodeSchema(
name="Site", namespace="Location", inherit_from=["LocationGeneric"], children=None, parent="LocationMetro"
)

@pytest.fixture(scope="class")
async def location_schema_02(
self,
schema_location_generic: GenericSchema,
schema_location_country_02: NodeSchema,
schema_location_site_02: NodeSchema,
schema_location_metro_02: NodeSchema,
) -> SchemaRoot:
return SchemaRoot(
version="1.0",
generics=[schema_location_generic],
nodes=[schema_location_country_02, schema_location_metro_02, schema_location_site_02],
)

@pytest.fixture(scope="class")
async def initial_schema(
self, db: InfrahubDatabase, initialize_registry, default_branch: Branch, location_schema_01: SchemaRoot
) -> None:
branch_schema = registry.schema.get_schema_branch(name=default_branch.name)
tmp_schema = branch_schema.duplicate()
tmp_schema.load_schema(schema=location_schema_01)
tmp_schema.process()

await registry.schema.update_schema_branch(schema=tmp_schema, db=db, branch=default_branch.name, update_db=True)

@pytest.fixture(scope="class")
async def branch_1(self, db: InfrahubDatabase) -> Branch:
return await create_branch(db=db, branch_name="branch_1")

async def test_baseline(
self, db: InfrahubDatabase, client: InfrahubClient, initial_schema: dict[str, Node]
) -> None:
country_schema = await client.schema.get(kind="LocationCountry")
rels_by_name = {r.name: r for r in country_schema.relationships}
assert rels_by_name["parent"].peer == "LocationGeneric"
assert rels_by_name["children"].peer == "LocationSite"
site_schema = await client.schema.get(kind="LocationSite")
rels_by_name = {r.name: r for r in site_schema.relationships}
assert rels_by_name["parent"].peer == "LocationCountry"
assert rels_by_name["children"].peer == "LocationGeneric"

async def test_check_schema_02(self, client: InfrahubClient, branch_1: Branch, location_schema_02: SchemaRoot):
success, response = await client.schema.check(
schemas=[location_schema_02.model_dump(mode="json")], branch=branch_1.name
)
assert success
assert response == {
"diff": {
"added": {"LocationMetro": {"added": {}, "changed": {}, "removed": {}}},
"removed": {},
"changed": {
"LocationSite": {
"added": {},
"removed": {},
"changed": {
"parent": None,
"relationships": {
"added": {},
"removed": {},
"changed": {"parent": {"added": {}, "removed": {}, "changed": {"peer": None}}},
},
},
},
"LocationCountry": {
"added": {},
"removed": {},
"changed": {
"children": None,
"relationships": {
"added": {},
"removed": {},
"changed": {"children": {"added": {}, "removed": {}, "changed": {"peer": None}}},
},
},
},
"LocationGeneric": {"added": {}, "changed": {"used_by": None}, "removed": {}},
},
},
}

async def test_load_schema_02(
self, db: InfrahubDatabase, client: InfrahubClient, branch_1: Branch, location_schema_02: SchemaRoot
):
response = await client.schema.load(schemas=[location_schema_02.model_dump(mode="json")], branch=branch_1.name)
assert not response.errors

country_schema = await client.schema.get(kind="LocationCountry", branch=branch_1.name)
rels_by_name = {r.name: r for r in country_schema.relationships}
assert rels_by_name["parent"].peer == "LocationGeneric"
assert rels_by_name["children"].peer == "LocationMetro"
metro_schema = await client.schema.get(kind="LocationMetro", branch=branch_1.name)
rels_by_name = {r.name: r for r in metro_schema.relationships}
assert rels_by_name["parent"].peer == "LocationCountry"
assert rels_by_name["children"].peer == "LocationSite"
site_schema = await client.schema.get(kind="LocationSite", branch=branch_1.name)
rels_by_name = {r.name: r for r in site_schema.relationships}
assert rels_by_name["parent"].peer == "LocationMetro"
assert rels_by_name["children"].peer == "LocationGeneric"

country_schema = db.schema.get(name="LocationCountry", branch=branch_1, duplicate=False)
assert country_schema.parent == "" # noqa: PLC1901
assert country_schema.children == "LocationMetro"
metro_schema = db.schema.get(name="LocationMetro", branch=branch_1, duplicate=False)
assert metro_schema.parent == "LocationCountry"
assert metro_schema.children == "LocationSite"
site_schema = db.schema.get(name="LocationSite", branch=branch_1, duplicate=False)
assert site_schema.parent == "LocationMetro"
assert site_schema.children == "" # noqa: PLC1901
3 changes: 2 additions & 1 deletion backend/tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2011,7 +2011,7 @@ async def fruit_tag_schema_global(db: InfrahubDatabase, group_schema, data_schem


@pytest.fixture
async def hierarchical_location_schema_simple(db: InfrahubDatabase, default_branch: Branch) -> None:
async def hierarchical_location_schema_simple(db: InfrahubDatabase, default_branch: Branch) -> SchemaRoot:
SCHEMA: dict[str, Any] = {
"generics": [
{
Expand Down Expand Up @@ -2072,6 +2072,7 @@ async def hierarchical_location_schema_simple(db: InfrahubDatabase, default_bran

schema = SchemaRoot(**SCHEMA)
registry.schema.register_schema(schema=schema, branch=default_branch.name)
return schema


@pytest.fixture
Expand Down
Loading
Loading