Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -409,13 +409,18 @@ def get_openlineage_facets_on_complete(self, _) -> OperatorLineage:
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.openlineage.extractors import OperatorLineage

project_id = self.subscription_project_id or self.project_id or self.pubsub_hook.project_id

output_dataset = [
Dataset(namespace="pubsub", name=f"subscription:{project_id}:{self._resolved_subscription_name}")
]

return OperatorLineage(outputs=output_dataset)
topic_project_id = self.project_id or self.pubsub_hook.project_id
subscription_project_id = self.subscription_project_id or topic_project_id

return OperatorLineage(
inputs=[Dataset(namespace="pubsub", name=f"topic:{topic_project_id}:{self.topic}")],
outputs=[
Dataset(
namespace="pubsub",
name=f"subscription:{subscription_project_id}:{self._resolved_subscription_name}",
)
],
)


class PubSubDeleteTopicOperator(GoogleCloudBaseOperator):
Expand Down
53 changes: 38 additions & 15 deletions providers/google/tests/unit/google/cloud/operators/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,33 +207,54 @@ def test_execute_no_subscription(self, mock_hook):
assert response == TEST_SUBSCRIPTION

@pytest.mark.parametrize(
"subscription, project_id, subscription_project_id, expected_dataset",
"project_id, subscription, subscription_project_id, expected_input, expected_output",
[
# 1. Subscription provided, project_id provided, subscription_project_id not provided
(TEST_SUBSCRIPTION, TEST_PROJECT, None, f"subscription:{TEST_PROJECT}:{TEST_SUBSCRIPTION}"),
# 2. Subscription provided, subscription_project_id provided
(
TEST_PROJECT,
TEST_SUBSCRIPTION,
None,
f"topic:{TEST_PROJECT}:{TEST_TOPIC}",
f"subscription:{TEST_PROJECT}:{TEST_SUBSCRIPTION}",
),
(
TEST_PROJECT,
TEST_SUBSCRIPTION,
"another-project",
f"topic:{TEST_PROJECT}:{TEST_TOPIC}",
f"subscription:another-project:{TEST_SUBSCRIPTION}",
),
# 3. Subscription not provided (generated), project_id provided
(None, TEST_PROJECT, None, f"subscription:{TEST_PROJECT}:generated"),
# 4. Subscription not provided, subscription_project_id provided
(None, TEST_PROJECT, "another-project", "subscription:another-project:generated"),
# 5. Neither subscription nor project_id provided (use project_id from connection)
(None, None, None, "subscription:connection-project:generated"),
(
TEST_PROJECT,
None,
None,
f"topic:{TEST_PROJECT}:{TEST_TOPIC}",
f"subscription:{TEST_PROJECT}:generated",
),
(
TEST_PROJECT,
None,
"another-project",
f"topic:{TEST_PROJECT}:{TEST_TOPIC}",
"subscription:another-project:generated",
),
(
None,
None,
None,
f"topic:connection-project:{TEST_TOPIC}",
"subscription:connection-project:generated",
),
],
)
@mock.patch("airflow.providers.google.cloud.operators.pubsub.PubSubHook")
def test_get_openlineage_facets(
self,
mock_hook,
subscription,
project_id,
subscription,
subscription_project_id,
expected_dataset,
expected_input,
expected_output,
):
operator = PubSubCreateSubscriptionOperator(
task_id=TASK_ID,
Expand All @@ -243,7 +264,7 @@ def test_get_openlineage_facets(
subscription_project_id=subscription_project_id,
)
mock_hook.return_value.create_subscription.return_value = subscription or "generated"
mock_hook.return_value.project_id = subscription_project_id or project_id or "connection-project"
mock_hook.return_value.project_id = project_id or "connection-project"
context = mock.MagicMock()
response = operator.execute(context=context)
mock_hook.return_value.create_subscription.assert_called_once_with(
Expand Down Expand Up @@ -275,10 +296,12 @@ def test_get_openlineage_facets(
result = operator.get_openlineage_facets_on_complete(operator)
assert not result.run_facets
assert not result.job_facets
assert len(result.inputs) == 0
assert len(result.inputs) == 1
assert result.inputs[0].namespace == "pubsub"
assert result.inputs[0].name == expected_input
assert len(result.outputs) == 1
assert result.outputs[0].namespace == "pubsub"
assert result.outputs[0].name == expected_dataset
assert result.outputs[0].name == expected_output


class TestPubSubSubscriptionDeleteOperator:
Expand Down
Loading