Skip to content

Commit

Permalink
Consider updated fields for Jinja2 based computed attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
ogenstad committed Dec 23, 2024
1 parent aee97a8 commit ac64cdc
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 13 deletions.
12 changes: 8 additions & 4 deletions backend/infrahub/computed_attribute/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any

import ujson
from infrahub_sdk.protocols import CoreNode # noqa: TCH002
from prefect import flow
from prefect.automations import AutomationCore
Expand Down Expand Up @@ -227,12 +228,15 @@ async def process_jinja2(
object_id: str,
computed_attribute_name: str,
computed_attribute_kind: str,
updated_fields: list[str] | None = None,
updated_fields: str | None = None,
) -> None:
log = get_run_logger()
service = services.service

await add_tags(branches=[branch_name])
updates: list[str] = []
if isinstance(updated_fields, str):
updates = ujson.loads(updated_fields)

target_branch_schema = (
branch_name if branch_name in registry.get_altered_schema_branches() else registry.default_branch
Expand All @@ -242,9 +246,7 @@ async def process_jinja2(

computed_macros = [
attrib
for attrib in schema_branch.computed_attributes.get_impacted_jinja2_targets(
kind=node_kind, updates=updated_fields
)
for attrib in schema_branch.computed_attributes.get_impacted_jinja2_targets(kind=node_kind, updates=updates)
if attrib.kind == computed_attribute_kind and attrib.attribute.name == computed_attribute_name
]
for computed_macro in computed_macros:
Expand Down Expand Up @@ -369,6 +371,7 @@ async def computed_attribute_setup(branch_name: str | None = None) -> None: # p
"object_id": "{{ event.resource['infrahub.node.id'] }}",
"computed_attribute_name": computed_attribute.attribute.name,
"computed_attribute_kind": computed_attribute.kind,
"updated_fields": "{{ event.payload['fields'] | tojson }}",
},
job_variables={},
)
Expand Down Expand Up @@ -432,6 +435,7 @@ async def computed_attribute_setup(branch_name: str | None = None) -> None: # p
"object_id": "{{ event.resource['infrahub.node.id'] }}",
"computed_attribute_name": computed_attribute.attribute.name,
"computed_attribute_kind": computed_attribute.kind,
"updated_fields": "{{ event.payload['fields'] | tojson }}",
},
job_variables={},
)
Expand Down
3 changes: 2 additions & 1 deletion backend/infrahub/events/node_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class NodeMutatedEvent(InfrahubBranchEvent):
node_id: str = Field(..., description="The ID of the mutated node")
action: MutationAction = Field(..., description="The action taken on the node")
data: dict[str, Any] = Field(..., description="Data on modified object")
fields: list[str] = Field(default_factory=list, description="Fields provided in tha mutation")

def get_name(self) -> str:
return f"{self.get_event_namespace()}.node.{self.action.value}"
Expand All @@ -29,7 +30,7 @@ def get_resource(self) -> dict[str, str]:
}

def get_payload(self) -> dict[str, Any]:
return self.data
return {"data": self.data, "fields": self.fields}

def get_messages(self) -> list[InfrahubMessage]:
return [
Expand Down
10 changes: 6 additions & 4 deletions backend/infrahub/graphql/mutations/computed_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,20 @@ async def mutate(
)
if attribute_field.value != str(data.value):
attribute_field.value = str(data.value)
await target_node.save(db=context.db, fields=[str(data.attribute)])
async with context.db.start_transaction() as dbt:
await target_node.save(db=dbt, fields=[str(data.attribute)])

log_data = get_log_data()
request_id = log_data.get("request_id", "")
log_data = get_log_data()
request_id = log_data.get("request_id", "")

graphql_payload = await target_node.to_graphql(db=context.db, filter_sensitive=True)
graphql_payload = await target_node.to_graphql(db=dbt, filter_sensitive=True)

event = NodeMutatedEvent(
branch=context.branch.name,
kind=node_schema.kind,
node_id=target_node.get_id(),
data=graphql_payload,
fields=[str(data.attribute)],
action=MutationAction.UPDATED,
meta=EventMeta(initiator_id=WORKER_IDENTITY, request_id=request_id),
)
Expand Down
8 changes: 6 additions & 2 deletions backend/infrahub/graphql/mutations/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,17 @@ async def mutate(cls, root: dict, info: GraphQLResolveInfo, data: InputObjectTyp
request_id = log_data.get("request_id", "")

graphql_payload = await obj.to_graphql(db=context.db, filter_sensitive=True)

event = NodeMutatedEvent(
branch=context.branch.name,
kind=obj._schema.kind,
node_id=obj.id,
data=graphql_payload,
action=action,
fields=_get_data_fields(data),
meta=EventMeta(initiator_id=WORKER_IDENTITY, request_id=request_id),
)

context.background.add_task(context.service.event.send, event)
context.background.add_task(context.active_service.event.send, event)

return mutation

Expand Down Expand Up @@ -457,3 +457,7 @@ def _get_kind_lock_names_on_object_mutation(kind: str, branch: Branch, schema_br
lock_kinds = _get_kinds_to_lock_on_object_mutation(kind, schema_branch)
lock_names = [build_object_lock_name(kind) for kind in lock_kinds]
return lock_names


def _get_data_fields(data: InputObjectType) -> list[str]:
return [field for field in data.keys() if field not in ["id", "hfid"]]
2 changes: 1 addition & 1 deletion backend/infrahub/webhook/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ async def configure_webhooks() -> None:
deployment_id=deployment_id_webhook_trigger,
parameters={
"event_type": "{{ event.resource['infrahub.node.kind'] }}.{{ event.resource['infrahub.node.action'] }}",
"event_data": "{{ event.payload | tojson }}",
"event_data": "{{ event.payload['data'] | tojson }}",
},
job_variables={},
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def test_description_after_color_change_jinja2(
object_id=tshirt_1.id,
computed_attribute_kind="TestingTShirt",
computed_attribute_name="description",
updated_fields=["color"],
updated_fields='"color"',
)

tshirt_updated = await client.get(kind="TestingTShirt", id=data["t1"].id)
Expand Down

0 comments on commit ac64cdc

Please sign in to comment.