diff --git a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py index 100c479593b7e..b55706e2ab87f 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py @@ -713,17 +713,28 @@ def __init__( self.enable_message_ordering = enable_message_ordering self.impersonation_chain = impersonation_chain - def execute(self, context: Context) -> None: - hook = PubSubHook( + @cached_property + def pubsub_hook(self): + return PubSubHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain, enable_message_ordering=self.enable_message_ordering, ) + def execute(self, context: Context) -> None: self.log.info("Publishing to topic %s", self.topic) - hook.publish(project_id=self.project_id, topic=self.topic, messages=self.messages) + self.pubsub_hook.publish(project_id=self.project_id, topic=self.topic, messages=self.messages) self.log.info("Published to topic %s", self.topic) + def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.openlineage.extractors import OperatorLineage + + project_id = self.project_id or self.pubsub_hook.project_id + output_dataset = [Dataset(namespace="pubsub", name=f"topic:{project_id}:{self.topic}")] + + return OperatorLineage(outputs=output_dataset) + class PubSubPullOperator(GoogleCloudBaseOperator): """ diff --git a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py index d62ef1ae47580..ea2865aefa192 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py +++ b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py @@ -329,6 +329,38 @@ def test_publish_with_ordering_key(self, mock_hook): project_id=TEST_PROJECT, topic=TEST_TOPIC, messages=TEST_MESSAGES_ORDERING_KEY ) + @pytest.mark.parametrize( + "project_id, expected_dataset", + [ + # 1. project_id provided + (TEST_PROJECT, f"topic:{TEST_PROJECT}:{TEST_TOPIC}"), + # 2. project_id not provided (use project_id from connection) + (None, f"topic:connection-project:{TEST_TOPIC}"), + ], + ) + @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook") + def test_get_openlineage_facets(self, mock_hook, project_id, expected_dataset): + operator = PubSubPublishMessageOperator( + task_id=TASK_ID, + project_id=project_id, + topic=TEST_TOPIC, + messages=TEST_MESSAGES, + ) + + operator.execute(None) + mock_hook.return_value.publish.assert_called_once_with( + project_id=project_id, topic=TEST_TOPIC, messages=TEST_MESSAGES + ) + mock_hook.return_value.project_id = project_id or "connection-project" + + result = operator.get_openlineage_facets_on_complete(operator) + assert not result.run_facets + assert not result.job_facets + assert len(result.inputs) == 0 + assert len(result.outputs) == 1 + assert result.outputs[0].namespace == "pubsub" + assert result.outputs[0].name == expected_dataset + class TestPubSubPullOperator: def _generate_messages(self, count):