diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py index 1963ff93a0b75..c882fa26dd782 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/bigquery.py @@ -66,6 +66,7 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.utils.bigquery import bq_cast from airflow.providers.google.cloud.utils.credentials_provider import _get_scopes +from airflow.providers.google.cloud.utils.lineage import send_hook_lineage_for_bq_job from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.deprecated import deprecated from airflow.providers.google.common.hooks.base_google import ( @@ -88,6 +89,7 @@ from google.api_core.retry import Retry from requests import Session + from airflow.providers.openlineage.sqlparser import DatabaseInfo from airflow.sdk import Context log = logging.getLogger(__name__) @@ -1330,19 +1332,10 @@ def insert_job( # Start the job and wait for it to complete and get the result. job_api_repr.result(timeout=timeout, retry=retry) - self._send_hook_level_lineage_for_bq_job(job=job_api_repr) + send_hook_lineage_for_bq_job(context=self, job=job_api_repr) return job_api_repr - def _send_hook_level_lineage_for_bq_job(self, job): - # TODO(kacpermuda) Add support for other job types and more params to sql job - if job.job_type == QueryJob.job_type: - send_sql_hook_lineage( - context=self, - sql=job.query, - job_id=job.job_id, - ) - def generate_job_id( self, job_id: str | None, @@ -1503,6 +1496,31 @@ def scopes(self) -> Sequence[str]: scope_value = self._get_field("scope", None) return _get_scopes(scope_value) + def get_openlineage_database_info(self, connection) -> DatabaseInfo: + """Return BigQuery specific information for OpenLineage.""" + from airflow.providers.openlineage.sqlparser import DatabaseInfo + + return DatabaseInfo( + scheme=self.get_openlineage_database_dialect(None), + authority=None, + database=self.project_id, + information_schema_columns=[ + "table_schema", + "table_name", + "column_name", + "ordinal_position", + "data_type", + "table_catalog", + ], + information_schema_table_name="INFORMATION_SCHEMA.COLUMNS", + ) + + def get_openlineage_database_dialect(self, _) -> str: + return "bigquery" + + def get_openlineage_default_schema(self) -> str | None: + return None + class BigQueryConnection: """ diff --git a/providers/google/src/airflow/providers/google/cloud/utils/lineage.py b/providers/google/src/airflow/providers/google/cloud/utils/lineage.py new file mode 100644 index 0000000000000..c5725976684a1 --- /dev/null +++ b/providers/google/src/airflow/providers/google/cloud/utils/lineage.py @@ -0,0 +1,94 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import logging + +from google.cloud.bigquery import CopyJob, ExtractJob, LoadJob, QueryJob + +from airflow.providers.common.compat.lineage.hook import get_hook_lineage_collector +from airflow.providers.common.sql.hooks.lineage import send_sql_hook_lineage + +log = logging.getLogger(__name__) + + +def _add_bq_table_to_lineage(collector, context, table_ref, *, is_input: bool): + method = collector.add_input_asset if is_input else collector.add_output_asset + method( + context=context, + scheme="bigquery", + asset_kwargs={ + "project_id": table_ref.project, + "dataset_id": table_ref.dataset_id, + "table_id": table_ref.table_id, + }, + ) + + +def _add_gcs_uris_to_lineage(collector, context, uris, *, is_input: bool): + method = collector.add_input_asset if is_input else collector.add_output_asset + for uri in uris or []: + method(context=context, uri=uri) + + +def send_hook_lineage_for_bq_job(context, job): + """ + Send hook-level lineage for a BigQuery job to the lineage collector. + + Handles all four BigQuery job types: + - QUERY: delegates to send_sql_hook_lineage for SQL parsing + - LOAD: source URIs (GCS) as inputs, destination table as output + - COPY: source tables as inputs, destination table as output + - EXTRACT: source table as input, destination URIs (GCS) as outputs + + :param context: The hook instance used as lineage context. + :param job: A BigQuery job object (QueryJob, LoadJob, CopyJob, or ExtractJob). + """ + collector = get_hook_lineage_collector() + + if isinstance(job, QueryJob): + log.debug("Sending Hook Level Lineage for Query job.") + send_sql_hook_lineage( + context=context, + sql=job.query, + job_id=job.job_id, + default_db=job.default_dataset.project if job.default_dataset else None, + default_schema=job.default_dataset.dataset_id if job.default_dataset else None, + ) + return + + try: + if isinstance(job, LoadJob): + log.debug("Sending Hook Level Lineage for Load job.") + _add_gcs_uris_to_lineage(collector, context, job.source_uris, is_input=True) + if job.destination: + _add_bq_table_to_lineage(collector, context, job.destination, is_input=False) + elif isinstance(job, CopyJob): + log.debug("Sending Hook Level Lineage for Copy job.") + for source_table in job.sources or []: + _add_bq_table_to_lineage(collector, context, source_table, is_input=True) + if job.destination: + _add_bq_table_to_lineage(collector, context, job.destination, is_input=False) + elif isinstance(job, ExtractJob): + log.debug("Sending Hook Level Lineage for Extract job.") + if job.source: + _add_bq_table_to_lineage(collector, context, job.source, is_input=True) + _add_gcs_uris_to_lineage(collector, context, job.destination_uris, is_input=False) + except Exception as e: + log.warning("Sending BQ job hook level lineage failed: %s", f"{e.__class__.__name__}: {str(e)}") + log.debug("Exception details:", exc_info=True) diff --git a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py index a2cd8894e4f23..bbd4f64bf4649 100644 --- a/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py +++ b/providers/google/tests/unit/google/cloud/hooks/test_bigquery.py @@ -2068,23 +2068,13 @@ def test_get_df_by_chunks_hook_lineage(self, mock_get_pandas_df_by_chunks, mock_ assert call_kw["sql"] == sql assert call_kw["sql_parameters"] == parameters - @mock.patch("airflow.providers.google.cloud.hooks.bigquery.send_sql_hook_lineage") + @mock.patch("airflow.providers.google.cloud.hooks.bigquery.send_hook_lineage_for_bq_job") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.QueryJob") @mock.patch("airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_client") def test_insert_job_hook_lineage(self, mock_client, mock_query_job, mock_send_lineage): - query_job_type = "query" - job_conf = { - query_job_type: { - query_job_type: "SELECT * FROM test", - "useLegacySql": "False", - } - } - mock_query_job._JOB_TYPE = query_job_type - mock_query_job.job_type = query_job_type + job_conf = {"query": {"query": "SELECT * FROM test", "useLegacySql": "False"}} + mock_query_job._JOB_TYPE = "query" mock_job_instance = mock.MagicMock() - mock_job_instance.job_id = JOB_ID - mock_job_instance.query = "SELECT * FROM test" - mock_job_instance.job_type = query_job_type mock_query_job.from_api_repr.return_value = mock_job_instance self.hook.insert_job( @@ -2095,8 +2085,4 @@ def test_insert_job_hook_lineage(self, mock_client, mock_query_job, mock_send_li nowait=True, ) - mock_send_lineage.assert_called_once() - call_kw = mock_send_lineage.call_args.kwargs - assert call_kw["context"] is self.hook - assert call_kw["sql"] == "SELECT * FROM test" - assert call_kw["job_id"] == JOB_ID + mock_send_lineage.assert_called_once_with(context=self.hook, job=mock_job_instance) diff --git a/providers/google/tests/unit/google/cloud/utils/test_lineage.py b/providers/google/tests/unit/google/cloud/utils/test_lineage.py new file mode 100644 index 0000000000000..fd9dc22640b4a --- /dev/null +++ b/providers/google/tests/unit/google/cloud/utils/test_lineage.py @@ -0,0 +1,237 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +from google.cloud.bigquery import CopyJob, DatasetReference, ExtractJob, LoadJob, QueryJob, TableReference + +from airflow.providers.common.compat.assets import Asset +from airflow.providers.google.cloud.utils.lineage import ( + _add_bq_table_to_lineage, + _add_gcs_uris_to_lineage, + send_hook_lineage_for_bq_job, +) + +PROJECT_ID = "test-project" +DATASET_ID = "test_dataset" +TABLE_ID = "test_table" +JOB_ID = "test-job-123" + +TABLE_REFERENCE = TableReference(DatasetReference(PROJECT_ID, DATASET_ID), TABLE_ID) + + +def _make_table_ref(project, dataset, table): + return TableReference(DatasetReference(project, dataset), table) + + +class TestAddBqTableToLineage: + def test_add_as_input(self): + collector = mock.MagicMock() + context = mock.sentinel.context + + _add_bq_table_to_lineage(collector, context, TABLE_REFERENCE, is_input=True) + + collector.add_input_asset.assert_called_once_with( + context=context, + scheme="bigquery", + asset_kwargs={ + "project_id": PROJECT_ID, + "dataset_id": DATASET_ID, + "table_id": TABLE_ID, + }, + ) + collector.add_output_asset.assert_not_called() + + def test_add_as_output(self): + collector = mock.MagicMock() + context = mock.sentinel.context + + _add_bq_table_to_lineage(collector, context, TABLE_REFERENCE, is_input=False) + + collector.add_output_asset.assert_called_once_with( + context=context, + scheme="bigquery", + asset_kwargs={ + "project_id": PROJECT_ID, + "dataset_id": DATASET_ID, + "table_id": TABLE_ID, + }, + ) + collector.add_input_asset.assert_not_called() + + +class TestAddGcsUrisToLineage: + def test_add_uris_as_input(self): + collector = mock.MagicMock() + context = mock.sentinel.context + uris = ["gs://bucket1/path/file.csv", "gs://bucket2/other.json"] + + _add_gcs_uris_to_lineage(collector, context, uris, is_input=True) + + assert collector.add_input_asset.call_count == 2 + collector.add_input_asset.assert_any_call(context=context, uri="gs://bucket1/path/file.csv") + collector.add_input_asset.assert_any_call(context=context, uri="gs://bucket2/other.json") + collector.add_output_asset.assert_not_called() + + def test_add_uris_as_output(self): + collector = mock.MagicMock() + context = mock.sentinel.context + uris = ["gs://bucket/export/data.csv"] + + _add_gcs_uris_to_lineage(collector, context, uris, is_input=False) + + collector.add_output_asset.assert_called_once_with(context=context, uri="gs://bucket/export/data.csv") + collector.add_input_asset.assert_not_called() + + def test_empty_uris(self): + collector = mock.MagicMock() + _add_gcs_uris_to_lineage(collector, mock.sentinel.context, [], is_input=True) + collector.add_input_asset.assert_not_called() + + def test_none_uris(self): + collector = mock.MagicMock() + _add_gcs_uris_to_lineage(collector, mock.sentinel.context, None, is_input=True) + collector.add_input_asset.assert_not_called() + + +class TestSendHookLineageForBqJob: + @mock.patch("airflow.providers.google.cloud.utils.lineage.send_sql_hook_lineage") + def test_query_job(self, mock_send_sql): + job = mock.MagicMock(spec=QueryJob) + job.query = "SELECT * FROM dataset.table" + job.job_id = JOB_ID + job.default_dataset = DatasetReference(PROJECT_ID, DATASET_ID) + context = mock.sentinel.context + + send_hook_lineage_for_bq_job(context=context, job=job) + + mock_send_sql.assert_called_once_with( + context=context, + sql="SELECT * FROM dataset.table", + job_id=JOB_ID, + default_db=PROJECT_ID, + default_schema=DATASET_ID, + ) + + @mock.patch("airflow.providers.google.cloud.utils.lineage.send_sql_hook_lineage") + def test_query_job_no_default_dataset(self, mock_send_sql): + job = mock.MagicMock(spec=QueryJob) + job.query = "SELECT 1" + job.job_id = JOB_ID + job.default_dataset = None + context = mock.sentinel.context + + send_hook_lineage_for_bq_job(context=context, job=job) + + mock_send_sql.assert_called_once_with( + context=context, + sql="SELECT 1", + job_id=JOB_ID, + default_db=None, + default_schema=None, + ) + + def test_load_job(self, hook_lineage_collector): + job = mock.MagicMock(spec=LoadJob) + job.source_uris = ["gs://bucket/data.csv", "gs://bucket/data2.csv"] + job.destination = TABLE_REFERENCE + context = mock.sentinel.context + + send_hook_lineage_for_bq_job(context=context, job=job) + + assert len(hook_lineage_collector.collected_assets.inputs) == 2 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"bigquery://{PROJECT_ID}/{DATASET_ID}/{TABLE_ID}" + ) + + def test_load_job_no_destination(self, hook_lineage_collector): + job = mock.MagicMock(spec=LoadJob) + job.source_uris = ["gs://bucket/data.csv"] + job.destination = None + context = mock.sentinel.context + + send_hook_lineage_for_bq_job(context=context, job=job) + + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert len(hook_lineage_collector.collected_assets.outputs) == 0 + + def test_copy_job(self, hook_lineage_collector): + source1 = _make_table_ref(PROJECT_ID, DATASET_ID, "source1") + source2 = _make_table_ref(PROJECT_ID, DATASET_ID, "source2") + dest = _make_table_ref(PROJECT_ID, DATASET_ID, "dest") + + job = mock.MagicMock(spec=CopyJob) + job.sources = [source1, source2] + job.destination = dest + context = mock.sentinel.context + + send_hook_lineage_for_bq_job(context=context, job=job) + + assert len(hook_lineage_collector.collected_assets.inputs) == 2 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri=f"bigquery://{PROJECT_ID}/{DATASET_ID}/source1" + ) + assert hook_lineage_collector.collected_assets.inputs[1].asset == Asset( + uri=f"bigquery://{PROJECT_ID}/{DATASET_ID}/source2" + ) + assert hook_lineage_collector.collected_assets.outputs[0].asset == Asset( + uri=f"bigquery://{PROJECT_ID}/{DATASET_ID}/dest" + ) + + def test_extract_job(self, hook_lineage_collector): + job = mock.MagicMock(spec=ExtractJob) + job.source = TABLE_REFERENCE + job.destination_uris = ["gs://bucket/export/file1.csv", "gs://bucket/export/file2.csv"] + context = mock.sentinel.context + + send_hook_lineage_for_bq_job(context=context, job=job) + + assert len(hook_lineage_collector.collected_assets.inputs) == 1 + assert len(hook_lineage_collector.collected_assets.outputs) == 2 + assert hook_lineage_collector.collected_assets.inputs[0].asset == Asset( + uri=f"bigquery://{PROJECT_ID}/{DATASET_ID}/{TABLE_ID}" + ) + + def test_extract_job_no_source(self, hook_lineage_collector): + job = mock.MagicMock(spec=ExtractJob) + job.source = None + job.destination_uris = ["gs://bucket/export/file.csv"] + context = mock.sentinel.context + + send_hook_lineage_for_bq_job(context=context, job=job) + + assert len(hook_lineage_collector.collected_assets.inputs) == 0 + assert len(hook_lineage_collector.collected_assets.outputs) == 1 + + @mock.patch("airflow.providers.google.cloud.utils.lineage.send_sql_hook_lineage") + def test_unknown_job_type_does_not_raise(self, mock_send_sql, hook_lineage_collector): + job = mock.MagicMock() + send_hook_lineage_for_bq_job(context=mock.sentinel.context, job=job) + mock_send_sql.assert_not_called() + assert len(hook_lineage_collector.collected_assets.inputs) == 0 + assert len(hook_lineage_collector.collected_assets.outputs) == 0 + + def test_exception_in_non_query_job_is_caught(self, hook_lineage_collector): + job = mock.MagicMock(spec=LoadJob) + type(job).source_uris = mock.PropertyMock(side_effect=RuntimeError("boom")) + context = mock.sentinel.context + + send_hook_lineage_for_bq_job(context=context, job=job)