Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix issue that causes illegal diff query #5291

Merged
merged 2 commits into from
Dec 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 32 additions & 27 deletions backend/infrahub/core/diff/coordinator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable

from infrahub import lock
from infrahub.core import registry
Expand All @@ -14,7 +14,6 @@
EnrichedDiffs,
NameTrackingId,
NodeFieldSpecifier,
TimeRange,
TrackingId,
)

Expand Down Expand Up @@ -219,6 +218,30 @@ async def recalculate(
log.debug(f"Diff recalculation complete for {base_branch.name} - {diff_branch.name}")
return enriched_diffs.diff_branch_diff

def _get_ordered_diff_pairs(
self, diff_pairs: Iterable[EnrichedDiffs], allow_overlap: bool = False
) -> list[EnrichedDiffs]:
ordered_diffs = sorted(diff_pairs, key=lambda d: d.diff_branch_diff.from_time)
if allow_overlap:
return ordered_diffs
ordered_diffs_no_overlaps: list[EnrichedDiffs] = []
for candidate_diff_pair in ordered_diffs:
if not ordered_diffs_no_overlaps:
ordered_diffs_no_overlaps.append(candidate_diff_pair)
continue
# no time overlap
previous_diff = ordered_diffs_no_overlaps[-1].diff_branch_diff
candidate_diff = candidate_diff_pair.diff_branch_diff
if previous_diff.to_time <= candidate_diff.from_time:
ordered_diffs_no_overlaps.append(candidate_diff_pair)
continue
previous_interval = previous_diff.time_range
candidate_interval = candidate_diff.time_range
# keep the diff that covers the larger time frame
if candidate_interval > previous_interval:
ordered_diffs_no_overlaps[-1] = candidate_diff_pair
return ordered_diffs_no_overlaps
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strip out overlapping diffs. when two diffs overlap, keep the one that covers the greater time range


async def _update_diffs(
self,
base_branch: Branch,
Expand Down Expand Up @@ -272,15 +295,17 @@ async def _get_aggregated_enriched_diffs(
if not partial_enriched_diffs:
return await self._get_enriched_diff(diff_request=diff_request)

remaining_diffs = sorted(partial_enriched_diffs, key=lambda d: d.diff_branch_diff.from_time)
ordered_diffs = self._get_ordered_diff_pairs(diff_pairs=partial_enriched_diffs, allow_overlap=False)
ordered_diff_reprs = [repr(d) for d in ordered_diffs]
log.debug(f"Ordered diffs for aggregation: {ordered_diff_reprs}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a little logging to help us troubleshoot if this comes up again

current_time = diff_request.from_time
previous_diffs: EnrichedDiffs | None = None
while current_time < diff_request.to_time:
if remaining_diffs and remaining_diffs[0].diff_branch_diff.from_time == current_time:
current_diffs = remaining_diffs.pop(0)
if ordered_diffs and ordered_diffs[0].diff_branch_diff.from_time == current_time:
current_diffs = ordered_diffs.pop(0)
else:
if remaining_diffs:
end_time = remaining_diffs[0].diff_branch_diff.from_time
if ordered_diffs:
end_time = ordered_diffs[0].diff_branch_diff.from_time
else:
end_time = diff_request.to_time
if previous_diffs is None:
Expand Down Expand Up @@ -322,26 +347,6 @@ async def _get_enriched_diff(self, diff_request: EnrichedDiffRequest) -> Enriche
enriched_diff_pair = await self.diff_enricher.enrich(calculated_diffs=calculated_diff_pair)
return enriched_diff_pair

def _get_missing_time_ranges(
self, time_ranges: list[TimeRange], from_time: Timestamp, to_time: Timestamp
) -> list[TimeRange]:
if not time_ranges:
return [TimeRange(from_time=from_time, to_time=to_time)]
sorted_time_ranges = sorted(time_ranges, key=lambda tr: tr.from_time)
missing_time_ranges = []
if sorted_time_ranges[0].from_time > from_time:
missing_time_ranges.append(TimeRange(from_time=from_time, to_time=sorted_time_ranges[0].from_time))
index = 0
while index < len(sorted_time_ranges) - 1:
this_diff = sorted_time_ranges[index]
next_diff = sorted_time_ranges[index + 1]
if this_diff.to_time < next_diff.from_time:
missing_time_ranges.append(TimeRange(from_time=this_diff.to_time, to_time=next_diff.from_time))
index += 1
if sorted_time_ranges[-1].to_time < to_time:
missing_time_ranges.append(TimeRange(from_time=sorted_time_ranges[-1].to_time, to_time=to_time))
return missing_time_ranges

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code

def _get_node_field_specifiers(self, enriched_diff: EnrichedDiffRoot) -> set[NodeFieldSpecifier]:
specifiers: set[NodeFieldSpecifier] = set()
schema_branch = registry.schema.get_schema_branch(name=enriched_diff.diff_branch_name)
Expand Down
16 changes: 16 additions & 0 deletions backend/infrahub/core/diff/model/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from neo4j.graph import Node as Neo4jNode
from neo4j.graph import Path as Neo4jPath
from neo4j.graph import Relationship as Neo4jRelationship
from pendulum import Interval

from infrahub.graphql.initialization import GraphqlContext

Expand Down Expand Up @@ -394,6 +395,10 @@ class EnrichedDiffRoot(BaseSummary):
def __hash__(self) -> int:
return hash(self.uuid)

@property
def time_range(self) -> Interval:
return self.to_time.obj - self.from_time.obj

def get_nodes_without_parents(self) -> set[EnrichedDiffNode]:
nodes_with_parent_uuids = set()
for n in self.nodes:
Expand Down Expand Up @@ -483,6 +488,17 @@ class EnrichedDiffs:
base_branch_diff: EnrichedDiffRoot
diff_branch_diff: EnrichedDiffRoot

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"branch_uuid={self.diff_branch_diff.uuid},"
f"base_uuid={self.base_branch_diff.uuid},"
f"branch_name={self.diff_branch_name},"
f"base_name={self.base_branch_name},"
f"from_time={self.diff_branch_diff.from_time},"
f"to_time={self.diff_branch_diff.to_time})"
)

@classmethod
def from_calculated_diffs(cls, calculated_diffs: CalculatedDiffs) -> EnrichedDiffs:
base_branch_diff = EnrichedDiffRoot.from_calculated_diff(
Expand Down
6 changes: 4 additions & 2 deletions backend/infrahub/core/query/diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,8 +767,10 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
// -------------------------------------
MATCH diff_rel_path = (root:Root)<-[r_root:IS_PART_OF]-(n:Node)-[r_node]-(p)-[diff_rel {branch: $branch_name}]->(q)
WHERE (node_field_specifiers_list IS NULL OR [n.uuid, p.name] IN node_field_specifiers_list)
AND (from_time <= diff_rel.from < $to_time)
AND (diff_rel.to IS NULL OR (from_time <= diff_rel.to < $to_time))
AND (
(from_time <= diff_rel.from < $to_time)
OR (from_time <= diff_rel.to < $to_time)
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix in the query to calculate a diff

// exclude attributes and relationships under added/removed nodes, attrs, and rels b/c they are covered above
AND ALL(
r in [r_root, r_node]
Expand Down
65 changes: 65 additions & 0 deletions backend/tests/unit/core/diff/test_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
from infrahub.core.constants import DiffAction
from infrahub.core.constants.database import DatabaseEdgeType
from infrahub.core.diff.coordinator import DiffCoordinator
from infrahub.core.diff.model.path import BranchTrackingId
from infrahub.core.diff.repository.repository import DiffRepository
from infrahub.core.initialization import create_branch
from infrahub.core.manager import NodeManager
from infrahub.core.node import Node
from infrahub.core.timestamp import Timestamp
from infrahub.database import InfrahubDatabase
from infrahub.dependencies.registry import get_component_registry

Expand Down Expand Up @@ -44,3 +47,65 @@ async def test_node_deleted_after_branching(
assert prop_diff.action is DiffAction.REMOVED
assert prop_diff.conflict is None
assert prop_diff.new_value is None

async def test_overlapping_diffs(self, db: InfrahubDatabase, default_branch: Branch, person_john_main: Node):
branch = await create_branch(db=db, branch_name="branch")
component_registry = get_component_registry()
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=branch)
diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=branch)
original_height = person_john_main.height.value
person_john_branch = await NodeManager.get_one(db=db, branch=branch, id=person_john_main.id)

# t0
t0 = Timestamp()
person_john_branch.height.value = 1
await person_john_branch.save(db=db)
# t1
t1 = Timestamp()
person_john_branch.height.value = 2
await person_john_branch.save(db=db)
# t2
# diff from t0 - t2
await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch)
person_john_branch.height.value = 3
await person_john_branch.save(db=db)
# t3
t3 = Timestamp()
# overlapping diff from t1 to t3
arbitrary_diff = await diff_coordinator.create_or_update_arbitrary_timeframe_diff(
base_branch=default_branch, diff_branch=branch, from_time=t1, to_time=t3
)

full_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch)

# check that only one branch-tracking diff exists for this branch
tracking_diff = await diff_repository.get_one(
diff_branch_name=branch.name, tracking_id=BranchTrackingId(name=branch.name)
)
assert tracking_diff == full_diff
# test that arbitrary diff still exists
retrieved_arbitrary_diff = await diff_repository.get_one(
diff_branch_name=branch.name, diff_id=arbitrary_diff.uuid
)
assert retrieved_arbitrary_diff == arbitrary_diff

# validate content of the diff
assert full_diff.base_branch_name == default_branch.name
assert full_diff.diff_branch_name == branch.name
assert full_diff.from_time < t0
assert full_diff.to_time > t3
assert len(full_diff.nodes) == 1
diff_node = full_diff.nodes.pop()
assert diff_node.uuid == person_john_main.id
assert diff_node.action is DiffAction.UPDATED
assert not diff_node.relationships
assert len(diff_node.attributes) == 1
diff_attribute = diff_node.attributes.pop()
assert diff_attribute.name == "height"
assert diff_attribute.action is DiffAction.UPDATED
assert len(diff_attribute.properties) == 1
diff_property = diff_attribute.properties.pop()
assert diff_property.property_type is DatabaseEdgeType.HAS_VALUE
assert diff_property.action is DiffAction.UPDATED
assert diff_property.previous_value == str(original_height)
assert diff_property.new_value == "3"
8 changes: 7 additions & 1 deletion backend/tests/unit/core/diff/test_diff_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ async def test_relationship_one_property_branch_update(
single_relationship = single_relationships_by_peer_id[person_john_main.id]
assert single_relationship.peer_id == person_john_main.id
assert single_relationship.action is DiffAction.REMOVED
assert len(single_relationship.properties) == 2
assert len(single_relationship.properties) == 3
assert before_main_change < single_relationship.changed_at < after_main_change
property_diff_by_type = {p.property_type: p for p in single_relationship.properties}
property_diff = property_diff_by_type[DatabaseEdgeType.IS_RELATED]
Expand All @@ -709,6 +709,12 @@ async def test_relationship_one_property_branch_update(
assert property_diff.new_value is None
assert property_diff.action is DiffAction.REMOVED
assert before_main_change < property_diff.changed_at < after_main_change
property_diff = property_diff_by_type[DatabaseEdgeType.IS_PROTECTED]
assert property_diff.property_type == DatabaseEdgeType.IS_PROTECTED
assert property_diff.previous_value is False
assert property_diff.new_value is None
assert property_diff.action is DiffAction.REMOVED
assert before_main_change < property_diff.changed_at < after_main_change
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix to this unit test after the fix to the diff query



async def test_add_node_branch(
Expand Down
Loading