Skip to content

Commit

Permalink
feat: add OpenLineage support for BigQueryToBigQueryOperator
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <mudakacper@gmail.com>
  • Loading branch information
kacpermuda committed Nov 20, 2024
1 parent aa7a3b2 commit eca79ba
Show file tree
Hide file tree
Showing 8 changed files with 594 additions and 104 deletions.
75 changes: 54 additions & 21 deletions providers/src/airflow/providers/google/cloud/openlineage/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from airflow.providers.common.compat.openlineage.facet import Dataset

from airflow.providers.common.compat.openlineage.facet import (
BaseFacet,
ColumnLineageDatasetFacet,
DocumentationDatasetFacet,
Fields,
Expand All @@ -41,50 +42,82 @@
BIGQUERY_URI = "bigquery"


def get_facets_from_bq_table(table: Table) -> dict[Any, Any]:
def get_facets_from_bq_table(table: Table) -> dict[str, BaseFacet]:
"""Get facets from BigQuery table object."""
facets = {
"schema": SchemaDatasetFacet(
facets: dict[str, BaseFacet] = {}
if table.schema:
facets["schema"] = SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(
name=field.name, type=field.field_type, description=field.description
name=schema_field.name, type=schema_field.field_type, description=schema_field.description
)
for field in table.schema
for schema_field in table.schema
]
),
"documentation": DocumentationDatasetFacet(description=table.description or ""),
}
)
if table.description:
facets["documentation"] = DocumentationDatasetFacet(description=table.description)

return facets


def get_identity_column_lineage_facet(
field_names: list[str],
dest_field_names: list[str],
input_datasets: list[Dataset],
) -> ColumnLineageDatasetFacet:
) -> dict[str, ColumnLineageDatasetFacet]:
"""
Get column lineage facet.
Simple lineage will be created, where each source column corresponds to single destination column
in each input dataset and there are no transformations made.
Get column lineage facet for identity transformations.
This function generates a simple column lineage facet, where each destination column
consists of source columns of the same name from all input datasets that have that column.
The lineage assumes there are no transformations applied, meaning the columns retain their
identity between the source and destination datasets.
Args:
dest_field_names: A list of destination column names for which lineage should be determined.
input_datasets: A list of input datasets with schema facets.
Returns:
A dictionary containing a single key, "columnLineage", mapped to a `ColumnLineageDatasetFacet`.
If no lineage can be determined, an empty dictionary is returned.
Notes:
- If any input dataset lacks a schema facet, the function immediately returns an empty dictionary.
- If any field in the source dataset's schema is not present in the destination table,
the function returns an empty dictionary. The destination table can contain extra fields, but all
source columns should be present in the destination table.
- If none of the destination columns can be matched to input dataset columns, an empty
dictionary is returned.
- Extra columns in the destination table that do not exist in the input datasets are ignored and
skipped in the lineage facet, as they cannot be traced back to a source column.
- The function assumes there are no transformations applied, meaning the columns retain their
identity between the source and destination datasets.
"""
if field_names and not input_datasets:
raise ValueError("When providing `field_names` You must provide at least one `input_dataset`.")
fields_sources: dict[str, list[Dataset]] = {}
for ds in input_datasets:
if not ds.facets or "schema" not in ds.facets:
return {}
for schema_field in ds.facets["schema"].fields: # type: ignore[attr-defined]
if schema_field.name not in dest_field_names:
return {}
fields_sources[schema_field.name] = fields_sources.get(schema_field.name, []) + [ds]

if not fields_sources:
return {}

column_lineage_facet = ColumnLineageDatasetFacet(
fields={
field: Fields(
field_name: Fields(
inputFields=[
InputField(namespace=dataset.namespace, name=dataset.name, field=field)
for dataset in input_datasets
InputField(namespace=dataset.namespace, name=dataset.name, field=field_name)
for dataset in datasets
],
transformationType="IDENTITY",
transformationDescription="identical",
)
for field in field_names
for field_name, datasets in fields_sources.items()
}
)
return column_lineage_facet
return {"columnLineage": column_lineage_facet}


@define
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
self.location = location
self.impersonation_chain = impersonation_chain
self.hook: BigQueryHook | None = None
self._job_conf: dict = {}

def _prepare_job_configuration(self):
self.source_project_dataset_tables = (
Expand Down Expand Up @@ -154,39 +155,94 @@ def _prepare_job_configuration(self):

return configuration

def _submit_job(
self,
hook: BigQueryHook,
configuration: dict,
) -> str:
job = hook.insert_job(configuration=configuration, project_id=hook.project_id)
return job.job_id

def execute(self, context: Context) -> None:
self.log.info(
"Executing copy of %s into: %s",
self.source_project_dataset_tables,
self.destination_project_dataset_table,
)
hook = BigQueryHook(
self.hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)
self.hook = hook

if not hook.project_id:
if not self.hook.project_id:
raise ValueError("The project_id should be set")

configuration = self._prepare_job_configuration()
job_id = self._submit_job(hook=hook, configuration=configuration)
self._job_conf = self.hook.insert_job(
configuration=configuration, project_id=self.hook.project_id
).to_api_repr()

job = hook.get_job(job_id=job_id, location=self.location).to_api_repr()
conf = job["configuration"]["copy"]["destinationTable"]
dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"]
BigQueryTableLink.persist(
context=context,
task_instance=self,
dataset_id=conf["datasetId"],
project_id=conf["projectId"],
table_id=conf["tableId"],
dataset_id=dest_table_info["datasetId"],
project_id=dest_table_info["projectId"],
table_id=dest_table_info["tableId"],
)

def get_openlineage_facets_on_complete(self, task_instance):
"""Implement on_complete as we will include final BQ job id."""
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
ExternalQueryRunFacet,
)
from airflow.providers.google.cloud.openlineage.utils import (
BIGQUERY_NAMESPACE,
get_facets_from_bq_table,
get_identity_column_lineage_facet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

if not self.hook:
self.hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

if not self._job_conf:
self.log.debug("OpenLineage could not find BQ job configuration.")
return OperatorLineage()

bq_job_id = self._job_conf["jobReference"]["jobId"]
source_tables_info = self._job_conf["configuration"]["copy"]["sourceTables"]
dest_table_info = self._job_conf["configuration"]["copy"]["destinationTable"]

run_facets = {
"externalQuery": ExternalQueryRunFacet(externalQueryId=bq_job_id, source="bigquery"),
}

input_datasets = []
for in_table_info in source_tables_info:
table_id = ".".join(
(in_table_info["projectId"], in_table_info["datasetId"], in_table_info["tableId"])
)
table_object = self.hook.get_client().get_table(table_id)
input_datasets.append(
Dataset(
namespace=BIGQUERY_NAMESPACE, name=table_id, facets=get_facets_from_bq_table(table_object)
)
)

out_table_id = ".".join(
(dest_table_info["projectId"], dest_table_info["datasetId"], dest_table_info["tableId"])
)
out_table_object = self.hook.get_client().get_table(out_table_id)
output_dataset_facets = {
**get_facets_from_bq_table(out_table_object),
**get_identity_column_lineage_facet(
dest_field_names=[field.name for field in out_table_object.schema],
input_datasets=input_datasets,
),
}
output_dataset = Dataset(
namespace=BIGQUERY_NAMESPACE,
name=out_table_id,
facets=output_dataset_facets,
)

return OperatorLineage(inputs=input_datasets, outputs=[output_dataset], run_facets=run_facets)
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ def get_openlineage_facets_on_complete(self, task_instance):
from pathlib import Path

from airflow.providers.common.compat.openlineage.facet import (
BaseFacet,
Dataset,
ExternalQueryRunFacet,
Identifier,
Expand Down Expand Up @@ -322,12 +323,12 @@ def get_openlineage_facets_on_complete(self, task_instance):
facets=get_facets_from_bq_table(table_object),
)

output_dataset_facets = {
"schema": input_dataset.facets["schema"],
"columnLineage": get_identity_column_lineage_facet(
field_names=[field.name for field in table_object.schema], input_datasets=[input_dataset]
),
}
output_dataset_facets: dict[str, BaseFacet] = get_identity_column_lineage_facet(
dest_field_names=[field.name for field in table_object.schema], input_datasets=[input_dataset]
)
if "schema" in input_dataset.facets:
output_dataset_facets["schema"] = input_dataset.facets["schema"]

output_datasets = []
for uri in sorted(self.destination_cloud_storage_uris):
bucket, blob = _parse_gcs_url(uri)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -784,9 +784,10 @@ def get_openlineage_facets_on_complete(self, task_instance):
source_objects = (
self.source_objects if isinstance(self.source_objects, list) else [self.source_objects]
)
input_dataset_facets = {
"schema": output_dataset_facets["schema"],
}
input_dataset_facets = {}
if "schema" in output_dataset_facets:
input_dataset_facets["schema"] = output_dataset_facets["schema"]

input_datasets = []
for blob in sorted(source_objects):
additional_facets = {}
Expand All @@ -811,14 +812,16 @@ def get_openlineage_facets_on_complete(self, task_instance):
)
input_datasets.append(dataset)

output_dataset_facets["columnLineage"] = get_identity_column_lineage_facet(
field_names=[field.name for field in table_object.schema], input_datasets=input_datasets
)

output_dataset = Dataset(
namespace="bigquery",
name=str(table_object.reference),
facets=output_dataset_facets,
facets={
**output_dataset_facets,
**get_identity_column_lineage_facet(
dest_field_names=[field.name for field in table_object.schema],
input_datasets=input_datasets,
),
},
)

run_facets = {}
Expand Down
Loading

0 comments on commit eca79ba

Please sign in to comment.