diff --git a/providers/google/src/airflow/providers/google/cloud/openlineage/utils.py b/providers/google/src/airflow/providers/google/cloud/openlineage/utils.py index aabeb581d7aba..d0012f2301bde 100644 --- a/providers/google/src/airflow/providers/google/cloud/openlineage/utils.py +++ b/providers/google/src/airflow/providers/google/cloud/openlineage/utils.py @@ -214,7 +214,20 @@ def extract_ds_name_from_gcs_path(path: str) -> str: def get_facets_from_bq_table(table: Table) -> dict[str, DatasetFacet]: """Get facets from BigQuery table object.""" + return get_facets_from_bq_table_for_given_fields(table, selected_fields=None) + + +def get_facets_from_bq_table_for_given_fields( + table: Table, selected_fields: list[str] | None +) -> dict[str, DatasetFacet]: + """ + Get facets from BigQuery table object for selected fields only. + + If selected_fields is None, include all fields. + """ facets: dict[str, DatasetFacet] = {} + selected_fields_set = set(selected_fields) if selected_fields else None + if table.schema: facets["schema"] = SchemaDatasetFacet( fields=[ @@ -222,6 +235,7 @@ def get_facets_from_bq_table(table: Table) -> dict[str, DatasetFacet]: name=schema_field.name, type=schema_field.field_type, description=schema_field.description ) for schema_field in table.schema + if selected_fields_set is None or schema_field.name in selected_fields_set ] ) if table.description: diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py index ec63aeee5f6cd..0464fd7920749 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_mssql.py @@ -21,14 +21,17 @@ import warnings from collections.abc import Sequence +from functools import cached_property from typing import TYPE_CHECKING from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook from airflow.providers.google.cloud.links.bigquery import BigQueryTableLink from airflow.providers.google.cloud.transfers.bigquery_to_sql import BigQueryToSqlBaseOperator from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook if TYPE_CHECKING: + from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.context import Context @@ -94,9 +97,13 @@ def __init__( self.mssql_conn_id = mssql_conn_id self.source_project_dataset_table = source_project_dataset_table - def get_sql_hook(self) -> MsSqlHook: + @cached_property + def mssql_hook(self) -> MsSqlHook: return MsSqlHook(schema=self.database, mssql_conn_id=self.mssql_conn_id) + def get_sql_hook(self) -> MsSqlHook: + return self.mssql_hook + def persist_links(self, context: Context) -> None: project_id, dataset_id, table_id = self.source_project_dataset_table.split(".") BigQueryTableLink.persist( @@ -105,3 +112,67 @@ def persist_links(self, context: Context) -> None: project_id=project_id, table_id=table_id, ) + + def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None: + from airflow.providers.common.compat.openlineage.facet import Dataset + from airflow.providers.google.cloud.openlineage.utils import ( + BIGQUERY_NAMESPACE, + get_facets_from_bq_table_for_given_fields, + get_identity_column_lineage_facet, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + + if not self.bigquery_hook: + self.bigquery_hook = BigQueryHook( + gcp_conn_id=self.gcp_conn_id, + location=self.location, + impersonation_chain=self.impersonation_chain, + ) + + try: + table_obj = self.bigquery_hook.get_client().get_table(self.source_project_dataset_table) + except Exception: + self.log.debug( + "OpenLineage: could not fetch BigQuery table %s", + self.source_project_dataset_table, + exc_info=True, + ) + return OperatorLineage() + + if self.selected_fields: + if isinstance(self.selected_fields, str): + bigquery_field_names = list(self.selected_fields) + else: + bigquery_field_names = self.selected_fields + else: + bigquery_field_names = [f.name for f in getattr(table_obj, "schema", [])] + + input_dataset = Dataset( + namespace=BIGQUERY_NAMESPACE, + name=self.source_project_dataset_table, + facets=get_facets_from_bq_table_for_given_fields(table_obj, bigquery_field_names), + ) + + db_info = self.mssql_hook.get_openlineage_database_info(self.mssql_hook.get_conn()) + default_schema = self.mssql_hook.get_openlineage_default_schema() + namespace = f"{db_info.scheme}://{db_info.authority}" + + if self.target_table_name and "." in self.target_table_name: + schema_name, table_name = self.target_table_name.split(".", 1) + else: + schema_name = default_schema or "" + table_name = self.target_table_name or "" + + if self.database: + output_name = f"{self.database}.{schema_name}.{table_name}" + else: + output_name = f"{schema_name}.{table_name}" + + column_lineage_facet = get_identity_column_lineage_facet( + bigquery_field_names, input_datasets=[input_dataset] + ) + + output_facets = column_lineage_facet or {} + output_dataset = Dataset(namespace=namespace, name=output_name, facets=output_facets) + + return OperatorLineage(inputs=[input_dataset], outputs=[output_dataset]) diff --git a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py index dc3ad68fb81ce..20a9f8edc3017 100644 --- a/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py +++ b/providers/google/src/airflow/providers/google/cloud/transfers/bigquery_to_sql.py @@ -21,6 +21,7 @@ import abc from collections.abc import Sequence +from functools import cached_property from typing import TYPE_CHECKING from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook @@ -113,19 +114,22 @@ def get_sql_hook(self) -> DbApiHook: def persist_links(self, context: Context) -> None: """Persist the connection to the SQL provider.""" - def execute(self, context: Context) -> None: - big_query_hook = BigQueryHook( + @cached_property + def bigquery_hook(self) -> BigQueryHook: + return BigQueryHook( gcp_conn_id=self.gcp_conn_id, location=self.location, impersonation_chain=self.impersonation_chain, ) + + def execute(self, context: Context) -> None: self.persist_links(context) sql_hook = self.get_sql_hook() for rows in bigquery_get_data( self.log, self.dataset_id, self.table_id, - big_query_hook, + self.bigquery_hook, self.batch_size, self.selected_fields, ): diff --git a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py index fb6023e61916b..db26444c1fddd 100644 --- a/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py +++ b/providers/google/tests/unit/google/cloud/transfers/test_bigquery_to_mssql.py @@ -18,6 +18,7 @@ from __future__ import annotations from unittest import mock +from unittest.mock import MagicMock from airflow.providers.google.cloud.transfers.bigquery_to_mssql import BigQueryToMsSqlOperator @@ -28,6 +29,24 @@ TEST_PROJECT = "test-project" +def _make_bq_table(schema_names: list[str]): + class TableObj: + def __init__(self, schema): + self.schema = [] + for n in schema: + field = MagicMock() + field.name = n + self.schema.append(field) + self.description = "table description" + self.external_data_configuration = None + self.labels = {} + self.num_rows = 0 + self.num_bytes = 0 + self.table_type = "TABLE" + + return TableObj(schema_names) + + class TestBigQueryToMsSqlOperator: @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryTableLink") @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_sql.BigQueryHook") @@ -85,3 +104,94 @@ def test_persist_links(self, mock_link): project_id=TEST_PROJECT, table_id=TEST_TABLE_ID, ) + + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook") + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook") + def test_get_openlineage_facets_on_complete_no_selected_fields(self, mock_bq_hook, mock_mssql_hook): + mock_bq_client = MagicMock() + table_obj = _make_bq_table(["id", "name", "value"]) + mock_bq_client.get_table.return_value = table_obj + mock_bq_hook.get_client.return_value = mock_bq_client + mock_bq_hook.return_value = mock_bq_hook + + db_info = MagicMock(scheme="mssql", authority="localhost:1433", database="mydb") + mock_mssql_hook.get_openlineage_database_info.return_value = db_info + mock_mssql_hook.get_openlineage_default_schema.return_value = "dbo" + mock_mssql_hook.return_value = mock_mssql_hook + + op = BigQueryToMsSqlOperator( + task_id="test", + source_project_dataset_table="proj.dataset.table", + target_table_name="dbo.destination", + selected_fields=None, + database="mydb", + ) + op.bigquery_hook = mock_bq_hook + op.mssql_hook = mock_mssql_hook + context = mock.MagicMock() + op.execute(context=context) + + result = op.get_openlineage_facets_on_complete(task_instance=MagicMock()) + assert len(result.inputs) == 1 + assert len(result.outputs) == 1 + + input_ds = result.inputs[0] + assert input_ds.namespace == "bigquery" + assert input_ds.name == "proj.dataset.table" + + assert "schema" in input_ds.facets + schema_fields = [f.name for f in input_ds.facets["schema"].fields] + assert set(schema_fields) == {"id", "name", "value"} + + output_ds = result.outputs[0] + assert output_ds.namespace == "mssql://localhost:1433" + assert output_ds.name == "mydb.dbo.destination" + + assert "columnLineage" in output_ds.facets + col_lineage = output_ds.facets["columnLineage"] + assert set(col_lineage.fields.keys()) == {"id", "name", "value"} + + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.MsSqlHook") + @mock.patch("airflow.providers.google.cloud.transfers.bigquery_to_mssql.BigQueryHook") + def test_get_openlineage_facets_on_complete_selected_fields(self, mock_bq_hook, mock_mssql_hook): + mock_bq_client = MagicMock() + table_obj = _make_bq_table(["id", "name", "value"]) + mock_bq_client.get_table.return_value = table_obj + mock_bq_hook.get_client.return_value = mock_bq_client + mock_bq_hook.return_value = mock_bq_hook + + db_info = MagicMock(scheme="mssql", authority="server.example:1433", database="mydb") + mock_mssql_hook.get_openlineage_database_info.return_value = db_info + mock_mssql_hook.get_openlineage_default_schema.return_value = "dbo" + mock_mssql_hook.return_value = mock_mssql_hook + + op = BigQueryToMsSqlOperator( + task_id="test", + source_project_dataset_table="proj.dataset.table", + target_table_name="dbo.destination", + selected_fields=["id", "name"], + database="mydb", + ) + op.bigquery_hook = mock_bq_hook + op.mssql_hook = mock_mssql_hook + context = mock.MagicMock() + op.execute(context=context) + + result = op.get_openlineage_facets_on_complete(task_instance=MagicMock()) + assert len(result.inputs) == 1 + assert len(result.outputs) == 1 + + input_ds = result.inputs[0] + assert input_ds.namespace == "bigquery" + assert "schema" in input_ds.facets + + schema_fields = [f.name for f in input_ds.facets["schema"].fields] + assert set(schema_fields) == {"id", "name"} + + output_ds = result.outputs[0] + assert output_ds.namespace == "mssql://server.example:1433" + assert output_ds.name == "mydb.dbo.destination" + + assert "columnLineage" in output_ds.facets + col_lineage = output_ds.facets["columnLineage"] + assert set(col_lineage.fields.keys()) == {"id", "name"}