diff --git a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py index c24cacae49368..100c479593b7e 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py @@ -26,6 +26,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence +from functools import cached_property from typing import TYPE_CHECKING, Any from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault @@ -52,6 +53,7 @@ if TYPE_CHECKING: from google.api_core.retry import Retry + from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.context import Context @@ -359,15 +361,18 @@ def __init__( self.timeout = timeout self.metadata = metadata self.impersonation_chain = impersonation_chain + self._resolved_subscription_name: str | None = None - def execute(self, context: Context) -> str: - hook = PubSubHook( + @cached_property + def pubsub_hook(self): + return PubSubHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, ) + def execute(self, context: Context) -> str: self.log.info("Creating subscription for topic %s", self.topic) - result = hook.create_subscription( + result = self.pubsub_hook.create_subscription( project_id=self.project_id, topic=self.topic, subscription=self.subscription, @@ -389,13 +394,29 @@ def execute(self, context: Context) -> str: ) self.log.info("Created subscription for topic %s", self.topic) + + # Store resolved subscription for Open Lineage + self._resolved_subscription_name = self.subscription or result + PubSubSubscriptionLink.persist( context=context, - subscription_id=self.subscription or result, # result returns subscription name - project_id=self.project_id or hook.project_id, + subscription_id=self._resolved_subscription_name, # result returns subscription name + project_id=self.project_id or self.pubsub_hook.project_id, ) return result + def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.openlineage.extractors import OperatorLineage + + project_id = self.subscription_project_id or self.project_id or self.pubsub_hook.project_id + + output_dataset = [ + Dataset(namespace="pubsub", name=f"subscription:{project_id}:{self._resolved_subscription_name}") + ] + + return OperatorLineage(outputs=output_dataset) + class PubSubDeleteTopicOperator(GoogleCloudBaseOperator): """ diff --git a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py index 6f2a13ca09e8d..d62ef1ae47580 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py +++ b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py @@ -206,6 +206,80 @@ def test_execute_no_subscription(self, mock_hook): ) assert response == TEST_SUBSCRIPTION + @pytest.mark.parametrize( + "subscription, project_id, subscription_project_id, expected_dataset", + [ + # 1. Subscription provided, project_id provided, subscription_project_id not provided + (TEST_SUBSCRIPTION, TEST_PROJECT, None, f"subscription:{TEST_PROJECT}:{TEST_SUBSCRIPTION}"), + # 2. Subscription provided, subscription_project_id provided + ( + TEST_SUBSCRIPTION, + TEST_PROJECT, + "another-project", + f"subscription:another-project:{TEST_SUBSCRIPTION}", + ), + # 3. Subscription not provided (generated), project_id provided + (None, TEST_PROJECT, None, f"subscription:{TEST_PROJECT}:generated"), + # 4. Subscription not provided, subscription_project_id provided + (None, TEST_PROJECT, "another-project", "subscription:another-project:generated"), + # 5. Neither subscription nor project_id provided (use project_id from connection) + (None, None, None, "subscription:connection-project:generated"), + ], + ) + @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook") + def test_get_openlineage_facets( + self, + mock_hook, + subscription, + project_id, + subscription_project_id, + expected_dataset, + ): + operator = PubSubCreateSubscriptionOperator( + task_id=TASK_ID, + project_id=project_id, + topic=TEST_TOPIC, + subscription=subscription, + subscription_project_id=subscription_project_id, + ) + mock_hook.return_value.create_subscription.return_value = subscription or "generated" + mock_hook.return_value.project_id = subscription_project_id or project_id or "connection-project" + context = mock.MagicMock() + response = operator.execute(context=context) + mock_hook.return_value.create_subscription.assert_called_once_with( + project_id=project_id, + topic=TEST_TOPIC, + subscription=subscription, + subscription_project_id=subscription_project_id, + ack_deadline_secs=10, + fail_if_exists=False, + push_config=None, + retain_acked_messages=None, + message_retention_duration=None, + labels=None, + enable_message_ordering=False, + expiration_policy=None, + filter_=None, + dead_letter_policy=None, + retry_policy=None, + retry=DEFAULT, + timeout=None, + metadata=(), + ) + + if subscription: + assert response == TEST_SUBSCRIPTION + else: + assert response == "generated" + + result = operator.get_openlineage_facets_on_complete(operator) + assert not result.run_facets + assert not result.job_facets + assert len(result.inputs) == 0 + assert len(result.outputs) == 1 + assert result.outputs[0].namespace == "pubsub" + assert result.outputs[0].name == expected_dataset + class TestPubSubSubscriptionDeleteOperator: @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")