Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,28 @@ def extract_ds_name_from_gcs_path(path: str) -> str:

def get_facets_from_bq_table(table: Table) -> dict[str, DatasetFacet]:
"""Get facets from BigQuery table object."""
return get_facets_from_bq_table_for_given_fields(table, selected_fields=None)


def get_facets_from_bq_table_for_given_fields(
table: Table, selected_fields: list[str] | None
) -> dict[str, DatasetFacet]:
"""
Get facets from BigQuery table object for selected fields only.

If selected_fields is None, include all fields.
"""
facets: dict[str, DatasetFacet] = {}
selected_fields_set = set(selected_fields) if selected_fields else None

if table.schema:
facets["schema"] = SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(
name=schema_field.name, type=schema_field.field_type, description=schema_field.description
)
for schema_field in table.schema
if selected_fields_set is None or schema_field.name in selected_fields_set
]
)
if table.description:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@

import warnings
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink
from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator
from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -94,9 +97,13 @@ def __init__(
self.mssql_conn_id = mssql_conn_id
self.source_project_dataset_table = source_project_dataset_table

def get_sql_hook(self) -> MsSqlHook:
@cached_property
def mssql_hook(self) -> MsSqlHook:
return MsSqlHook(schema=self.database, mssql_conn_id=self.mssql_conn_id)

def get_sql_hook(self) -> MsSqlHook:
return self.mssql_hook

def persist_links(self, context: Context) -> None:
project_id, dataset_id, table_id = self.source_project_dataset_table.split(".")
BigQueryTableLink.persist(
Expand All @@ -105,3 +112,67 @@ def persist_links(self, context: Context) -> None:
project_id=project_id,
table_id=table_id,
)

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import Dataset
from airflow.providers.google.cloud.openlineage.utils import (
BIGQUERY_NAMESPACE,
get_facets_from_bq_table_for_given_fields,
get_identity_column_lineage_facet,
)
from airflow.providers.openlineage.extractors import OperatorLineage

if not self.bigquery_hook:
self.bigquery_hook = BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

try:
table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table)
except Exception:
self.log.debug(
"OpenLineage: could not fetch BigQuery table %s",
self.source_project_dataset_table,
exc_info=True,
)
return OperatorLineage()

if self.selected_fields:
if isinstance(self.selected_fields, str):
bigquery_field_names = list(self.selected_fields)
else:
bigquery_field_names = self.selected_fields
else:
bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])]

input_dataset = Dataset(
namespace=BIGQUERY_NAMESPACE,
name=self.source_project_dataset_table,
facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names),
)

db_info = self.mssql_hook.get_openlineage_database_info(self.mssql_hook.get_conn())
default_schema = self.mssql_hook.get_openlineage_default_schema()
namespace = f"{db_info.scheme}://{db_info.authority}"

if self.target_table_name and "." in self.target_table_name:
schema_name, table_name = self.target_table_name.split(".", 1)
else:
schema_name = default_schema or ""
table_name = self.target_table_name or ""

if self.database:
output_name = f"{self.database}.{schema_name}.{table_name}"
else:
output_name = f"{schema_name}.{table_name}"

column_lineage_facet = get_identity_column_lineage_facet(
bigquery_field_names, input_datasets=[input_dataset]
)

output_facets = column_lineage_facet or {}
output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets)

return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset])
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import abc
from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING

from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
Expand Down Expand Up @@ -113,19 +114,22 @@ def get_sql_hook(self) -> DbApiHook:
def persist_links(self, context: Context) -> None:
"""Persist the connection to the SQL provider."""

def execute(self, context: Context) -> None:
big_query_hook = BigQueryHook(
@cached_property
def bigquery_hook(self) -> BigQueryHook:
return BigQueryHook(
gcp_conn_id=self.gcp_conn_id,
location=self.location,
impersonation_chain=self.impersonation_chain,
)

def execute(self, context: Context) -> None:
self.persist_links(context)
sql_hook = self.get_sql_hook()
for rows in bigquery_get_data(
self.log,
self.dataset_id,
self.table_id,
big_query_hook,
self.bigquery_hook,
self.batch_size,
self.selected_fields,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock

from airflow.providers.google.cloud.transfers.bigquery_to_mssql import BigQueryToMsSqlOperator

Expand All @@ -28,6 +29,24 @@
TEST_PROJECT = "test-project"


def _make_bq_table(schema_names: list[str]):
class TableObj:
def __init__(self, schema):
self.schema = []
for n in schema:
field = MagicMock()
field.name = n
self.schema.append(field)
self.description = "table description"
self.external_data_configuration = None
self.labels = {}
self.num_rows = 0
self.num_bytes = 0
self.table_type = "TABLE"

return TableObj(schema_names)


class TestBigQueryToMsSqlOperator:
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryTableLink")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook")
Expand Down Expand Up @@ -85,3 +104,94 @@ def test_persist_links(self, mock_link):
project_id=TEST_PROJECT,
table_id=TEST_TABLE_ID,
)

@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook")
def test_get_openlineage_facets_on_complete_no_selected_fields(self, mock_bq_hook, mock_mssql_hook):
mock_bq_client = MagicMock()
table_obj = _make_bq_table(["id", "name", "value"])
mock_bq_client.get_table.return_value = table_obj
mock_bq_hook.get_client.return_value = mock_bq_client
mock_bq_hook.return_value = mock_bq_hook

db_info = MagicMock(scheme="mssql", authority="localhost:1433", database="mydb")
mock_mssql_hook.get_openlineage_database_info.return_value = db_info
mock_mssql_hook.get_openlineage_default_schema.return_value = "dbo"
mock_mssql_hook.return_value = mock_mssql_hook

op = BigQueryToMsSqlOperator(
task_id="test",
source_project_dataset_table="proj.dataset.table",
target_table_name="dbo.destination",
selected_fields=None,
database="mydb",
)
op.bigquery_hook = mock_bq_hook
op.mssql_hook = mock_mssql_hook
context = mock.MagicMock()
op.execute(context=context)

result = op.get_openlineage_facets_on_complete(task_instance=MagicMock())
assert len(result.inputs) == 1
assert len(result.outputs) == 1

input_ds = result.inputs[0]
assert input_ds.namespace == "bigquery"
assert input_ds.name == "proj.dataset.table"

assert "schema" in input_ds.facets
schema_fields = [f.name for f in input_ds.facets["schema"].fields]
assert set(schema_fields) == {"id", "name", "value"}

output_ds = result.outputs[0]
assert output_ds.namespace == "mssql://localhost:1433"
assert output_ds.name == "mydb.dbo.destination"

assert "columnLineage" in output_ds.facets
col_lineage = output_ds.facets["columnLineage"]
assert set(col_lineage.fields.keys()) == {"id", "name", "value"}

@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook")
@mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook")
def test_get_openlineage_facets_on_complete_selected_fields(self, mock_bq_hook, mock_mssql_hook):
mock_bq_client = MagicMock()
table_obj = _make_bq_table(["id", "name", "value"])
mock_bq_client.get_table.return_value = table_obj
mock_bq_hook.get_client.return_value = mock_bq_client
mock_bq_hook.return_value = mock_bq_hook

db_info = MagicMock(scheme="mssql", authority="server.example:1433", database="mydb")
mock_mssql_hook.get_openlineage_database_info.return_value = db_info
mock_mssql_hook.get_openlineage_default_schema.return_value = "dbo"
mock_mssql_hook.return_value = mock_mssql_hook

op = BigQueryToMsSqlOperator(
task_id="test",
source_project_dataset_table="proj.dataset.table",
target_table_name="dbo.destination",
selected_fields=["id", "name"],
database="mydb",
)
op.bigquery_hook = mock_bq_hook
op.mssql_hook = mock_mssql_hook
context = mock.MagicMock()
op.execute(context=context)

result = op.get_openlineage_facets_on_complete(task_instance=MagicMock())
assert len(result.inputs) == 1
assert len(result.outputs) == 1

input_ds = result.inputs[0]
assert input_ds.namespace == "bigquery"
assert "schema" in input_ds.facets

schema_fields = [f.name for f in input_ds.facets["schema"].fields]
assert set(schema_fields) == {"id", "name"}

output_ds = result.outputs[0]
assert output_ds.namespace == "mssql://server.example:1433"
assert output_ds.name == "mydb.dbo.destination"

assert "columnLineage" in output_ds.facets
col_lineage = output_ds.facets["columnLineage"]
assert set(col_lineage.fields.keys()) == {"id", "name"}