Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
{
"$schema": "https://json-schema.org/draft/2020-12/schema",
"$defs": {
"DataFusionRunFacet": {
"allOf": [
{
"$ref": "https://openlineage.io/spec/2-0-2/OpenLineage.json#/$defs/RunFacet"
},
{
"type": "object",
"properties": {
"runId": {
"type": "string",
"description": "Pipeline run ID assigned by Cloud Data Fusion."
},
"runtimeArgs": {
"type": "object",
"description": "Runtime arguments provided when starting the pipeline."
}
}
}
],
"type": "object"
}
},
"type": "object",
"properties": {
"dataFusionRun": {
"$ref": "#/$defs/DataFusionRunFacet"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,26 @@ def _get_schema() -> str:
"openlineage/CloudStorageTransferRunFacet.json"
)

@define
class DataFusionRunFacet(RunFacet):
"""
Facet that represents relevant details of a Cloud Data Fusion pipeline run.

:param runId: The pipeline execution id.
:param runtimeArgs: Runtime arguments passed to the pipeline.
"""

runId: str | None = field(default=None)
runtimeArgs: dict[str, str] | None = field(default=None)

@staticmethod
def _get_schema() -> str:
return (
"https://raw.githubusercontent.com/apache/airflow/"
f"providers-google/{provider_version}/airflow/providers/google/"
"openlineage/DataFusionRunFacet.json"
)

except ImportError: # OpenLineage is not available

def create_no_op(*_, **__) -> None:
Expand All @@ -145,3 +165,4 @@ def create_no_op(*_, **__) -> None:
BigQueryJobRunFacet = create_no_op # type: ignore[misc, assignment]
CloudStorageTransferJobFacet = create_no_op # type: ignore[misc, assignment]
CloudStorageTransferRunFacet = create_no_op # type: ignore[misc, assignment]
DataFusionRunFacet = create_no_op # type: ignore[misc, assignment]
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -777,6 +778,7 @@ def __init__(
self.pipeline_timeout = pipeline_timeout
self.deferrable = deferrable
self.poll_interval = poll_interval
self.pipeline_id: str | None = None

if success_states:
self.success_states = success_states
Expand All @@ -796,14 +798,14 @@ def execute(self, context: Context) -> str:
project_id=self.project_id,
)
api_url = instance["apiEndpoint"]
pipeline_id = hook.start_pipeline(
self.pipeline_id = hook.start_pipeline(
pipeline_name=self.pipeline_name,
pipeline_type=self.pipeline_type,
instance_url=api_url,
namespace=self.namespace,
runtime_args=self.runtime_args,
)
self.log.info("Pipeline %s submitted successfully.", pipeline_id)
self.log.info("Pipeline %s submitted successfully.", self.pipeline_id)

DataFusionPipelineLink.persist(
context=context,
Expand All @@ -824,7 +826,7 @@ def execute(self, context: Context) -> str:
namespace=self.namespace,
pipeline_name=self.pipeline_name,
pipeline_type=self.pipeline_type.value,
pipeline_id=pipeline_id,
pipeline_id=self.pipeline_id,
poll_interval=self.poll_interval,
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
Expand All @@ -834,19 +836,21 @@ def execute(self, context: Context) -> str:
else:
if not self.asynchronous:
# when NOT using asynchronous mode it will just wait for pipeline to finish and print message
self.log.info("Waiting when pipeline %s will be in one of the success states", pipeline_id)
self.log.info(
"Waiting when pipeline %s will be in one of the success states", self.pipeline_id
)
hook.wait_for_pipeline_state(
success_states=self.success_states,
pipeline_id=pipeline_id,
pipeline_id=self.pipeline_id,
pipeline_name=self.pipeline_name,
pipeline_type=self.pipeline_type,
namespace=self.namespace,
instance_url=api_url,
timeout=self.pipeline_timeout,
)
self.log.info("Pipeline %s discovered success state.", pipeline_id)
self.log.info("Pipeline %s discovered success state.", self.pipeline_id)
# otherwise, return pipeline_id so that sensor can use it later to check the pipeline state
return pipeline_id
return self.pipeline_id

def execute_complete(self, context: Context, event: dict[str, Any]):
"""
Expand All @@ -863,6 +867,31 @@ def execute_complete(self, context: Context, event: dict[str, Any]):
)
return event["pipeline_id"]

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
"""Build and return OpenLineage facets and datasets for the completed pipeline start."""
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.facets import DataFusionRunFacet
from airflow.providers.openlineage.extractors import OperatorLineage

pipeline_resource = f"{self.project_id}:{self.location}:{self.instance_name}:{self.pipeline_name}"

inputs = [Dataset(namespace="datafusion", name=pipeline_resource)]

if self.pipeline_id:
output_name = f"{pipeline_resource}:{self.pipeline_id}"
else:
output_name = f"{pipeline_resource}:unknown"
outputs = [Dataset(namespace="datafusion", name=output_name)]

run_facets = {
"dataFusionRun": DataFusionRunFacet(
runId=self.pipeline_id,
runtimeArgs=self.runtime_args,
)
}

return OperatorLineage(inputs=inputs, outputs=outputs, run_facets=run_facets, job_facets={})


class CloudDataFusionStopPipelineOperator(GoogleCloudBaseOperator):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
BigQueryJobRunFacet,
CloudStorageTransferJobFacet,
CloudStorageTransferRunFacet,
DataFusionRunFacet,
)


Expand Down Expand Up @@ -80,3 +81,10 @@ def test_cloud_storage_transfer_run_facet():
assert facet.timeout == 3600
assert facet.deferrable is False
assert facet.deleteJobAfterCompletion is True


def test_datafusion_run_facet():
facet = DataFusionRunFacet(runId="abc123", runtimeArgs={"arg1": "val1"})

assert facet.runId == "abc123"
assert facet.runtimeArgs == {"arg1": "val1"}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from airflow import DAG
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, PipelineStates
from airflow.providers.google.cloud.openlineage.facets import DataFusionRunFacet
from airflow.providers.google.cloud.operators.datafusion import (
CloudDataFusionCreateInstanceOperator,
CloudDataFusionCreatePipelineOperator,
Expand Down Expand Up @@ -412,6 +413,65 @@ def test_execute_check_hook_call_asynch_param_should_execute_successfully(self,
):
op.execute(context=mock.MagicMock())

@pytest.mark.parametrize(
"pipeline_id, runtime_args, expected_run_id, expected_runtime_args, expected_output_suffix",
[
("abc123", {"arg1": "val1"}, "abc123", {"arg1": "val1"}, "abc123"),
(None, None, None, None, "unknown"),
],
)
@mock.patch("airflow.providers.google.cloud.operators.datafusion.DataFusionPipelineLink.persist")
@mock.patch(HOOK_STR)
def test_openlineage_facets_with_mock(
self,
mock_hook,
mock_persist,
pipeline_id,
runtime_args,
expected_run_id,
expected_runtime_args,
expected_output_suffix,
):
mock_persist.return_value = None

mock_instance = {"apiEndpoint": "https://mock-endpoint", "serviceEndpoint": "https://mock-service"}
mock_hook.return_value.get_instance.return_value = mock_instance
mock_hook.return_value.start_pipeline.return_value = pipeline_id

op = CloudDataFusionStartPipelineOperator(
task_id=TASK_ID,
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
runtime_args=runtime_args,
)

result_pipeline_id = op.execute(context={})
results = op.get_openlineage_facets_on_complete(task_instance=None)

assert result_pipeline_id == pipeline_id
assert op.pipeline_id == pipeline_id

expected_input_name = f"{PROJECT_ID}:{LOCATION}:{INSTANCE_NAME}:{PIPELINE_NAME}"

assert results is not None
assert len(results.inputs) == 1
assert results.inputs[0].namespace == "datafusion"
assert results.inputs[0].name == expected_input_name

assert len(results.outputs) == 1
assert results.outputs[0].namespace == "datafusion"
assert results.outputs[0].name == f"{expected_input_name}:{expected_output_suffix}"

facet = results.run_facets["dataFusionRun"]
assert isinstance(facet, DataFusionRunFacet)
assert facet.runId == expected_run_id
assert facet.runtimeArgs == expected_runtime_args

assert results.job_facets == {}


class TestCloudDataFusionStopPipelineOperator:
@mock.patch(HOOK_STR)
Expand Down
Loading