Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,11 @@
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -112,67 +110,3 @@ def persist_links(self, context: Context) -> None:
project_id=project_id,
table_id=table_id,
)

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import (
BIGQUERY_NAMESPACE,
get_facets_from_bq_table_for_given_fields,
get_identity_column_lineage_facet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

if not self.bigquery_hook:
self.bigquery_hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

try:
table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
except Exception:
self.log.debug(
"OpenLineage: could not fetch BigQuery table %s",
self.source_project_dataset_table,
exc_info=True,
)
return OperatorLineage()

if self.selected_fields:
if isinstance(self.selected_fields, str):
bigquery_field_names = list(self.selected_fields)
else:
bigquery_field_names = self.selected_fields
else:
bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])]

input_dataset = Dataset(
namespace=BIGQUERY_NAMESPACE,
name=self.source_project_dataset_table,
facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names),
)

db_info = self.mssql_hook.get_openlineage_database_info(self.mssql_hook.get_conn())
default_schema = self.mssql_hook.get_openlineage_default_schema()
namespace = f"{db_info.scheme}://{db_info.authority}"

if self.target_table_name and "." in self.target_table_name:
schema_name, table_name = self.target_table_name.split(".", 1)
else:
schema_name = default_schema or ""
table_name = self.target_table_name or ""

if self.database:
output_name = f"{self.database}.{schema_name}.{table_name}"
else:
output_name = f"{schema_name}.{table_name}"

column_lineage_facet = get_identity_column_lineage_facet(
bigquery_field_names, input_datasets=[input_dataset]
)

output_facets = column_lineage_facet or {}
output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,11 @@
import warnings
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class BigQueryToMySqlOperator(BigQueryToSqlBaseOperator):
"""
Expand Down Expand Up @@ -94,57 +89,3 @@ def execute(self, context):
project_id = self.bigquery_hook.project_id
self.source_project_dataset_table = f"{project_id}.{self.dataset_id}.{self.table_id}"
return super().execute(context)

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import (
BIGQUERY_NAMESPACE,
get_facets_from_bq_table_for_given_fields,
get_identity_column_lineage_facet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

if not self.bigquery_hook:
self.bigquery_hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

try:
table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
except Exception:
self.log.debug(
"OpenLineage: could not fetch BigQuery table %s",
self.source_project_dataset_table,
exc_info=True,
)
return OperatorLineage()

if self.selected_fields:
if isinstance(self.selected_fields, str):
bigquery_field_names = list(self.selected_fields)
else:
bigquery_field_names = self.selected_fields
else:
bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])]

input_dataset = Dataset(
namespace=BIGQUERY_NAMESPACE,
name=self.source_project_dataset_table,
facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names),
)

db_info = self.mysql_hook.get_openlineage_database_info(self.mysql_hook.get_conn())
namespace = f"{db_info.scheme}://{db_info.authority}"

output_name = f"{self.database}.{self.target_table_name}"

column_lineage_facet = get_identity_column_lineage_facet(
bigquery_field_names, input_datasets=[input_dataset]
)

output_facets = column_lineage_facet or {}
output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING

from psycopg2.extensions import register_adapter
Expand Down Expand Up @@ -78,28 +79,36 @@ def __init__(
self.postgres_conn_id = postgres_conn_id
self.replace_index = replace_index

def get_sql_hook(self) -> PostgresHook:
@cached_property
def postgres_hook(self) -> PostgresHook:
register_adapter(list, Json)
register_adapter(dict, Json)
return PostgresHook(database=self.database, postgres_conn_id=self.postgres_conn_id)

def get_sql_hook(self) -> PostgresHook:
return self.postgres_hook

def execute(self, context: Context) -> None:
big_query_hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)
if not self.bigquery_hook:
self.bigquery_hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)
# Set source_project_dataset_table here, after hooks are initialized and project_id is available
project_id = self.bigquery_hook.project_id
self.source_project_dataset_table = f"{project_id}.{self.dataset_id}.{self.table_id}"

self.persist_links(context)
sql_hook: PostgresHook = self.get_sql_hook()
for rows in bigquery_get_data(
self.log,
self.dataset_id,
self.table_id,
big_query_hook,
self.bigquery_hook,
self.batch_size,
self.selected_fields,
):
sql_hook.insert_rows(
self.postgres_hook.insert_rows(
table=self.target_table_name,
rows=rows,
target_fields=self.selected_fields,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

if TYPE_CHECKING:
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -140,3 +141,97 @@ def execute(self, context: Context) -> None:
replace=self.replace,
commit_every=self.batch_size,
)

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
"""
Build a generic OpenLineage facet for BigQuery -> SQL transfers.

This consolidates nearly identical implementations from child
operators. Children still provide a concrete SQL hook via
``get_sql_hook()`` and may override behavior if needed.
"""
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import (
BIGQUERY_NAMESPACE,
get_facets_from_bq_table_for_given_fields,
get_identity_column_lineage_facet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

if not self.bigquery_hook:
self.bigquery_hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

try:
if not getattr(self, "source_project_dataset_table", None):
project_id = self.bigquery_hook.project_id
self.source_project_dataset_table = f"{project_id}.{self.dataset_id}.{self.table_id}"

table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
except Exception:
self.log.debug(
"OpenLineage: could not fetch BigQuery table %s",
getattr(self, "source_project_dataset_table", None),
exc_info=True,
)
return OperatorLineage()

if self.selected_fields:
if isinstance(self.selected_fields, str):
bigquery_field_names = list(self.selected_fields)
else:
bigquery_field_names = self.selected_fields
else:
bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])]

input_dataset = Dataset(
namespace=BIGQUERY_NAMESPACE,
name=self.source_project_dataset_table,
facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names),
)

sql_hook = self.get_sql_hook()
db_info = sql_hook.get_openlineage_database_info(sql_hook.get_conn())
if db_info is None:
self.log.debug("OpenLineage: could not get database info from SQL hook %s", type(sql_hook))
return OperatorLineage()
namespace = f"{db_info.scheme}://{db_info.authority}"

schema_name = None
if hasattr(sql_hook, "get_openlineage_default_schema"):
try:
schema_name = sql_hook.get_openlineage_default_schema()
except Exception:
schema_name = None

if self.target_table_name and "." in self.target_table_name:
schema_part, table_part = self.target_table_name.split(".", 1)
else:
schema_part = schema_name or ""
table_part = self.target_table_name or ""

if db_info and db_info.scheme == "mysql":
output_name = f"{self.database}.{table_part}" if self.database else f"{table_part}"
else:
if self.database:
if schema_part:
output_name = f"{self.database}.{schema_part}.{table_part}"
else:
output_name = f"{self.database}.{table_part}"
else:
if schema_part:
output_name = f"{schema_part}.{table_part}"
else:
output_name = f"{table_part}"

column_lineage_facet = get_identity_column_lineage_facet(
bigquery_field_names, input_datasets=[input_dataset]
)

output_facets = column_lineage_facet or {}
output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_persist_links(self, mock_link):
)

@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
def test_get_openlineage_facets_on_complete_no_selected_fields(self, mock_bq_hook, mock_mssql_hook):
mock_bq_client = MagicMock()
table_obj = _make_bq_table(["id", "name", "value"])
Expand Down Expand Up @@ -152,7 +152,7 @@ def test_get_openlineage_facets_on_complete_no_selected_fields(self, mock_bq_hoo
assert set(col_lineage.fields.keys()) == {"id", "name", "value"}

@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
def test_get_openlineage_facets_on_complete_selected_fields(self, mock_bq_hook, mock_mssql_hook):
mock_bq_client = MagicMock()
table_obj = _make_bq_table(["id", "name", "value"])
Expand Down
Loading
Loading