Skip to content

Commit

Permalink
finish adding more labels and testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ajtmccarty committed Aug 29, 2024
1 parent aab5672 commit 3c94d04
Show file tree
Hide file tree
Showing 8 changed files with 131 additions and 18 deletions.
6 changes: 6 additions & 0 deletions backend/infrahub/core/diff/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .conflicts_enricher import ConflictsEnricher
from .data_check_synchronizer import DiffDataCheckSynchronizer
from .enricher.aggregated import AggregatedDiffEnricher
from .enricher.labels import DiffLabelsEnricher
from .enricher.summary_counts import DiffSummaryCountsEnricher
from .repository.repository import DiffRepository

Expand All @@ -40,6 +41,7 @@ def __init__(
diff_enricher: AggregatedDiffEnricher,
diff_combiner: DiffCombiner,
conflicts_enricher: ConflictsEnricher,
labels_enricher: DiffLabelsEnricher,
summary_counts_enricher: DiffSummaryCountsEnricher,
data_check_synchronizer: DiffDataCheckSynchronizer,
) -> None:
Expand All @@ -48,6 +50,7 @@ def __init__(
self.diff_enricher = diff_enricher
self.diff_combiner = diff_combiner
self.conflicts_enricher = conflicts_enricher
self.labels_enricher = labels_enricher
self.summary_counts_enricher = summary_counts_enricher
self.data_check_synchronizer = data_check_synchronizer
self._enriched_diff_cache: dict[EnrichedDiffRequest, EnrichedDiffRoot] = {}
Expand Down Expand Up @@ -162,6 +165,9 @@ async def _update_diffs(
base_diff_root=aggregated_diffs_by_branch_name[base_branch.name],
branch_diff_root=aggregated_diffs_by_branch_name[diff_branch.name],
)
await self.labels_enricher.enrich(
enriched_diff_root=aggregated_diffs_by_branch_name[diff_branch.name], conflicts_only=True
)

if tracking_id:
for enriched_diff in aggregated_diffs_by_branch_name.values():
Expand Down
27 changes: 18 additions & 9 deletions backend/infrahub/core/diff/enricher/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, db: InfrahubDatabase):
self.db = db
self._base_branch_name: str | None = None
self._diff_branch_name: str | None = None
self._conflicts_only = False

@property
def base_branch_name(self) -> str:
Expand All @@ -50,9 +51,10 @@ def diff_branch_name(self) -> str:

def _nodes_iterator(self, enriched_diff_root: EnrichedDiffRoot) -> Generator[DisplayLabelRequest, str | None, None]:
for node in enriched_diff_root.nodes:
label = yield DisplayLabelRequest(node_id=node.uuid, branch_name=self.diff_branch_name)
if label:
node.label = label
if not self._conflicts_only:
label = yield DisplayLabelRequest(node_id=node.uuid, branch_name=self.diff_branch_name)
if label:
node.label = label
for attribute_diff in node.attributes:
for property_diff in attribute_diff.properties:
property_iterator = self._property_iterator(property_diff=property_diff)
Expand All @@ -77,9 +79,10 @@ def _relationship_iterator(
self, relationship_diff: EnrichedDiffRelationship
) -> Generator[DisplayLabelRequest, str | None, None]:
for element_diff in relationship_diff.relationships:
peer_label = yield DisplayLabelRequest(node_id=element_diff.peer_id, branch_name=self.diff_branch_name)
if peer_label:
element_diff.peer_label = peer_label
if not self._conflicts_only:
peer_label = yield DisplayLabelRequest(node_id=element_diff.peer_id, branch_name=self.diff_branch_name)
if peer_label:
element_diff.peer_label = peer_label
if element_diff.conflict:
conflict_iterator = self._conflict_iterator(conflict_diff=element_diff.conflict)
label = None
Expand All @@ -103,13 +106,13 @@ def _property_iterator(
self, property_diff: EnrichedDiffProperty
) -> Generator[DisplayLabelRequest, str | None, None]:
if property_diff.property_type in PROPERTY_TYPES_WITH_LABELS:
if property_diff.previous_value:
if property_diff.previous_value and not self._conflicts_only:
label = yield DisplayLabelRequest(
node_id=property_diff.previous_value, branch_name=self.base_branch_name
)
if label:
property_diff.previous_label = label
if property_diff.new_value:
if property_diff.new_value and not self._conflicts_only:
label = yield DisplayLabelRequest(node_id=property_diff.new_value, branch_name=self.diff_branch_name)
if label:
property_diff.new_label = label
Expand Down Expand Up @@ -167,9 +170,15 @@ async def _get_display_label_map(
branch_map[node_kind].append(dlr.node_id)
return await get_display_labels(db=self.db, nodes=display_label_request_map, ignore_deleted=False)

async def enrich(self, enriched_diff_root: EnrichedDiffRoot, calculated_diffs: CalculatedDiffs) -> None:
async def enrich(
self,
enriched_diff_root: EnrichedDiffRoot,
calculated_diffs: CalculatedDiffs | None = None,
conflicts_only: bool = False,
) -> None:
self._base_branch_name = enriched_diff_root.base_branch_name
self._diff_branch_name = enriched_diff_root.diff_branch_name
self._conflicts_only = conflicts_only
display_label_requests = set(self._nodes_iterator(enriched_diff_root=enriched_diff_root))
display_label_map = await self._get_display_label_map(display_label_requests=display_label_requests)

Expand Down
4 changes: 4 additions & 0 deletions backend/infrahub/core/diff/query/save_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@ def _build_conflict_params(self, enriched_conflict: EnrichedDiffConflict) -> dic
"base_branch_changed_at": enriched_conflict.base_branch_changed_at.to_string()
if enriched_conflict.base_branch_changed_at
else None,
"base_branch_label": enriched_conflict.base_branch_label,
"diff_branch_action": enriched_conflict.diff_branch_action.value,
"diff_branch_value": enriched_conflict.diff_branch_value,
"diff_branch_changed_at": enriched_conflict.diff_branch_changed_at.to_string()
if enriched_conflict.diff_branch_changed_at
else None,
"diff_branch_label": enriched_conflict.diff_branch_label,
"selected_branch": enriched_conflict.selected_branch.value if enriched_conflict.selected_branch else None,
}

Expand All @@ -137,6 +139,8 @@ def _build_diff_property_params(self, enriched_property: EnrichedDiffProperty) -
"changed_at": enriched_property.changed_at.to_string(),
"previous_value": enriched_property.previous_value,
"new_value": enriched_property.new_value,
"previous_label": enriched_property.previous_label,
"new_label": enriched_property.new_label,
"action": enriched_property.action,
"path_identifier": enriched_property.path_identifier,
},
Expand Down
12 changes: 12 additions & 0 deletions backend/infrahub/core/diff/repository/deserializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,15 @@ def _deserialize_diff_relationship_element(
def _property_node_to_enriched_property(self, property_node: Neo4jNode) -> EnrichedDiffProperty:
previous_value = self._get_str_or_none_property_value(node=property_node, property_name="previous_value")
new_value = self._get_str_or_none_property_value(node=property_node, property_name="new_value")
previous_label = self._get_str_or_none_property_value(node=property_node, property_name="previous_label")
new_label = self._get_str_or_none_property_value(node=property_node, property_name="new_label")
return EnrichedDiffProperty(
property_type=DatabaseEdgeType(str(property_node.get("property_type"))),
changed_at=Timestamp(str(property_node.get("changed_at"))),
previous_value=previous_value,
new_value=new_value,
previous_label=previous_label,
new_label=new_label,
action=DiffAction(str(property_node.get("action"))),
path_identifier=str(property_node.get("path_identifier")),
)
Expand Down Expand Up @@ -345,6 +349,12 @@ def deserialize_conflict(self, diff_conflict_node: Neo4jNode) -> EnrichedDiffCon
diff_branch_value = self._get_str_or_none_property_value(
node=diff_conflict_node, property_name="diff_branch_value"
)
base_branch_label = self._get_str_or_none_property_value(
node=diff_conflict_node, property_name="base_branch_label"
)
diff_branch_label = self._get_str_or_none_property_value(
node=diff_conflict_node, property_name="diff_branch_label"
)
base_timestamp_str = self._get_str_or_none_property_value(
node=diff_conflict_node, property_name="base_branch_changed_at"
)
Expand All @@ -357,8 +367,10 @@ def deserialize_conflict(self, diff_conflict_node: Neo4jNode) -> EnrichedDiffCon
base_branch_action=DiffAction(str(diff_conflict_node.get("base_branch_action"))),
base_branch_value=base_branch_value,
base_branch_changed_at=Timestamp(base_timestamp_str) if base_timestamp_str else None,
base_branch_label=base_branch_label,
diff_branch_action=DiffAction(str(diff_conflict_node.get("diff_branch_action"))),
diff_branch_value=diff_branch_value,
diff_branch_label=diff_branch_label,
diff_branch_changed_at=Timestamp(diff_timestamp_str) if diff_timestamp_str else None,
selected_branch=ConflictSelection(selected_branch) if selected_branch else None,
)
2 changes: 2 additions & 0 deletions backend/infrahub/dependencies/builder/diff/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .conflicts_enricher import DiffConflictsEnricherDependency
from .data_check_synchronizer import DiffDataCheckSynchronizerDependency
from .enricher.aggregated import DiffAggregatedEnricherDependency
from .enricher.labels import DiffLabelsEnricherDependency
from .enricher.summary_counts import DiffSummaryCountsEnricherDependency
from .repository import DiffRepositoryDependency

Expand All @@ -19,6 +20,7 @@ def build(cls, context: DependencyBuilderContext) -> DiffCoordinator:
diff_combiner=DiffCombinerDependency.build(context=context),
diff_enricher=DiffAggregatedEnricherDependency.build(context=context),
conflicts_enricher=DiffConflictsEnricherDependency.build(context=context),
labels_enricher=DiffLabelsEnricherDependency.build(context=context),
summary_counts_enricher=DiffSummaryCountsEnricherDependency.build(context=context),
data_check_synchronizer=DiffDataCheckSynchronizerDependency.build(context=context),
)
8 changes: 8 additions & 0 deletions backend/infrahub/graphql/queries/diff/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ class ConflictDetails(ObjectType):
base_branch_action = Field(GrapheneDiffActionEnum, required=True)
base_branch_value = String()
base_branch_changed_at = DateTime(required=True)
base_branch_label = String()
diff_branch_action = Field(GrapheneDiffActionEnum, required=True)
diff_branch_value = String()
diff_branch_changed_at = DateTime(required=True)
diff_branch_label = String()
selected_branch = Field(GraphQLConflictSelection)


Expand All @@ -57,6 +59,8 @@ class DiffProperty(ObjectType):
last_changed_at = DateTime(required=True)
previous_value = String(required=False)
new_value = String(required=False)
previous_label = String(required=False)
new_label = String(required=False)
status = Field(GrapheneDiffActionEnum, required=True)
path_identifier = String(required=True)
conflict = Field(ConflictDetails, required=False)
Expand Down Expand Up @@ -265,6 +269,8 @@ def to_diff_property(
last_changed_at=enriched_property.changed_at.obj,
previous_value=enriched_property.previous_value,
new_value=enriched_property.new_value,
previous_label=enriched_property.previous_label,
new_label=enriched_property.new_label,
status=enriched_property.action,
path_identifier=enriched_property.path_identifier,
conflict=conflict,
Expand All @@ -282,11 +288,13 @@ def to_diff_conflict(
base_branch_changed_at=enriched_conflict.base_branch_changed_at.obj
if enriched_conflict.base_branch_changed_at
else None,
base_branch_label=enriched_conflict.base_branch_label,
diff_branch_action=enriched_conflict.diff_branch_action,
diff_branch_value=enriched_conflict.diff_branch_value,
diff_branch_changed_at=enriched_conflict.diff_branch_changed_at.obj
if enriched_conflict.diff_branch_changed_at
else None,
diff_branch_label=enriched_conflict.diff_branch_label,
selected_branch=enriched_conflict.selected_branch.value if enriched_conflict.selected_branch else None,
)

Expand Down
Loading

0 comments on commit 3c94d04

Please sign in to comment.