diff --git a/backend/infrahub/core/diff/coordinator.py b/backend/infrahub/core/diff/coordinator.py index 276c777401..4606e8825b 100644 --- a/backend/infrahub/core/diff/coordinator.py +++ b/backend/infrahub/core/diff/coordinator.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Iterable from infrahub import lock from infrahub.core import registry @@ -14,7 +14,6 @@ EnrichedDiffs, NameTrackingId, NodeFieldSpecifier, - TimeRange, TrackingId, ) @@ -219,6 +218,30 @@ async def recalculate( log.debug(f"Diff recalculation complete for {base_branch.name} - {diff_branch.name}") return enriched_diffs.diff_branch_diff + def _get_ordered_diff_pairs( + self, diff_pairs: Iterable[EnrichedDiffs], allow_overlap: bool = False + ) -> list[EnrichedDiffs]: + ordered_diffs = sorted(diff_pairs, key=lambda d: d.diff_branch_diff.from_time) + if allow_overlap: + return ordered_diffs + ordered_diffs_no_overlaps: list[EnrichedDiffs] = [] + for candidate_diff_pair in ordered_diffs: + if not ordered_diffs_no_overlaps: + ordered_diffs_no_overlaps.append(candidate_diff_pair) + continue + # no time overlap + previous_diff = ordered_diffs_no_overlaps[-1].diff_branch_diff + candidate_diff = candidate_diff_pair.diff_branch_diff + if previous_diff.to_time <= candidate_diff.from_time: + ordered_diffs_no_overlaps.append(candidate_diff_pair) + continue + previous_interval = previous_diff.time_range + candidate_interval = candidate_diff.time_range + # keep the diff that covers the larger time frame + if candidate_interval > previous_interval: + ordered_diffs_no_overlaps[-1] = candidate_diff_pair + return ordered_diffs_no_overlaps + async def _update_diffs( self, base_branch: Branch, @@ -272,15 +295,17 @@ async def _get_aggregated_enriched_diffs( if not partial_enriched_diffs: return await self._get_enriched_diff(diff_request=diff_request) - remaining_diffs = sorted(partial_enriched_diffs, key=lambda d: d.diff_branch_diff.from_time) + ordered_diffs = self._get_ordered_diff_pairs(diff_pairs=partial_enriched_diffs, allow_overlap=False) + ordered_diff_reprs = [repr(d) for d in ordered_diffs] + log.debug(f"Ordered diffs for aggregation: {ordered_diff_reprs}") current_time = diff_request.from_time previous_diffs: EnrichedDiffs | None = None while current_time < diff_request.to_time: - if remaining_diffs and remaining_diffs[0].diff_branch_diff.from_time == current_time: - current_diffs = remaining_diffs.pop(0) + if ordered_diffs and ordered_diffs[0].diff_branch_diff.from_time == current_time: + current_diffs = ordered_diffs.pop(0) else: - if remaining_diffs: - end_time = remaining_diffs[0].diff_branch_diff.from_time + if ordered_diffs: + end_time = ordered_diffs[0].diff_branch_diff.from_time else: end_time = diff_request.to_time if previous_diffs is None: @@ -322,26 +347,6 @@ async def _get_enriched_diff(self, diff_request: EnrichedDiffRequest) -> Enriche enriched_diff_pair = await self.diff_enricher.enrich(calculated_diffs=calculated_diff_pair) return enriched_diff_pair - def _get_missing_time_ranges( - self, time_ranges: list[TimeRange], from_time: Timestamp, to_time: Timestamp - ) -> list[TimeRange]: - if not time_ranges: - return [TimeRange(from_time=from_time, to_time=to_time)] - sorted_time_ranges = sorted(time_ranges, key=lambda tr: tr.from_time) - missing_time_ranges = [] - if sorted_time_ranges[0].from_time > from_time: - missing_time_ranges.append(TimeRange(from_time=from_time, to_time=sorted_time_ranges[0].from_time)) - index = 0 - while index < len(sorted_time_ranges) - 1: - this_diff = sorted_time_ranges[index] - next_diff = sorted_time_ranges[index + 1] - if this_diff.to_time < next_diff.from_time: - missing_time_ranges.append(TimeRange(from_time=this_diff.to_time, to_time=next_diff.from_time)) - index += 1 - if sorted_time_ranges[-1].to_time < to_time: - missing_time_ranges.append(TimeRange(from_time=sorted_time_ranges[-1].to_time, to_time=to_time)) - return missing_time_ranges - def _get_node_field_specifiers(self, enriched_diff: EnrichedDiffRoot) -> set[NodeFieldSpecifier]: specifiers: set[NodeFieldSpecifier] = set() schema_branch = registry.schema.get_schema_branch(name=enriched_diff.diff_branch_name) diff --git a/backend/infrahub/core/diff/model/path.py b/backend/infrahub/core/diff/model/path.py index ff210a53d1..d1d0b32563 100644 --- a/backend/infrahub/core/diff/model/path.py +++ b/backend/infrahub/core/diff/model/path.py @@ -20,6 +20,7 @@ from neo4j.graph import Node as Neo4jNode from neo4j.graph import Path as Neo4jPath from neo4j.graph import Relationship as Neo4jRelationship + from pendulum import Interval from infrahub.graphql.initialization import GraphqlContext @@ -394,6 +395,10 @@ class EnrichedDiffRoot(BaseSummary): def __hash__(self) -> int: return hash(self.uuid) + @property + def time_range(self) -> Interval: + return self.to_time.obj - self.from_time.obj + def get_nodes_without_parents(self) -> set[EnrichedDiffNode]: nodes_with_parent_uuids = set() for n in self.nodes: @@ -483,6 +488,17 @@ class EnrichedDiffs: base_branch_diff: EnrichedDiffRoot diff_branch_diff: EnrichedDiffRoot + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"branch_uuid={self.diff_branch_diff.uuid}," + f"base_uuid={self.base_branch_diff.uuid}," + f"branch_name={self.diff_branch_name}," + f"base_name={self.base_branch_name}," + f"from_time={self.diff_branch_diff.from_time}," + f"to_time={self.diff_branch_diff.to_time})" + ) + @classmethod def from_calculated_diffs(cls, calculated_diffs: CalculatedDiffs) -> EnrichedDiffs: base_branch_diff = EnrichedDiffRoot.from_calculated_diff( diff --git a/backend/infrahub/core/query/diff.py b/backend/infrahub/core/query/diff.py index 5b71330e44..c73385c062 100644 --- a/backend/infrahub/core/query/diff.py +++ b/backend/infrahub/core/query/diff.py @@ -767,8 +767,10 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None: // ------------------------------------- MATCH diff_rel_path = (root:Root)<-[r_root:IS_PART_OF]-(n:Node)-[r_node]-(p)-[diff_rel {branch: $branch_name}]->(q) WHERE (node_field_specifiers_list IS NULL OR [n.uuid, p.name] IN node_field_specifiers_list) - AND (from_time <= diff_rel.from < $to_time) - AND (diff_rel.to IS NULL OR (from_time <= diff_rel.to < $to_time)) + AND ( + (from_time <= diff_rel.from < $to_time) + OR (from_time <= diff_rel.to < $to_time) + ) // exclude attributes and relationships under added/removed nodes, attrs, and rels b/c they are covered above AND ALL( r in [r_root, r_node] diff --git a/backend/tests/unit/core/diff/test_coordinator.py b/backend/tests/unit/core/diff/test_coordinator.py index 1e73123792..a1ed15c255 100644 --- a/backend/tests/unit/core/diff/test_coordinator.py +++ b/backend/tests/unit/core/diff/test_coordinator.py @@ -2,9 +2,12 @@ from infrahub.core.constants import DiffAction from infrahub.core.constants.database import DatabaseEdgeType from infrahub.core.diff.coordinator import DiffCoordinator +from infrahub.core.diff.model.path import BranchTrackingId +from infrahub.core.diff.repository.repository import DiffRepository from infrahub.core.initialization import create_branch from infrahub.core.manager import NodeManager from infrahub.core.node import Node +from infrahub.core.timestamp import Timestamp from infrahub.database import InfrahubDatabase from infrahub.dependencies.registry import get_component_registry @@ -44,3 +47,65 @@ async def test_node_deleted_after_branching( assert prop_diff.action is DiffAction.REMOVED assert prop_diff.conflict is None assert prop_diff.new_value is None + + async def test_overlapping_diffs(self, db: InfrahubDatabase, default_branch: Branch, person_john_main: Node): + branch = await create_branch(db=db, branch_name="branch") + component_registry = get_component_registry() + diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=branch) + diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=branch) + original_height = person_john_main.height.value + person_john_branch = await NodeManager.get_one(db=db, branch=branch, id=person_john_main.id) + + # t0 + t0 = Timestamp() + person_john_branch.height.value = 1 + await person_john_branch.save(db=db) + # t1 + t1 = Timestamp() + person_john_branch.height.value = 2 + await person_john_branch.save(db=db) + # t2 + # diff from t0 - t2 + await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch) + person_john_branch.height.value = 3 + await person_john_branch.save(db=db) + # t3 + t3 = Timestamp() + # overlapping diff from t1 to t3 + arbitrary_diff = await diff_coordinator.create_or_update_arbitrary_timeframe_diff( + base_branch=default_branch, diff_branch=branch, from_time=t1, to_time=t3 + ) + + full_diff = await diff_coordinator.update_branch_diff(base_branch=default_branch, diff_branch=branch) + + # check that only one branch-tracking diff exists for this branch + tracking_diff = await diff_repository.get_one( + diff_branch_name=branch.name, tracking_id=BranchTrackingId(name=branch.name) + ) + assert tracking_diff == full_diff + # test that arbitrary diff still exists + retrieved_arbitrary_diff = await diff_repository.get_one( + diff_branch_name=branch.name, diff_id=arbitrary_diff.uuid + ) + assert retrieved_arbitrary_diff == arbitrary_diff + + # validate content of the diff + assert full_diff.base_branch_name == default_branch.name + assert full_diff.diff_branch_name == branch.name + assert full_diff.from_time < t0 + assert full_diff.to_time > t3 + assert len(full_diff.nodes) == 1 + diff_node = full_diff.nodes.pop() + assert diff_node.uuid == person_john_main.id + assert diff_node.action is DiffAction.UPDATED + assert not diff_node.relationships + assert len(diff_node.attributes) == 1 + diff_attribute = diff_node.attributes.pop() + assert diff_attribute.name == "height" + assert diff_attribute.action is DiffAction.UPDATED + assert len(diff_attribute.properties) == 1 + diff_property = diff_attribute.properties.pop() + assert diff_property.property_type is DatabaseEdgeType.HAS_VALUE + assert diff_property.action is DiffAction.UPDATED + assert diff_property.previous_value == str(original_height) + assert diff_property.new_value == "3" diff --git a/backend/tests/unit/core/diff/test_diff_calculator.py b/backend/tests/unit/core/diff/test_diff_calculator.py index 4944fefa4a..bb60c0c475 100644 --- a/backend/tests/unit/core/diff/test_diff_calculator.py +++ b/backend/tests/unit/core/diff/test_diff_calculator.py @@ -694,7 +694,7 @@ async def test_relationship_one_property_branch_update( single_relationship = single_relationships_by_peer_id[person_john_main.id] assert single_relationship.peer_id == person_john_main.id assert single_relationship.action is DiffAction.REMOVED - assert len(single_relationship.properties) == 2 + assert len(single_relationship.properties) == 3 assert before_main_change < single_relationship.changed_at < after_main_change property_diff_by_type = {p.property_type: p for p in single_relationship.properties} property_diff = property_diff_by_type[DatabaseEdgeType.IS_RELATED] @@ -709,6 +709,12 @@ async def test_relationship_one_property_branch_update( assert property_diff.new_value is None assert property_diff.action is DiffAction.REMOVED assert before_main_change < property_diff.changed_at < after_main_change + property_diff = property_diff_by_type[DatabaseEdgeType.IS_PROTECTED] + assert property_diff.property_type == DatabaseEdgeType.IS_PROTECTED + assert property_diff.previous_value is False + assert property_diff.new_value is None + assert property_diff.action is DiffAction.REMOVED + assert before_main_change < property_diff.changed_at < after_main_change async def test_add_node_branch(