Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IFC-587 track conflicts with an attribute instead of by ID #4223

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 19 additions & 21 deletions backend/infrahub/core/diff/data_check_synchronizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from enum import Enum
from typing import TYPE_CHECKING

from infrahub.core.constants import BranchConflictKeep, InfrahubKind, ProposedChangeState
from infrahub.core.integrity.object_conflict.conflict_recorder import ObjectConflictValidatorRecorder
Expand All @@ -10,9 +9,6 @@
from .conflicts_extractor import DiffConflictsExtractor
from .model.path import ConflictSelection, EnrichedDiffConflict, EnrichedDiffRoot

if TYPE_CHECKING:
from infrahub.core.protocols import CoreProposedChange


class DiffDataCheckSynchronizer:
def __init__(
Expand All @@ -33,26 +29,28 @@ async def synchronize(self, enriched_diff: EnrichedDiffRoot) -> list[Node]:
)
if not proposed_changes:
return []
proposed_change: CoreProposedChange = proposed_changes[0]
enriched_conflicts = enriched_diff.get_all_conflicts()
data_conflicts = await self.conflicts_extractor.get_data_conflicts(enriched_diff_root=enriched_diff)
core_data_checks = await self.conflict_recorder.record_conflicts(
proposed_change_id=proposed_change.get_id(), conflicts=data_conflicts
)
core_data_checks_by_id = {cdc.get_id(): cdc for cdc in core_data_checks}
enriched_conflicts_by_id = {ec.uuid: ec for ec in enriched_conflicts}
for conflict_id, core_data_check in core_data_checks_by_id.items():
enriched_conflict = enriched_conflicts_by_id.get(conflict_id)
if not enriched_conflict:
continue
expected_keep_branch = self._get_keep_branch_for_enriched_conflict(enriched_conflict=enriched_conflict)
expected_keep_branch_value = (
expected_keep_branch.value if isinstance(expected_keep_branch, Enum) else expected_keep_branch
all_data_checks = []
for pc in proposed_changes:
core_data_checks = await self.conflict_recorder.record_conflicts(
proposed_change_id=pc.get_id(), conflicts=data_conflicts
)
if core_data_check.keep_branch.value != expected_keep_branch_value: # type: ignore[attr-defined]
core_data_check.keep_branch.value = expected_keep_branch_value # type: ignore[attr-defined]
await core_data_check.save(db=self.db)
return core_data_checks
all_data_checks.extend(core_data_checks)
core_data_checks_by_id = {cdc.enriched_conflict_id.value: cdc for cdc in core_data_checks} # type: ignore[attr-defined]
enriched_conflicts_by_id = {ec.uuid: ec for ec in enriched_conflicts}
for conflict_id, core_data_check in core_data_checks_by_id.items():
enriched_conflict = enriched_conflicts_by_id.get(conflict_id)
if not enriched_conflict:
continue
expected_keep_branch = self._get_keep_branch_for_enriched_conflict(enriched_conflict=enriched_conflict)
expected_keep_branch_value = (
expected_keep_branch.value if isinstance(expected_keep_branch, Enum) else expected_keep_branch
)
if core_data_check.keep_branch.value != expected_keep_branch_value: # type: ignore[attr-defined]
core_data_check.keep_branch.value = expected_keep_branch_value # type: ignore[attr-defined]
await core_data_check.save(db=self.db)
return all_data_checks

def _get_keep_branch_for_enriched_conflict(
self, enriched_conflict: EnrichedDiffConflict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ async def record_conflicts(self, proposed_change_id: str, conflicts: Sequence[Ob
await self.initialize_validator(validator)

previous_checks = await validator.checks.get_peers(db=self.db) # type: ignore[attr-defined]
previous_checks_by_conflict_id = {check.enriched_conflict_id.value: check for check in previous_checks.values()}
is_success = False

current_checks: list[Node] = []
Expand Down Expand Up @@ -56,9 +57,9 @@ async def record_conflicts(self, proposed_change_id: str, conflicts: Sequence[Ob
for conflict in conflicts:
conflicts_data = [conflict.to_conflict_dict()]
conflict_obj = None
if conflict.conflict_id and conflict.conflict_id in previous_checks:
conflict_obj = previous_checks[conflict.conflict_id]
check_ids_to_keep.add(conflict.conflict_id)
if conflict.conflict_id and conflict.conflict_id in previous_checks_by_conflict_id:
conflict_obj = previous_checks_by_conflict_id[conflict.conflict_id]
check_ids_to_keep.add(conflict_obj.get_id())
if not conflict_obj:
for previous_check in previous_checks.values():
if previous_check.conflicts.value == conflicts_data: # type: ignore[attr-defined]
Expand All @@ -70,7 +71,7 @@ async def record_conflicts(self, proposed_change_id: str, conflicts: Sequence[Ob

await conflict_obj.new(
db=self.db,
id=conflict.conflict_id,
enriched_conflict_id=conflict.conflict_id,
label=conflict.label,
origin="internal",
kind="DataIntegrity",
Expand Down
2 changes: 2 additions & 0 deletions backend/infrahub/core/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ class CoreCustomWebhook(CoreWebhook, CoreTaskTarget):
class CoreDataCheck(CoreCheck):
conflicts: JSONAttribute
keep_branch: Enum
enriched_conflict_id: StringOptional


class CoreDataValidator(CoreValidator):
Expand Down Expand Up @@ -389,6 +390,7 @@ class CoreRepositoryValidator(CoreValidator):

class CoreSchemaCheck(CoreCheck):
conflicts: JSONAttribute
enriched_conflict_id: StringOptional


class CoreSchemaValidator(CoreValidator):
Expand Down
2 changes: 2 additions & 0 deletions backend/infrahub/core/schema/definitions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1317,6 +1317,7 @@
"attributes": [
{"name": "conflicts", "kind": "JSON"},
{"name": "keep_branch", "enum": BranchConflictKeep.available_types(), "kind": "Text", "optional": True},
{"name": "enriched_conflict_id", "kind": "Text", "optional": True},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this requires a migration b/c CoreDataChecks are not long-lived objects and the new field is only used with the new diff

],
},
{
Expand All @@ -1342,6 +1343,7 @@
"branch": BranchSupportType.AGNOSTIC.value,
"attributes": [
{"name": "conflicts", "kind": "JSON"},
{"name": "enriched_conflict_id", "kind": "Text", "optional": True},
],
},
{
Expand Down
11 changes: 7 additions & 4 deletions backend/infrahub/graphql/mutations/diff_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,18 @@ async def mutate(
selection = ConflictSelection(data.selected_branch.value) if data.selected_branch else None
await diff_repo.update_conflict_by_id(conflict_id=data.conflict_id, selection=selection)

core_data_check = await NodeManager.get_one(db=context.db, id=data.conflict_id, kind=InfrahubKind.DATACHECK)
if not core_data_check:
core_data_checks = await NodeManager.query(
db=context.db, schema=InfrahubKind.DATACHECK, filters={"enriched_conflict_id__value": data.conflict_id}
)
if not core_data_checks:
return cls(ok=True)
if data.selected_branch is GraphQlConflictSelection.BASE_BRANCH:
keep_branch = BranchConflictKeep.TARGET
elif data.selected_branch is GraphQlConflictSelection.DIFF_BRANCH:
keep_branch = BranchConflictKeep.SOURCE
else:
keep_branch = None
core_data_check.keep_branch.value = keep_branch
await core_data_check.save(db=context.db)
for cdc in core_data_checks:
cdc.keep_branch.value = keep_branch
await cdc.save(db=context.db)
return cls(ok=True)
104 changes: 82 additions & 22 deletions backend/tests/integration/diff/test_diff_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ async def get_branch_diff(db: InfrahubDatabase, branch: Branch) -> EnrichedDiffR
)

async def _get_proposed_change_and_data_validator(self, db) -> tuple[Node, Node]:
pcs = await NodeManager.query(db=db, schema=InfrahubKind.PROPOSEDCHANGE, filters={"source_branch": BRANCH_NAME})
pcs = await NodeManager.query(
db=db, schema=InfrahubKind.PROPOSEDCHANGE, filters={"name__value": PROPOSED_CHANGE_NAME}
)
assert len(pcs) == 1
pc = pcs[0]
validators = await pc.validations.get_peers(db=db)
Expand Down Expand Up @@ -767,7 +769,8 @@ async def test_create_proposed_change_data_checks_created(

_, data_validator = await self._get_proposed_change_and_data_validator(db=db)
core_data_checks = await data_validator.checks.get_peers(db=db) # type: ignore[attr-defined]
assert set(core_data_checks.keys()) == {
data_checks_by_conflict_id = {cdc.enriched_conflict_id.value: cdc for cdc in core_data_checks.values()}
assert set(data_checks_by_conflict_id.keys()) == {
attribute_value_conflict.conflict_id,
peer_conflict.conflict_id,
cardinality_one_property_conflict_a.conflict_id,
Expand All @@ -776,13 +779,21 @@ async def test_create_proposed_change_data_checks_created(
node_removed_conflict.conflict_id,
node_removed_attribute_value_conflict.conflict_id,
}
attr_value_data_check = core_data_checks[attribute_value_conflict.conflict_id]
peer_data_check = core_data_checks[peer_conflict.conflict_id]
cardinality_one_property_data_check_a = core_data_checks[cardinality_one_property_conflict_a.conflict_id]
cardinality_one_property_data_check_b = core_data_checks[cardinality_one_property_conflict_b.conflict_id]
cardinality_many_property_data_check = core_data_checks[cardinality_many_property_conflict.conflict_id]
node_removed_data_check = core_data_checks[node_removed_conflict.conflict_id]
node_removed_attr_value_data_check = core_data_checks[node_removed_attribute_value_conflict.conflict_id]
attr_value_data_check = data_checks_by_conflict_id[attribute_value_conflict.conflict_id]
peer_data_check = data_checks_by_conflict_id[peer_conflict.conflict_id]
cardinality_one_property_data_check_a = data_checks_by_conflict_id[
cardinality_one_property_conflict_a.conflict_id
]
cardinality_one_property_data_check_b = data_checks_by_conflict_id[
cardinality_one_property_conflict_b.conflict_id
]
cardinality_many_property_data_check = data_checks_by_conflict_id[
cardinality_many_property_conflict.conflict_id
]
node_removed_data_check = data_checks_by_conflict_id[node_removed_conflict.conflict_id]
node_removed_attr_value_data_check = data_checks_by_conflict_id[
node_removed_attribute_value_conflict.conflict_id
]
assert attr_value_data_check.keep_branch.value.value == attribute_value_conflict.keep_branch.value
assert peer_data_check.keep_branch.value is None
assert cardinality_one_property_data_check_a.keep_branch.value is None
Expand Down Expand Up @@ -835,8 +846,9 @@ async def test_resolve_peer_conflict(
# check CoreDataChecks
_, data_validator = await self._get_proposed_change_and_data_validator(db=db)
core_data_checks = await data_validator.checks.get_peers(db=db) # type: ignore[attr-defined]
assert peer_conflict.conflict_id in core_data_checks
peer_data_check = core_data_checks[peer_conflict.conflict_id]
data_checks_by_conflict_id = {cdc.enriched_conflict_id.value: cdc for cdc in core_data_checks.values()}
assert peer_conflict.conflict_id in data_checks_by_conflict_id
peer_data_check = data_checks_by_conflict_id[peer_conflict.conflict_id]
assert peer_data_check.keep_branch.value.value is peer_conflict.keep_branch.value

async def test_resolve_peer_property_conflict(
Expand Down Expand Up @@ -927,11 +939,12 @@ async def test_resolve_peer_property_conflict(
# check CoreDataChecks
_, data_validator = await self._get_proposed_change_and_data_validator(db=db)
core_data_checks = await data_validator.checks.get_peers(db=db) # type: ignore[attr-defined]
assert cardinality_one_property_conflict_a.conflict_id in core_data_checks
assert cardinality_one_property_conflict_b.conflict_id in core_data_checks
data_check_a = core_data_checks[cardinality_one_property_conflict_a.conflict_id]
data_checks_by_conflict_id = {cdc.enriched_conflict_id.value: cdc for cdc in core_data_checks.values()}
assert cardinality_one_property_conflict_a.conflict_id in data_checks_by_conflict_id
assert cardinality_one_property_conflict_b.conflict_id in data_checks_by_conflict_id
data_check_a = data_checks_by_conflict_id[cardinality_one_property_conflict_a.conflict_id]
assert data_check_a.keep_branch.value.value == cardinality_one_property_conflict_a.keep_branch.value
data_check_b = core_data_checks[cardinality_one_property_conflict_b.conflict_id]
data_check_b = data_checks_by_conflict_id[cardinality_one_property_conflict_b.conflict_id]
assert data_check_b.keep_branch.value.value == cardinality_one_property_conflict_b.keep_branch.value

async def test_resolve_cardinality_many_property_conflict(
Expand Down Expand Up @@ -987,8 +1000,9 @@ async def test_resolve_cardinality_many_property_conflict(
# check CoreDataChecks
_, data_validator = await self._get_proposed_change_and_data_validator(db=db)
core_data_checks = await data_validator.checks.get_peers(db=db) # type: ignore[attr-defined]
assert cardinality_many_property_conflict.conflict_id in core_data_checks
peer_data_check = core_data_checks[cardinality_many_property_conflict.conflict_id]
data_checks_by_conflict_id = {cdc.enriched_conflict_id.value: cdc for cdc in core_data_checks.values()}
assert cardinality_many_property_conflict.conflict_id in data_checks_by_conflict_id
peer_data_check = data_checks_by_conflict_id[cardinality_many_property_conflict.conflict_id]
assert peer_data_check.keep_branch.value.value is cardinality_many_property_conflict.keep_branch.value

async def test_merge_fails_with_conflicts(
Expand Down Expand Up @@ -1059,12 +1073,15 @@ async def test_diff_resolve_node_removed_conflicts(
# check CoreDataChecks
_, data_validator = await self._get_proposed_change_and_data_validator(db=db)
core_data_checks = await data_validator.checks.get_peers(db=db) # type: ignore[attr-defined]
data_checks_by_conflict_id = {cdc.enriched_conflict_id.value: cdc for cdc in core_data_checks.values()}
assert {
node_removed_conflict.conflict_id,
node_removed_attribute_value_conflict.conflict_id,
} <= set(core_data_checks.keys())
node_removed_data_check = core_data_checks[node_removed_conflict.conflict_id]
node_removed_attr_value_data_check = core_data_checks[node_removed_attribute_value_conflict.conflict_id]
} <= set(data_checks_by_conflict_id.keys())
node_removed_data_check = data_checks_by_conflict_id[node_removed_conflict.conflict_id]
node_removed_attr_value_data_check = data_checks_by_conflict_id[
node_removed_attribute_value_conflict.conflict_id
]
assert node_removed_data_check.keep_branch.value.value == node_removed_conflict.keep_branch.value
assert (
node_removed_attr_value_data_check.keep_branch.value.value
Expand Down Expand Up @@ -1094,11 +1111,54 @@ async def test_expected_core_data_checks(
# check CoreDataChecks
_, data_validator = await self._get_proposed_change_and_data_validator(db=db)
core_data_checks = await data_validator.checks.get_peers(db=db) # type: ignore[attr-defined]
assert set(core_data_checks.keys()) == {tc.conflict_id for tc in tracked_conflicts}
data_checks_by_conflict_id = {cdc.enriched_conflict_id.value: cdc for cdc in core_data_checks.values()}
assert set(data_checks_by_conflict_id.keys()) == {tc.conflict_id for tc in tracked_conflicts}
for tracked_conflict in tracked_conflicts:
data_check = core_data_checks[tracked_conflict.conflict_id]
data_check = data_checks_by_conflict_id[tracked_conflict.conflict_id]
assert data_check.keep_branch.value.value == tracked_conflict.keep_branch.value

async def test_create_another_proposed_change_data_checks_created(
self, db: InfrahubDatabase, initial_dataset, default_branch, client: InfrahubClient
) -> None:
# verify duplicate data checks can be created
result = await client.execute_graphql(
query=PROPOSED_CHANGE_CREATE,
variables={
"name": PROPOSED_CHANGE_NAME + "2",
"source_branch": BRANCH_NAME,
"destination_branch": default_branch.name,
},
)
assert result["CoreProposedChangeCreate"]["object"]["id"]
pc_id = result["CoreProposedChangeCreate"]["object"]["id"]
attribute_value_conflict = self.retrieve_item("attribute_value")
peer_conflict = self.retrieve_item("peer_conflict")
cardinality_one_property_conflict_a = self.retrieve_item("cardinality_one_property_conflict_a")
cardinality_one_property_conflict_b = self.retrieve_item("cardinality_one_property_conflict_b")
cardinality_many_property_conflict = self.retrieve_item("cardinality_many_property_conflict")
node_removed_conflict = self.retrieve_item("node_removed")
node_removed_attribute_value_conflict = self.retrieve_item("node_removed_attribute_value")

pc = await NodeManager.get_one(db=db, id=pc_id)
validators = await pc.validations.get_peers(db=db)
data_validator = None
for v in validators.values():
if v.get_kind() == InfrahubKind.DATAVALIDATOR:
data_validator = v
assert data_validator
core_data_checks = await data_validator.checks.get_peers(db=db) # type: ignore[attr-defined]
data_checks_by_conflict_id = {cdc.enriched_conflict_id.value: cdc for cdc in core_data_checks.values()}
assert set(data_checks_by_conflict_id.keys()) == {
attribute_value_conflict.conflict_id,
peer_conflict.conflict_id,
cardinality_one_property_conflict_a.conflict_id,
cardinality_one_property_conflict_b.conflict_id,
cardinality_many_property_conflict.conflict_id,
node_removed_conflict.conflict_id,
node_removed_attribute_value_conflict.conflict_id,
}
assert len(core_data_checks) == len(data_checks_by_conflict_id)

async def test_merge_proposed_change(
self, db: InfrahubDatabase, initial_dataset, default_branch, client: InfrahubClient
) -> None:
Expand Down
4 changes: 4 additions & 0 deletions python_sdk/infrahub_sdk/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ class CoreCustomWebhook(CoreWebhook, CoreTaskTarget):
class CoreDataCheck(CoreCheck):
conflicts: JSONAttribute
keep_branch: Enum
enriched_conflict_id: StringOptional


class CoreDataValidator(CoreValidator):
Expand Down Expand Up @@ -395,6 +396,7 @@ class CoreRepositoryValidator(CoreValidator):

class CoreSchemaCheck(CoreCheck):
conflicts: JSONAttribute
enriched_conflict_id: StringOptional


class CoreSchemaValidator(CoreValidator):
Expand Down Expand Up @@ -683,6 +685,7 @@ class CoreCustomWebhookSync(CoreWebhookSync, CoreTaskTargetSync):
class CoreDataCheckSync(CoreCheckSync):
conflicts: JSONAttribute
keep_branch: Enum
enriched_conflict_id: StringOptional


class CoreDataValidatorSync(CoreValidatorSync):
Expand Down Expand Up @@ -811,6 +814,7 @@ class CoreRepositoryValidatorSync(CoreValidatorSync):

class CoreSchemaCheckSync(CoreCheckSync):
conflicts: JSONAttribute
enriched_conflict_id: StringOptional


class CoreSchemaValidatorSync(CoreValidatorSync):
Expand Down
Loading