Skip to content

Commit 6e27404

Browse files
authored
Merge pull request #7696 from opsmill/pog-prefect-trigger-refactoring
Refactor branch defined triggers
2 parents 2f92abc + 3d9aa65 commit 6e27404

File tree

11 files changed

+236
-59
lines changed

11 files changed

+236
-59
lines changed

backend/infrahub/computed_attribute/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,12 @@ class ComputedAttrJinja2TriggerDefinition(TriggerBranchDefinition):
119119
type: TriggerType = TriggerType.COMPUTED_ATTR_JINJA2
120120
computed_attribute: ComputedAttributeTarget
121121
template_hash: str
122+
trigger_kind: str
123+
124+
@property
125+
def targets_self(self) -> bool:
126+
"""Determine if the specific trigger definition targets the actual node kind of the computed attribute."""
127+
return self.trigger_kind == self.computed_attribute.kind
122128

123129
def get_description(self) -> str:
124130
return f"{super().get_description()} | hash:{self.template_hash}"
@@ -190,6 +196,7 @@ def from_computed_attribute(
190196
definition = cls(
191197
name=f"{computed_attribute.key_name}{NAME_SEPARATOR}kind{NAME_SEPARATOR}{trigger_node.kind}",
192198
template_hash=template_hash,
199+
trigger_kind=trigger_node.kind,
193200
branch=branch,
194201
computed_attribute=computed_attribute,
195202
trigger=event_trigger,

backend/infrahub/computed_attribute/tasks.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from .models import (
3030
ComputedAttrJinja2GraphQL,
3131
ComputedAttrJinja2GraphQLResponse,
32+
ComputedAttrJinja2TriggerDefinition,
3233
PythonTransformTarget,
3334
)
3435

@@ -312,21 +313,46 @@ async def computed_attribute_setup_jinja2(
312313
) # type: ignore[misc]
313314
# Configure all ComputedAttrJinja2Trigger in Prefect
314315

316+
all_triggers = report.triggers_with_type(trigger_type=ComputedAttrJinja2TriggerDefinition)
317+
315318
# Since we can have multiple trigger per NodeKind
316-
# we need to extract the list of unique node that should be processed
317-
unique_nodes: set[tuple[str, str, str]] = {
318-
(trigger.branch, trigger.computed_attribute.kind, trigger.computed_attribute.attribute.name) # type: ignore[attr-defined]
319-
for trigger in report.updated + report.created
320-
}
321-
for branch, kind, attribute_name in unique_nodes:
322-
if event_name != BranchDeletedEvent.event_name and branch == branch_name:
319+
# we need to extract the list of unique node that should be processed, this is done by filtering the triggers that targets_self
320+
modified_triggers = [
321+
trigger
322+
for trigger in report.modified_triggers_with_type(trigger_type=ComputedAttrJinja2TriggerDefinition)
323+
if trigger.targets_self
324+
]
325+
326+
for modified_trigger in modified_triggers:
327+
if event_name != BranchDeletedEvent.event_name and modified_trigger.branch == branch_name:
328+
if branch_name != registry.default_branch:
329+
default_branch_triggers = [
330+
trigger
331+
for trigger in all_triggers
332+
if trigger.branch == registry.default_branch
333+
and trigger.targets_self
334+
and trigger.computed_attribute.kind == modified_trigger.computed_attribute.kind
335+
and trigger.computed_attribute.attribute.name
336+
== modified_trigger.computed_attribute.attribute.name
337+
]
338+
if (
339+
default_branch_triggers
340+
and len(default_branch_triggers) == 1
341+
and default_branch_triggers[0].template_hash == modified_trigger.template_hash
342+
):
343+
log.debug(
344+
f"Skipping computed attribute updates for {modified_trigger.computed_attribute.kind}."
345+
f"{modified_trigger.computed_attribute.attribute.name} [{branch_name}], schema is identical to default branch"
346+
)
347+
continue
348+
323349
await get_workflow().submit_workflow(
324350
workflow=TRIGGER_UPDATE_JINJA_COMPUTED_ATTRIBUTES,
325351
context=context,
326352
parameters={
327-
"branch_name": branch,
328-
"computed_attribute_name": attribute_name,
329-
"computed_attribute_kind": kind,
353+
"branch_name": modified_trigger.branch,
354+
"computed_attribute_name": modified_trigger.computed_attribute.attribute.name,
355+
"computed_attribute_kind": modified_trigger.computed_attribute.kind,
330356
},
331357
)
332358

backend/infrahub/display_labels/tasks.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
from typing import cast
4-
53
from infrahub_sdk.exceptions import URLNotFoundError
64
from infrahub_sdk.template import Jinja2Template
75
from prefect import flow
@@ -139,11 +137,32 @@ async def display_labels_setup_jinja2(
139137
) # type: ignore[misc]
140138

141139
# Configure all DisplayLabelTriggerDefinitions in Prefect
142-
display_reports = [cast(DisplayLabelTriggerDefinition, entry) for entry in report.updated + report.created]
143-
direct_target_triggers = [display_report for display_report in display_reports if display_report.target_kind]
140+
all_triggers = report.triggers_with_type(trigger_type=DisplayLabelTriggerDefinition)
141+
direct_target_triggers = [
142+
display_report
143+
for display_report in report.modified_triggers_with_type(trigger_type=DisplayLabelTriggerDefinition)
144+
if display_report.target_kind
145+
]
144146

145147
for display_report in direct_target_triggers:
146148
if event_name != BranchDeletedEvent.event_name and display_report.branch == branch_name:
149+
if branch_name != registry.default_branch:
150+
default_branch_triggers = [
151+
trigger
152+
for trigger in all_triggers
153+
if trigger.branch == registry.default_branch
154+
and trigger.target_kind == display_report.target_kind
155+
]
156+
if (
157+
default_branch_triggers
158+
and len(default_branch_triggers) == 1
159+
and default_branch_triggers[0].template_hash == display_report.template_hash
160+
):
161+
log.debug(
162+
f"Skipping display label updates for {display_report.target_kind} [{branch_name}], schema is identical to default branch"
163+
)
164+
continue
165+
147166
await get_workflow().submit_workflow(
148167
workflow=TRIGGER_UPDATE_DISPLAY_LABELS,
149168
context=context,

backend/infrahub/hfid/tasks.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
from typing import cast
4-
53
from infrahub_sdk.exceptions import URLNotFoundError
64
from prefect import flow
75
from prefect.logging import get_run_logger
@@ -138,11 +136,32 @@ async def hfid_setup(context: InfrahubContext, branch_name: str | None = None, e
138136
) # type: ignore[misc]
139137

140138
# Configure all DisplayLabelTriggerDefinitions in Prefect
141-
hfid_reports = [cast(HFIDTriggerDefinition, entry) for entry in report.updated + report.created]
142-
direct_target_triggers = [hfid_report for hfid_report in hfid_reports if hfid_report.target_kind]
139+
all_triggers = report.triggers_with_type(trigger_type=HFIDTriggerDefinition)
140+
direct_target_triggers = [
141+
hfid_report
142+
for hfid_report in report.modified_triggers_with_type(trigger_type=HFIDTriggerDefinition)
143+
if hfid_report.target_kind
144+
]
143145

144146
for display_report in direct_target_triggers:
145147
if event_name != BranchDeletedEvent.event_name and display_report.branch == branch_name:
148+
if branch_name != registry.default_branch:
149+
default_branch_triggers = [
150+
trigger
151+
for trigger in all_triggers
152+
if trigger.branch == registry.default_branch
153+
and trigger.target_kind == display_report.target_kind
154+
]
155+
if (
156+
default_branch_triggers
157+
and len(default_branch_triggers) == 1
158+
and default_branch_triggers[0].hfid_hash == display_report.hfid_hash
159+
):
160+
log.debug(
161+
f"Skipping HFID updates for {display_report.target_kind} [{branch_name}], schema is identical to default branch"
162+
)
163+
continue
164+
146165
await get_workflow().submit_workflow(
147166
workflow=TRIGGER_UPDATE_HFID,
148167
context=context,

backend/infrahub/telemetry/task_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from infrahub.events.utils import get_all_events
99
from infrahub.trigger.constants import NAME_SEPARATOR
1010
from infrahub.trigger.models import TriggerType
11+
from infrahub.trigger.setup import gather_all_automations
1112

1213
from .models import TelemetryPrefectData, TelemetryWorkPoolData
1314

@@ -53,7 +54,7 @@ async def count_events(event_name: str) -> int:
5354

5455
@task(name="telemetry-gather-automations", task_run_name="Gather Automations", cache_policy=NONE)
5556
async def gather_prefect_automations(client: PrefectClient) -> dict[str, Any]:
56-
automations = await client.read_automations()
57+
automations = await gather_all_automations(client=client)
5758

5859
data: dict[str, Any] = {}
5960

backend/infrahub/trigger/models.py

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from __future__ import annotations
22

33
from datetime import timedelta
4-
from enum import Enum
5-
from typing import TYPE_CHECKING, Any
4+
from enum import Enum, StrEnum
5+
from typing import TYPE_CHECKING, Any, TypeVar
66

77
from prefect.events.actions import RunDeployment
88
from prefect.events.schemas.automations import Automation, Posture
@@ -18,16 +18,78 @@
1818
if TYPE_CHECKING:
1919
from uuid import UUID
2020

21+
T = TypeVar("T", bound="TriggerDefinition")
22+
23+
24+
class TriggerComparison(StrEnum):
25+
MATCH = "match" # Expected trigger and actual trigger is identical
26+
REFRESH = "refresh" # The branch parameters doesn't match, the hash does, refresh in Prefect but don't run triggers
27+
UPDATE = "update" # Neither branch or other data points match, update in Prefect and run triggers
28+
29+
@property
30+
def update_prefect(self) -> bool:
31+
return self in {TriggerComparison.REFRESH, TriggerComparison.UPDATE}
32+
2133

2234
class TriggerSetupReport(BaseModel):
2335
created: list[TriggerDefinition] = Field(default_factory=list)
36+
refreshed: list[TriggerDefinition] = Field(default_factory=list)
2437
updated: list[TriggerDefinition] = Field(default_factory=list)
2538
deleted: list[Automation] = Field(default_factory=list)
2639
unchanged: list[TriggerDefinition] = Field(default_factory=list)
2740

2841
@property
2942
def in_use_count(self) -> int:
30-
return len(self.created + self.updated + self.unchanged)
43+
return len(self.created + self.updated + self.unchanged + self.refreshed)
44+
45+
def add_with_comparison(self, trigger: TriggerDefinition, comparison: TriggerComparison) -> None:
46+
match comparison:
47+
case TriggerComparison.UPDATE:
48+
self.updated.append(trigger)
49+
case TriggerComparison.REFRESH:
50+
self.refreshed.append(trigger)
51+
case TriggerComparison.MATCH:
52+
self.unchanged.append(trigger)
53+
54+
def _created_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
55+
return [trigger for trigger in self.created if isinstance(trigger, trigger_type)]
56+
57+
def _refreshed_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
58+
return [trigger for trigger in self.refreshed if isinstance(trigger, trigger_type)]
59+
60+
def _unchanged_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
61+
return [trigger for trigger in self.unchanged if isinstance(trigger, trigger_type)]
62+
63+
def _updated_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
64+
return [trigger for trigger in self.updated if isinstance(trigger, trigger_type)]
65+
66+
def triggers_with_type(self, trigger_type: type[T]) -> list[T]:
67+
"""Return all triggers that match the specified type.
68+
69+
Args:
70+
trigger_type: A TriggerDefinition class or subclass to filter by
71+
72+
Returns:
73+
List of triggers of the specified type from all categories
74+
"""
75+
created = self._created_triggers_with_type(trigger_type=trigger_type)
76+
updated = self._updated_triggers_with_type(trigger_type=trigger_type)
77+
refreshed = self._refreshed_triggers_with_type(trigger_type=trigger_type)
78+
unchanged = self._unchanged_triggers_with_type(trigger_type=trigger_type)
79+
return created + updated + refreshed + unchanged
80+
81+
def modified_triggers_with_type(self, trigger_type: type[T]) -> list[T]:
82+
"""Return all created and updated triggers that match the specified type.
83+
84+
Args:
85+
trigger_type: A TriggerDefinition class or subclass to filter by
86+
87+
Returns:
88+
List of triggers of the specified type from both created and updated lists
89+
"""
90+
created = self._created_triggers_with_type(trigger_type=trigger_type)
91+
updated = self._updated_triggers_with_type(trigger_type=trigger_type)
92+
return created + updated
3193

3294

3395
class TriggerType(str, Enum):
@@ -41,6 +103,16 @@ class TriggerType(str, Enum):
41103
HUMAN_FRIENDLY_ID = "human_friendly_id"
42104
# OBJECT = "object"
43105

106+
@property
107+
def is_branch_specific(self) -> bool:
108+
return self in {
109+
TriggerType.COMPUTED_ATTR_JINJA2,
110+
TriggerType.COMPUTED_ATTR_PYTHON,
111+
TriggerType.COMPUTED_ATTR_PYTHON_QUERY,
112+
TriggerType.DISPLAY_LABEL_JINJA2,
113+
TriggerType.HUMAN_FRIENDLY_ID,
114+
}
115+
44116

45117
def _match_related_dict() -> dict:
46118
# Make Mypy happy as match related is a dict[str, Any] | list[dict[str, Any]]

0 commit comments

Comments
 (0)