diff --git a/backend/infrahub/computed_attribute/models.py b/backend/infrahub/computed_attribute/models.py index 213fd15918..3847497744 100644 --- a/backend/infrahub/computed_attribute/models.py +++ b/backend/infrahub/computed_attribute/models.py @@ -119,6 +119,12 @@ class ComputedAttrJinja2TriggerDefinition(TriggerBranchDefinition): type: TriggerType = TriggerType.COMPUTED_ATTR_JINJA2 computed_attribute: ComputedAttributeTarget template_hash: str + trigger_kind: str + + @property + def targets_self(self) -> bool: + """Determine if the specific trigger definition targets the actual node kind of the computed attribute.""" + return self.trigger_kind == self.computed_attribute.kind def get_description(self) -> str: return f"{super().get_description()} | hash:{self.template_hash}" @@ -190,6 +196,7 @@ def from_computed_attribute( definition = cls( name=f"{computed_attribute.key_name}{NAME_SEPARATOR}kind{NAME_SEPARATOR}{trigger_node.kind}", template_hash=template_hash, + trigger_kind=trigger_node.kind, branch=branch, computed_attribute=computed_attribute, trigger=event_trigger, diff --git a/backend/infrahub/computed_attribute/tasks.py b/backend/infrahub/computed_attribute/tasks.py index 403b0ced58..d5a74e3dc2 100644 --- a/backend/infrahub/computed_attribute/tasks.py +++ b/backend/infrahub/computed_attribute/tasks.py @@ -29,6 +29,7 @@ from .models import ( ComputedAttrJinja2GraphQL, ComputedAttrJinja2GraphQLResponse, + ComputedAttrJinja2TriggerDefinition, PythonTransformTarget, ) @@ -312,21 +313,46 @@ async def computed_attribute_setup_jinja2( ) # type: ignore[misc] # Configure all ComputedAttrJinja2Trigger in Prefect + all_triggers = report.triggers_with_type(trigger_type=ComputedAttrJinja2TriggerDefinition) + # Since we can have multiple trigger per NodeKind - # we need to extract the list of unique node that should be processed - unique_nodes: set[tuple[str, str, str]] = { - (trigger.branch, trigger.computed_attribute.kind, trigger.computed_attribute.attribute.name) # type: ignore[attr-defined] - for trigger in report.updated + report.created - } - for branch, kind, attribute_name in unique_nodes: - if event_name != BranchDeletedEvent.event_name and branch == branch_name: + # we need to extract the list of unique node that should be processed, this is done by filtering the triggers that targets_self + modified_triggers = [ + trigger + for trigger in report.modified_triggers_with_type(trigger_type=ComputedAttrJinja2TriggerDefinition) + if trigger.targets_self + ] + + for modified_trigger in modified_triggers: + if event_name != BranchDeletedEvent.event_name and modified_trigger.branch == branch_name: + if branch_name != registry.default_branch: + default_branch_triggers = [ + trigger + for trigger in all_triggers + if trigger.branch == registry.default_branch + and trigger.targets_self + and trigger.computed_attribute.kind == modified_trigger.computed_attribute.kind + and trigger.computed_attribute.attribute.name + == modified_trigger.computed_attribute.attribute.name + ] + if ( + default_branch_triggers + and len(default_branch_triggers) == 1 + and default_branch_triggers[0].template_hash == modified_trigger.template_hash + ): + log.debug( + f"Skipping computed attribute updates for {modified_trigger.computed_attribute.kind}." + f"{modified_trigger.computed_attribute.attribute.name} [{branch_name}], schema is identical to default branch" + ) + continue + await get_workflow().submit_workflow( workflow=TRIGGER_UPDATE_JINJA_COMPUTED_ATTRIBUTES, context=context, parameters={ - "branch_name": branch, - "computed_attribute_name": attribute_name, - "computed_attribute_kind": kind, + "branch_name": modified_trigger.branch, + "computed_attribute_name": modified_trigger.computed_attribute.attribute.name, + "computed_attribute_kind": modified_trigger.computed_attribute.kind, }, ) diff --git a/backend/infrahub/display_labels/tasks.py b/backend/infrahub/display_labels/tasks.py index 2ce703d02c..d99baa83cb 100644 --- a/backend/infrahub/display_labels/tasks.py +++ b/backend/infrahub/display_labels/tasks.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import cast - from infrahub_sdk.exceptions import URLNotFoundError from infrahub_sdk.template import Jinja2Template from prefect import flow @@ -139,11 +137,32 @@ async def display_labels_setup_jinja2( ) # type: ignore[misc] # Configure all DisplayLabelTriggerDefinitions in Prefect - display_reports = [cast(DisplayLabelTriggerDefinition, entry) for entry in report.updated + report.created] - direct_target_triggers = [display_report for display_report in display_reports if display_report.target_kind] + all_triggers = report.triggers_with_type(trigger_type=DisplayLabelTriggerDefinition) + direct_target_triggers = [ + display_report + for display_report in report.modified_triggers_with_type(trigger_type=DisplayLabelTriggerDefinition) + if display_report.target_kind + ] for display_report in direct_target_triggers: if event_name != BranchDeletedEvent.event_name and display_report.branch == branch_name: + if branch_name != registry.default_branch: + default_branch_triggers = [ + trigger + for trigger in all_triggers + if trigger.branch == registry.default_branch + and trigger.target_kind == display_report.target_kind + ] + if ( + default_branch_triggers + and len(default_branch_triggers) == 1 + and default_branch_triggers[0].template_hash == display_report.template_hash + ): + log.debug( + f"Skipping display label updates for {display_report.target_kind} [{branch_name}], schema is identical to default branch" + ) + continue + await get_workflow().submit_workflow( workflow=TRIGGER_UPDATE_DISPLAY_LABELS, context=context, diff --git a/backend/infrahub/hfid/tasks.py b/backend/infrahub/hfid/tasks.py index 9e0071a1a1..1683676b1a 100644 --- a/backend/infrahub/hfid/tasks.py +++ b/backend/infrahub/hfid/tasks.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import cast - from infrahub_sdk.exceptions import URLNotFoundError from prefect import flow from prefect.logging import get_run_logger @@ -138,11 +136,32 @@ async def hfid_setup(context: InfrahubContext, branch_name: str | None = None, e ) # type: ignore[misc] # Configure all DisplayLabelTriggerDefinitions in Prefect - hfid_reports = [cast(HFIDTriggerDefinition, entry) for entry in report.updated + report.created] - direct_target_triggers = [hfid_report for hfid_report in hfid_reports if hfid_report.target_kind] + all_triggers = report.triggers_with_type(trigger_type=HFIDTriggerDefinition) + direct_target_triggers = [ + hfid_report + for hfid_report in report.modified_triggers_with_type(trigger_type=HFIDTriggerDefinition) + if hfid_report.target_kind + ] for display_report in direct_target_triggers: if event_name != BranchDeletedEvent.event_name and display_report.branch == branch_name: + if branch_name != registry.default_branch: + default_branch_triggers = [ + trigger + for trigger in all_triggers + if trigger.branch == registry.default_branch + and trigger.target_kind == display_report.target_kind + ] + if ( + default_branch_triggers + and len(default_branch_triggers) == 1 + and default_branch_triggers[0].hfid_hash == display_report.hfid_hash + ): + log.debug( + f"Skipping HFID updates for {display_report.target_kind} [{branch_name}], schema is identical to default branch" + ) + continue + await get_workflow().submit_workflow( workflow=TRIGGER_UPDATE_HFID, context=context, diff --git a/backend/infrahub/telemetry/task_manager.py b/backend/infrahub/telemetry/task_manager.py index 7588d24da3..d408754c94 100644 --- a/backend/infrahub/telemetry/task_manager.py +++ b/backend/infrahub/telemetry/task_manager.py @@ -8,6 +8,7 @@ from infrahub.events.utils import get_all_events from infrahub.trigger.constants import NAME_SEPARATOR from infrahub.trigger.models import TriggerType +from infrahub.trigger.setup import gather_all_automations from .models import TelemetryPrefectData, TelemetryWorkPoolData @@ -53,7 +54,7 @@ async def count_events(event_name: str) -> int: @task(name="telemetry-gather-automations", task_run_name="Gather Automations", cache_policy=NONE) async def gather_prefect_automations(client: PrefectClient) -> dict[str, Any]: - automations = await client.read_automations() + automations = await gather_all_automations(client=client) data: dict[str, Any] = {} diff --git a/backend/infrahub/trigger/models.py b/backend/infrahub/trigger/models.py index 72aec78a5e..f257556c65 100644 --- a/backend/infrahub/trigger/models.py +++ b/backend/infrahub/trigger/models.py @@ -1,8 +1,8 @@ from __future__ import annotations from datetime import timedelta -from enum import Enum -from typing import TYPE_CHECKING, Any +from enum import Enum, StrEnum +from typing import TYPE_CHECKING, Any, TypeVar from prefect.events.actions import RunDeployment from prefect.events.schemas.automations import Automation, Posture @@ -18,16 +18,78 @@ if TYPE_CHECKING: from uuid import UUID +T = TypeVar("T", bound="TriggerDefinition") + + +class TriggerComparison(StrEnum): + MATCH = "match" # Expected trigger and actual trigger is identical + REFRESH = "refresh" # The branch parameters doesn't match, the hash does, refresh in Prefect but don't run triggers + UPDATE = "update" # Neither branch or other data points match, update in Prefect and run triggers + + @property + def update_prefect(self) -> bool: + return self in {TriggerComparison.REFRESH, TriggerComparison.UPDATE} + class TriggerSetupReport(BaseModel): created: list[TriggerDefinition] = Field(default_factory=list) + refreshed: list[TriggerDefinition] = Field(default_factory=list) updated: list[TriggerDefinition] = Field(default_factory=list) deleted: list[Automation] = Field(default_factory=list) unchanged: list[TriggerDefinition] = Field(default_factory=list) @property def in_use_count(self) -> int: - return len(self.created + self.updated + self.unchanged) + return len(self.created + self.updated + self.unchanged + self.refreshed) + + def add_with_comparison(self, trigger: TriggerDefinition, comparison: TriggerComparison) -> None: + match comparison: + case TriggerComparison.UPDATE: + self.updated.append(trigger) + case TriggerComparison.REFRESH: + self.refreshed.append(trigger) + case TriggerComparison.MATCH: + self.unchanged.append(trigger) + + def _created_triggers_with_type(self, trigger_type: type[T]) -> list[T]: + return [trigger for trigger in self.created if isinstance(trigger, trigger_type)] + + def _refreshed_triggers_with_type(self, trigger_type: type[T]) -> list[T]: + return [trigger for trigger in self.refreshed if isinstance(trigger, trigger_type)] + + def _unchanged_triggers_with_type(self, trigger_type: type[T]) -> list[T]: + return [trigger for trigger in self.unchanged if isinstance(trigger, trigger_type)] + + def _updated_triggers_with_type(self, trigger_type: type[T]) -> list[T]: + return [trigger for trigger in self.updated if isinstance(trigger, trigger_type)] + + def triggers_with_type(self, trigger_type: type[T]) -> list[T]: + """Return all triggers that match the specified type. + + Args: + trigger_type: A TriggerDefinition class or subclass to filter by + + Returns: + List of triggers of the specified type from all categories + """ + created = self._created_triggers_with_type(trigger_type=trigger_type) + updated = self._updated_triggers_with_type(trigger_type=trigger_type) + refreshed = self._refreshed_triggers_with_type(trigger_type=trigger_type) + unchanged = self._unchanged_triggers_with_type(trigger_type=trigger_type) + return created + updated + refreshed + unchanged + + def modified_triggers_with_type(self, trigger_type: type[T]) -> list[T]: + """Return all created and updated triggers that match the specified type. + + Args: + trigger_type: A TriggerDefinition class or subclass to filter by + + Returns: + List of triggers of the specified type from both created and updated lists + """ + created = self._created_triggers_with_type(trigger_type=trigger_type) + updated = self._updated_triggers_with_type(trigger_type=trigger_type) + return created + updated class TriggerType(str, Enum): @@ -41,6 +103,16 @@ class TriggerType(str, Enum): HUMAN_FRIENDLY_ID = "human_friendly_id" # OBJECT = "object" + @property + def is_branch_specific(self) -> bool: + return self in { + TriggerType.COMPUTED_ATTR_JINJA2, + TriggerType.COMPUTED_ATTR_PYTHON, + TriggerType.COMPUTED_ATTR_PYTHON_QUERY, + TriggerType.DISPLAY_LABEL_JINJA2, + TriggerType.HUMAN_FRIENDLY_ID, + } + def _match_related_dict() -> dict: # Make Mypy happy as match related is a dict[str, Any] | list[dict[str, Any]] diff --git a/backend/infrahub/trigger/setup.py b/backend/infrahub/trigger/setup.py index f9d33f8972..01b8e24424 100644 --- a/backend/infrahub/trigger/setup.py +++ b/backend/infrahub/trigger/setup.py @@ -12,22 +12,36 @@ from infrahub.database import InfrahubDatabase from infrahub.trigger.models import TriggerDefinition -from .models import TriggerSetupReport, TriggerType +from .models import TriggerComparison, TriggerSetupReport, TriggerType if TYPE_CHECKING: from uuid import UUID -def compare_automations(target: AutomationCore, existing: Automation) -> bool: - """Compare an AutomationCore with an existing Automation object to identify if they are identical or not - - Return True if the target is identical to the existing automation +def compare_automations( + target: AutomationCore, existing: Automation, trigger_type: TriggerType | None, force_update: bool = False +) -> TriggerComparison: + """Compare an AutomationCore with an existing Automation object to identify if they are identical, + if it's a branch specific automation and the branch filter may be different, or if they are different. """ + if force_update: + return TriggerComparison.UPDATE + target_dump = target.model_dump(exclude_defaults=True, exclude_none=True) existing_dump = existing.model_dump(exclude_defaults=True, exclude_none=True, exclude={"id"}) - return target_dump == existing_dump + if target_dump == existing_dump: + return TriggerComparison.MATCH + + if not trigger_type or not trigger_type.is_branch_specific: + return TriggerComparison.UPDATE + + if target.description == existing.description: + # If only the branch related info is different, we consider it a refresh + return TriggerComparison.REFRESH + + return TriggerComparison.UPDATE @task(name="trigger-setup-specific", task_run_name="Setup triggers of a specific kind", cache_policy=NONE) # type: ignore[arg-type] @@ -63,10 +77,8 @@ async def setup_triggers( report = TriggerSetupReport() - if trigger_type: - log.debug(f"Setting up triggers of type {trigger_type.value}") - else: - log.debug("Setting up all triggers") + trigger_log_message = f"triggers of type {trigger_type.value}" if trigger_type else "all triggers" + log.debug(f"Setting up {trigger_log_message}") # ------------------------------------------------------------- # Retrieve existing Deployments and Automation from the server @@ -80,16 +92,14 @@ async def setup_triggers( } deployments_mapping: dict[str, UUID] = {name: item.id for name, item in deployments.items()} - # If a trigger type is provided, narrow down the list of existing triggers to know which one to delete - existing_automations: dict[str, Automation] = {} + existing_automations = {item.name: item for item in await gather_all_automations(client=client)} if trigger_type: + # If a trigger type is provided, narrow down the list of existing triggers to know which one to delete existing_automations = { - item.name: item - for item in await client.read_automations() - if item.name.startswith(f"{trigger_type.value}::") + automation_name: automation + for automation_name, automation in existing_automations.items() + if automation_name.startswith(f"{trigger_type.value}::") } - else: - existing_automations = {item.name: item for item in await client.read_automations()} trigger_names = [trigger.generate_name() for trigger in triggers] automation_names = list(existing_automations.keys()) @@ -115,12 +125,13 @@ async def setup_triggers( existing_automation = existing_automations.get(trigger.generate_name(), None) if existing_automation: - if force_update or not compare_automations(target=automation, existing=existing_automation): + trigger_comparison = compare_automations( + target=automation, existing=existing_automation, trigger_type=trigger_type, force_update=force_update + ) + if trigger_comparison.update_prefect: await client.update_automation(automation_id=existing_automation.id, automation=automation) log.info(f"{trigger.generate_name()} Updated") - report.updated.append(trigger) - else: - report.unchanged.append(trigger) + report.add_with_comparison(trigger, trigger_comparison) else: await client.create_automation(automation=automation) log.info(f"{trigger.generate_name()} Created") @@ -145,15 +156,34 @@ async def setup_triggers( else: raise - if trigger_type: - log.info( - f"Processed triggers of type {trigger_type.value}: " - f"{len(report.created)} created, {len(report.updated)} updated, {len(report.unchanged)} unchanged, {len(report.deleted)} deleted" - ) - else: - log.info( - f"Processed all triggers: " - f"{len(report.created)} created, {len(report.updated)} updated, {len(report.unchanged)} unchanged, {len(report.deleted)} deleted" - ) + log.info( + f"Processed {trigger_log_message}: {len(report.created)} created, {len(report.updated)} updated, " + f"{len(report.refreshed)} refreshed, {len(report.unchanged)} unchanged, {len(report.deleted)} deleted" + ) return report + + +async def gather_all_automations(client: PrefectClient) -> list[Automation]: + """Gather all automations from the Prefect server + + By default the Prefect client only retrieves a limited number of automations, this function + retrieves them all by paginating through the results. The default within Prefect is 200 items, + and client.read_automations() doesn't support pagination parameters. + """ + automation_count_response = await client.request("POST", "/automations/count") + automation_count_response.raise_for_status() + automation_count: int = automation_count_response.json() + offset = 0 + limit = 200 + missing_automations = True + automations: list[Automation] = [] + while missing_automations: + response = await client.request("POST", "/automations/filter", json={"limit": limit, "offset": offset}) + response.raise_for_status() + automations.extend(Automation.model_validate_list(response.json())) + if len(automations) >= automation_count: + missing_automations = False + offset += limit + + return automations diff --git a/backend/tests/functional/webhook/test_task.py b/backend/tests/functional/webhook/test_task.py index b3d5716b50..4c6ef14ca9 100644 --- a/backend/tests/functional/webhook/test_task.py +++ b/backend/tests/functional/webhook/test_task.py @@ -10,6 +10,7 @@ from infrahub.core.constants import InfrahubKind from infrahub.core.node import Node +from infrahub.trigger.setup import gather_all_automations from infrahub.webhook.models import EventContext, WebhookTriggerDefinition from infrahub.webhook.tasks import ( configure_webhook_all, @@ -209,7 +210,7 @@ async def test_configure_all( ) -> None: await configure_webhook_all() - automations = await prefect_client.read_automations() + automations = await gather_all_automations(client=prefect_client) automations_by_name = {automation.name: automation for automation in automations} assert f"webhook::{webhook1.id}" in automations_by_name.keys() diff --git a/backend/tests/integration_docker/test_triggered_actions.py b/backend/tests/integration_docker/test_triggered_actions.py index 13aaf2daef..19122a650d 100644 --- a/backend/tests/integration_docker/test_triggered_actions.py +++ b/backend/tests/integration_docker/test_triggered_actions.py @@ -27,6 +27,7 @@ from infrahub.core.constants import InfrahubKind from infrahub.trigger.constants import NAME_SEPARATOR +from infrahub.trigger.setup import gather_all_automations from tests.helpers.fixtures import get_fixtures_dir CURRENT_DIRECTORY = Path(__file__).parent.resolve() @@ -97,7 +98,7 @@ async def wait_until_automations_are_configured(self, automation_names: list[str retry = 0 while continue_waiting: - automations = await client.read_automations() + automations = await gather_all_automations(client=client) observed_automation_names = [automation.name.split(NAME_SEPARATOR)[-1] for automation in automations] if set(automation_names).issubset(observed_automation_names): continue_waiting = False diff --git a/backend/tests/unit/trigger/test_tasks.py b/backend/tests/unit/trigger/test_tasks.py index a183649ce1..05b2e9edd5 100644 --- a/backend/tests/unit/trigger/test_tasks.py +++ b/backend/tests/unit/trigger/test_tasks.py @@ -3,7 +3,7 @@ from infrahub.trigger.catalogue import builtin_triggers from infrahub.trigger.models import EventTrigger, TriggerType -from infrahub.trigger.setup import setup_triggers +from infrahub.trigger.setup import gather_all_automations, setup_triggers from infrahub.workflows.initialization import setup_deployments, setup_worker_pools @@ -15,7 +15,7 @@ async def prefect_client(prefect_test_fixture): @pytest.fixture async def cleanup_automation(prefect_client: PrefectClient) -> None: - automations = await prefect_client.read_automations() + automations = await gather_all_automations(client=prefect_client) for automation in automations: await prefect_client.delete_automation(automation.id) @@ -34,7 +34,7 @@ async def test_setup_triggers(prefect_client: PrefectClient, init_prefect, clean assert len(report.unchanged) == 0 assert len(report.created) == len(builtin_triggers) - automations = await prefect_client.read_automations() + automations = await gather_all_automations(client=prefect_client) assert len(automations) == len(builtin_triggers) # Update 1 Trigger and remove 2 to ensure that setup_triggers is working as expected @@ -48,7 +48,7 @@ async def test_setup_triggers(prefect_client: PrefectClient, init_prefect, clean assert len(report_after.unchanged) == len(builtin_triggers) - 3 assert len(report_after.created) == 0 - automations = await prefect_client.read_automations() + automations = await gather_all_automations(client=prefect_client) assert len(automations) == len(builtin_triggers[:-2]) # Ensure force_update is working properly diff --git a/changelog/7692.fixed.md b/changelog/7692.fixed.md new file mode 100644 index 0000000000..415317ac98 --- /dev/null +++ b/changelog/7692.fixed.md @@ -0,0 +1 @@ +Refactor task setup to avoid excessive tasks being scheduled for branches that previously didn't contain tasks. The updated behaviour is that the task will only be triggered on the branch if the task signature differs from that of the default branch.