Skip to content

Commit

Permalink
feat: Add CLL to OpenLineage in BigQueryInsertJobOperator
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <mudakacper@gmail.com>
  • Loading branch information
kacpermuda committed Dec 12, 2024
1 parent 67cff15 commit 79c6ad4
Showing 1 changed file with 238 additions and 17 deletions.
255 changes: 238 additions & 17 deletions providers/src/airflow/providers/google/cloud/openlineage/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

import copy
import json
import logging
import traceback
from typing import TYPE_CHECKING, cast

if TYPE_CHECKING:
from airflow.providers.common.compat.openlineage.facet import (
ColumnLineageDatasetFacet,
Dataset,
InputDataset,
OutputDataset,
Expand All @@ -36,8 +38,16 @@

BIGQUERY_NAMESPACE = "bigquery"

log = logging.getLogger(__name__)


class _BigQueryOpenLineageMixin:
@property
def _safe_log(self) -> logging.Logger:
if hasattr(self, "log"):
return self.log
return log

def get_openlineage_facets_on_complete(self, _):
"""
Retrieve OpenLineage data for a COMPLETE BigQuery job.
Expand Down Expand Up @@ -70,8 +80,7 @@ def get_openlineage_facets_on_complete(self, _):
from airflow.providers.openlineage.sqlparser import SQLParser

if not self.job_id:
if hasattr(self, "log"):
self.log.warning("No BigQuery job_id was found by OpenLineage.")
self._safe_log.warning("No BigQuery job_id was found by OpenLineage.")
return OperatorLineage()

if not self.hook:
Expand Down Expand Up @@ -113,8 +122,7 @@ def get_facets(self, job_id: str):
inputs = []
outputs = []
run_facets: dict[str, RunFacet] = {}
if hasattr(self, "log"):
self.log.debug("Extracting data from bigquery job: `%s`", job_id)
self._safe_log.debug("Extracting data from bigquery job: `%s`", job_id)
try:
job = self.client.get_job(job_id=job_id) # type: ignore
props = job._properties
Expand All @@ -125,8 +133,7 @@ def get_facets(self, job_id: str):
run_facets["bigQueryJob"] = self._get_bigquery_job_run_facet(props)

if get_from_nullable_chain(props, ["statistics", "numChildJobs"]):
if hasattr(self, "log"):
self.log.debug("Found SCRIPT job. Extracting lineage from child jobs instead.")
self._safe_log.debug("Found SCRIPT job. Extracting lineage from child jobs instead.")
# SCRIPT job type has no input / output information but spawns child jobs that have one
# https://cloud.google.com/bigquery/docs/information-schema-jobs#multi-statement_query_job
for child_job_id in self.client.list_jobs(parent_job=job_id):
Expand All @@ -138,8 +145,7 @@ def get_facets(self, job_id: str):
inputs, _output = self._get_inputs_outputs_from_job(props)
outputs.append(_output)
except Exception as e:
if hasattr(self, "log"):
self.log.warning("Cannot retrieve job details from BigQuery.Client. %s", e, exc_info=True)
self._safe_log.warning("Cannot retrieve job details from BigQuery.Client. %s", e, exc_info=True)
exception_msg = traceback.format_exc()
run_facets.update(
{
Expand Down Expand Up @@ -178,14 +184,22 @@ def _get_inputs_outputs_from_job(

input_tables = get_from_nullable_chain(properties, ["statistics", "query", "referencedTables"]) or []
output_table = get_from_nullable_chain(properties, ["configuration", "query", "destinationTable"])
inputs = [(self._get_input_dataset(input_table)) for input_table in input_tables]
if output_table:
output = self._get_output_dataset(output_table)
dataset_stat_facet = self._get_statistics_dataset_facet(properties)
output.outputFacets = output.outputFacets or {}
if dataset_stat_facet:
output.outputFacets["outputStatistics"] = dataset_stat_facet
inputs = [
(self._get_input_dataset(input_table))
for input_table in input_tables
if input_table != output_table # Output table is in `referencedTables` and needs to be removed
]

if not output_table:
return inputs, None

output = self._get_output_dataset(output_table)
if dataset_stat_facet := self._get_statistics_dataset_facet(properties):
output.outputFacets = output.outputFacets or {}
output.outputFacets["outputStatistics"] = dataset_stat_facet
if cll_facet := self._get_column_level_lineage_facet(properties, output, inputs):
output.facets = output.facets or {}
output.facets["columnLineage"] = cll_facet
return inputs, output

@staticmethod
Expand Down Expand Up @@ -225,6 +239,70 @@ def _get_statistics_dataset_facet(
return OutputStatisticsOutputDatasetFacet(rowCount=int(out_rows), size=int(out_bytes))
return None

def _get_column_level_lineage_facet(
self, properties: dict, output: OutputDataset, inputs: list[InputDataset]
) -> ColumnLineageDatasetFacet | None:
"""
Extract column-level lineage information from a BigQuery job and return it as a facet.
The Column Level Lineage Facet will NOT be returned if any of the following condition is met:
- The parsed result does not contain column lineage information.
- The parsed result does not contain exactly one output table.
- The parsed result has a different output table than the output table from the BQ job.
- The parsed result has at least one input table not present in the input tables from the BQ job.
- The parsed result has a column not present in the schema of given dataset from the BQ job.
Args:
properties: The properties of the BigQuery job.
output: The output dataset for which the column lineage is being extracted.
Returns:
The extracted Column Lineage Dataset Facet, or None if conditions are not met.
"""
from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain
from airflow.providers.openlineage.sqlparser import SQLParser

# Extract SQL query and parse it
self._safe_log.debug("Extracting column-level lineage facet from BigQuery query.")
query = get_from_nullable_chain(properties, ["configuration", "query", "query"]) or ""
parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string(query))

if not parse_result or not parse_result.column_lineage:
self._safe_log.debug("No column-level lineage found in the SQL query. Facet generation skipped.")
return None

default_dataset, default_project = self._extract_default_dataset_and_project(
properties, self.project_id
)

# Validate parsed output tables against actual output dataset
if not self._validate_output_table_id(
parse_result,
output,
default_project,
default_dataset,
):
return None

# Validate parsed output columns against output dataset schema
if not self._validate_output_columns(parse_result, output):
return None

input_tables_from_parse_result = self._extract_parsed_input_tables(
parse_result, default_project, default_dataset
)
input_tables_from_bq = self._map_input_tables_to_columns(inputs)

# Validate parsed input tables against input datasets
if not self._validate_input_tables(input_tables_from_parse_result, input_tables_from_bq):
return None

# Validate parsed input tables columns against input datasets schemas
if not self._validate_input_columns(input_tables_from_parse_result, input_tables_from_bq):
return None

return self._generate_column_lineage_facet(parse_result, default_project, default_dataset)

def _get_input_dataset(self, table: dict) -> InputDataset:
from airflow.providers.common.compat.openlineage.facet import InputDataset

Expand Down Expand Up @@ -273,8 +351,7 @@ def _get_table_schema_safely(self, table_name: str) -> SchemaDatasetFacet | None
try:
return self._get_table_schema(table_name)
except Exception as e:
if hasattr(self, "log"):
self.log.warning("Could not extract output schema from bigquery. %s", e)
self._safe_log.warning("Could not extract output schema from bigquery. %s", e)
return None

def _get_table_schema(self, table: str) -> SchemaDatasetFacet | None:
Expand Down Expand Up @@ -303,3 +380,147 @@ def _get_table_schema(self, table: str) -> SchemaDatasetFacet | None:
for field in fields
]
)

@staticmethod
def _get_qualified_name_from_parse_result(table, default_project: str, default_dataset: str) -> str:
"""Get the qualified name of a table from the parse result."""
return ".".join(
(
table.database or default_project,
table.schema or default_dataset,
table.name,
)
)

@staticmethod
def _extract_default_dataset_and_project(properties: dict, default_project: str) -> tuple[str, str]:
"""Extract the default dataset and project from the BigQuery job properties."""
from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain

default_dataset_obj = get_from_nullable_chain(
properties, ["configuration", "query", "defaultDataset"]
)
default_dataset = default_dataset_obj.get("datasetId") if default_dataset_obj else ""
default_project = default_dataset_obj.get("projectId") if default_dataset_obj else default_project
return default_dataset, default_project

def _validate_output_table_id(
self, parse_result, output: OutputDataset, default_project: str, default_dataset: str
) -> bool:
"""Check if the output table id from the parse result matches the BQ job output table."""
if len(parse_result.out_tables) != 1:
self._safe_log.debug(
"Invalid output tables in the parse result: `%s`. Expected exactly one output table.",
parse_result.out_tables,
)
return False

parsed_output_table = self._get_qualified_name_from_parse_result(
parse_result.out_tables[0], default_project, default_dataset
)
if parsed_output_table != output.name:
self._safe_log.debug(
"Mismatch between parsed output table `%s` and BQ job output table `%s`.",
parsed_output_table,
output.name,
)
return False
return True

@staticmethod
def _extract_column_names(dataset: Dataset) -> list[str]:
"""Extract column names from a dataset's schema."""
from openlineage.client.facet_v2 import schema_dataset

return [
f.name for f in dataset.facets.get("schema", schema_dataset.SchemaDatasetFacet(fields=[])).fields
]

def _validate_output_columns(self, parse_result, output: OutputDataset) -> bool:
"""Validate if all descendant columns in parse result exist in output dataset schema."""
output_column_names = self._extract_column_names(output)
missing_columns = [
lineage.descendant.name
for lineage in parse_result.column_lineage
if lineage.descendant.name not in output_column_names
]
if missing_columns:
self._safe_log.debug(
"Output dataset schema is missing columns from the parse result: `%s`.", missing_columns
)
return False
return True

def _extract_parsed_input_tables(
self, parse_result, default_project: str, default_dataset: str | None
) -> dict[str, list[str]]:
"""Extract input tables and their columns from the parse result."""
input_tables: dict[str, list[str]] = {}
for lineage in parse_result.column_lineage:
for column_meta in lineage.lineage:
if not column_meta.origin:
self._safe_log.debug(
"Column `%s` lacks origin information. Skipping facet generation.", column_meta.name
)
return {}

input_table_id = self._get_qualified_name_from_parse_result(
column_meta.origin, default_project, default_dataset
)
input_tables.setdefault(input_table_id, []).append(column_meta.name)
return input_tables

def _map_input_tables_to_columns(self, inputs: list[InputDataset]) -> dict[str, list[str]]:
"""Map input tables to their columns from the BQ job."""
return {input_ds.name: self._extract_column_names(input_ds) for input_ds in inputs}

def _validate_input_tables(
self, parsed_input_tables: dict[str, list[str]], input_tables_from_bq: dict[str, list[str]]
) -> bool:
"""Validate if all parsed input tables exist in the BQ job's input datasets."""
missing_tables = [table for table in parsed_input_tables if table not in input_tables_from_bq]
if missing_tables:
self._safe_log.debug(
"Parsed input tables not found in the BQ job's input datasets: `%s`.", missing_tables
)
return False
return True

def _validate_input_columns(
self, parsed_input_tables: dict[str, list[str]], input_tables_from_bq: dict[str, list[str]]
) -> bool:
"""Validate if all parsed input columns exist in their respective BQ job input table schemas."""
for table, columns in parsed_input_tables.items():
missing_columns = [col for col in columns if col not in input_tables_from_bq.get(table, [])]
if missing_columns:
self._safe_log.debug(
"Input table `%s` is missing columns from the parse result: `%s`.", table, missing_columns
)
return False
return True

def _generate_column_lineage_facet(
self, parse_result, default_project: str, default_dataset: str | None
) -> ColumnLineageDatasetFacet:
"""Generate the ColumnLineageDatasetFacet based on the parsed result."""
from openlineage.client.facet_v2 import column_lineage_dataset

return column_lineage_dataset.ColumnLineageDatasetFacet(
fields={
lineage.descendant.name: column_lineage_dataset.Fields(
inputFields=[
column_lineage_dataset.InputField(
namespace=BIGQUERY_NAMESPACE,
name=self._get_qualified_name_from_parse_result(
column_meta.origin, default_project, default_dataset
),
field=column_meta.name,
)
for column_meta in lineage.lineage
],
transformationType="",
transformationDescription="",
)
for lineage in parse_result.column_lineage
}
)

0 comments on commit 79c6ad4

Please sign in to comment.