Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from __future__ import annotations

from collections.abc import Callable, Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
Expand All @@ -52,6 +53,7 @@
if TYPE_CHECKING:
from google.api_core.retry import Retry

from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -359,15 +361,18 @@ def __init__(
self.timeout = timeout
self.metadata = metadata
self.impersonation_chain = impersonation_chain
self._resolved_subscription_name: str | None = None

def execute(self, context: Context) -> str:
hook = PubSubHook(
@cached_property
def pubsub_hook(self):
return PubSubHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

def execute(self, context: Context) -> str:
self.log.info("Creating subscription for topic %s", self.topic)
result = hook.create_subscription(
result = self.pubsub_hook.create_subscription(
project_id=self.project_id,
topic=self.topic,
subscription=self.subscription,
Expand All @@ -389,13 +394,29 @@ def execute(self, context: Context) -> str:
)

self.log.info("Created subscription for topic %s", self.topic)

# Store resolved subscription for Open Lineage
self._resolved_subscription_name = self.subscription or result

PubSubSubscriptionLink.persist(
context=context,
subscription_id=self.subscription or result, # result returns subscription name
project_id=self.project_id or hook.project_id,
subscription_id=self._resolved_subscription_name, # result returns subscription name
project_id=self.project_id or self.pubsub_hook.project_id,
)
return result

def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.openlineage.extractors import OperatorLineage

project_id = self.subscription_project_id or self.project_id or self.pubsub_hook.project_id

output_dataset = [
Dataset(namespace="pubsub", name=f"subscription:{project_id}:{self._resolved_subscription_name}")
]

return OperatorLineage(outputs=output_dataset)


class PubSubDeleteTopicOperator(GoogleCloudBaseOperator):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,80 @@ def test_execute_no_subscription(self, mock_hook):
)
assert response == TEST_SUBSCRIPTION

@pytest.mark.parametrize(
"subscription, project_id, subscription_project_id, expected_dataset",
[
# 1. Subscription provided, project_id provided, subscription_project_id not provided
(TEST_SUBSCRIPTION, TEST_PROJECT, None, f"subscription:{TEST_PROJECT}:{TEST_SUBSCRIPTION}"),
# 2. Subscription provided, subscription_project_id provided
(
TEST_SUBSCRIPTION,
TEST_PROJECT,
"another-project",
f"subscription:another-project:{TEST_SUBSCRIPTION}",
),
# 3. Subscription not provided (generated), project_id provided
(None, TEST_PROJECT, None, f"subscription:{TEST_PROJECT}:generated"),
# 4. Subscription not provided, subscription_project_id provided
(None, TEST_PROJECT, "another-project", "subscription:another-project:generated"),
# 5. Neither subscription nor project_id provided (use project_id from connection)
(None, None, None, "subscription:connection-project:generated"),
],
)
@mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
def test_get_openlineage_facets(
self,
mock_hook,
subscription,
project_id,
subscription_project_id,
expected_dataset,
):
operator = PubSubCreateSubscriptionOperator(
task_id=TASK_ID,
project_id=project_id,
topic=TEST_TOPIC,
subscription=subscription,
subscription_project_id=subscription_project_id,
)
mock_hook.return_value.create_subscription.return_value = subscription or "generated"
mock_hook.return_value.project_id = subscription_project_id or project_id or "connection-project"
context = mock.MagicMock()
response = operator.execute(context=context)
mock_hook.return_value.create_subscription.assert_called_once_with(
project_id=project_id,
topic=TEST_TOPIC,
subscription=subscription,
subscription_project_id=subscription_project_id,
ack_deadline_secs=10,
fail_if_exists=False,
push_config=None,
retain_acked_messages=None,
message_retention_duration=None,
labels=None,
enable_message_ordering=False,
expiration_policy=None,
filter_=None,
dead_letter_policy=None,
retry_policy=None,
retry=DEFAULT,
timeout=None,
metadata=(),
)

if subscription:
assert response == TEST_SUBSCRIPTION
else:
assert response == "generated"

result = operator.get_openlineage_facets_on_complete(operator)
assert not result.run_facets
assert not result.job_facets
assert len(result.inputs) == 0
assert len(result.outputs) == 1
assert result.outputs[0].namespace == "pubsub"
assert result.outputs[0].name == expected_dataset


class TestPubSubSubscriptionDeleteOperator:
@mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
Expand Down
Loading