diff --git a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py index b55706e2ab87f..25fb715bb8105 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/pubsub.py @@ -885,3 +885,13 @@ def _default_message_callback( messages_json = [ReceivedMessage.to_dict(m) for m in pulled_messages] return messages_json + + def get_openlineage_facets_on_complete(self, _) -> OperatorLineage: + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.openlineage.extractors import OperatorLineage + + output_dataset = [ + Dataset(namespace="pubsub", name=f"subscription:{self.project_id}:{self.subscription}") + ] + + return OperatorLineage(outputs=output_dataset) diff --git a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py index ea2865aefa192..8efd1935d7260 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_pubsub.py +++ b/providers/google/tests/unit/google/cloud/operators/test_pubsub.py @@ -463,3 +463,28 @@ def test_execute_deferred(self, mock_hook, create_task_instance_of_operator): ) with pytest.raises(TaskDeferred) as _: ti.task.execute(mock.MagicMock()) + + @mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook") + def test_get_openlineage_facets(self, mock_hook): + operator = PubSubPullOperator( + task_id=TASK_ID, + project_id=TEST_PROJECT, + subscription=TEST_SUBSCRIPTION, + ) + + generated_messages = self._generate_messages(5) + generated_dicts = self._generate_dicts(5) + mock_hook.return_value.pull.return_value = generated_messages + + assert generated_dicts == operator.execute({}) + mock_hook.return_value.pull.assert_called_once_with( + project_id=TEST_PROJECT, subscription=TEST_SUBSCRIPTION, max_messages=5, return_immediately=True + ) + + result = operator.get_openlineage_facets_on_complete(operator) + assert not result.run_facets + assert not result.job_facets + assert len(result.inputs) == 0 + assert len(result.outputs) == 1 + assert result.outputs[0].namespace == "pubsub" + assert result.outputs[0].name == f"subscription:{TEST_PROJECT}:{TEST_SUBSCRIPTION}"