diff --git a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py index d364dd17673a0..7b3fefaa18980 100644 --- a/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py +++ b/providers/google/src/airflow/providers/google/cloud/hooks/spanner.py @@ -31,6 +31,7 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field +from airflow.providers.openlineage.sqlparser import DatabaseInfo if TYPE_CHECKING: from google.cloud.spanner_v1.database import Database @@ -38,6 +39,8 @@ from google.cloud.spanner_v1.transaction import Transaction from google.longrunning.operations_grpc_pb2 import Operation + from airflow.models.connection import Connection + class SpannerConnectionParams(NamedTuple): """Information about Google Spanner connection parameters.""" @@ -427,3 +430,45 @@ def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]) -> rc = transaction.execute_update(sql) counts[sql] = rc return counts + + def _get_openlineage_authority_part(self, connection: Connection) -> str | None: + """Build Spanner-specific authority part for OpenLineage. Returns {project}/{instance}.""" + extras = connection.extra_dejson + project_id = extras.get("project_id") + instance_id = extras.get("instance_id") + + if not project_id or not instance_id: + return None + + return f"{project_id}/{instance_id}" + + def get_openlineage_database_dialect(self, connection: Connection) -> str: + """Return database dialect for OpenLineage.""" + return "spanner" + + def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo: + """Return Spanner specific information for OpenLineage.""" + extras = connection.extra_dejson + database_id = extras.get("database_id") + + return DatabaseInfo( + scheme=self.get_openlineage_database_dialect(connection), + authority=self._get_openlineage_authority_part(connection), + database=database_id, + information_schema_columns=[ + "table_schema", + "table_name", + "column_name", + "ordinal_position", + "spanner_type", + ], + ) + + def get_openlineage_default_schema(self) -> str | None: + """ + Spanner expose 'public' or '' schema depending on dialect(Postgres vs GoogleSQL). + + SQLAlchemy dialect for Spanner does not expose default schema, so we return None + to follow the same approach. + """ + return None diff --git a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py index 732b2e19b7c1b..97d9fe797982d 100644 --- a/providers/google/src/airflow/providers/google/cloud/operators/spanner.py +++ b/providers/google/src/airflow/providers/google/cloud/operators/spanner.py @@ -20,6 +20,7 @@ from __future__ import annotations from collections.abc import Sequence +from functools import cached_property from typing import TYPE_CHECKING from airflow.exceptions import AirflowException @@ -29,6 +30,7 @@ from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID if TYPE_CHECKING: + from airflow.providers.openlineage.extractors import OperatorLineage from airflow.utils.context import Context @@ -254,6 +256,13 @@ def __init__( self.impersonation_chain = impersonation_chain super().__init__(**kwargs) + @cached_property + def hook(self) -> SpannerHook: + return SpannerHook( + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + ) + def _validate_inputs(self) -> None: if self.project_id == "": raise AirflowException("The required parameter 'project_id' is empty") @@ -265,10 +274,6 @@ def _validate_inputs(self) -> None: raise AirflowException("The required parameter 'query' is empty") def execute(self, context: Context): - hook = SpannerHook( - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - ) if isinstance(self.query, str): queries = [x.strip() for x in self.query.split(";")] self.sanitize_queries(queries) @@ -281,7 +286,7 @@ def execute(self, context: Context): self.database_id, ) self.log.info("Executing queries: %s", queries) - result_rows_count_per_query = hook.execute_dml( + result_rows_count_per_query = self.hook.execute_dml( project_id=self.project_id, instance_id=self.instance_id, database_id=self.database_id, @@ -291,7 +296,7 @@ def execute(self, context: Context): context=context, instance_id=self.instance_id, database_id=self.database_id, - project_id=self.project_id or hook.project_id, + project_id=self.project_id or self.hook.project_id, ) return result_rows_count_per_query @@ -305,6 +310,17 @@ def sanitize_queries(queries: list[str]) -> None: if queries and queries[-1] == "": queries.pop() + def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None: + """Build a generic OpenLineage facet, aligned with SQL-based operators.""" + from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql + + return get_openlineage_facets_with_sql( + hook=self.hook, + sql=self.query, + conn_id=self.gcp_conn_id, + database=self.database_id, + ) + class SpannerDeployDatabaseInstanceOperator(GoogleCloudBaseOperator): """ diff --git a/providers/google/tests/unit/google/cloud/operators/test_spanner.py b/providers/google/tests/unit/google/cloud/operators/test_spanner.py index 1784a0499aab0..7a5b33429b9e7 100644 --- a/providers/google/tests/unit/google/cloud/operators/test_spanner.py +++ b/providers/google/tests/unit/google/cloud/operators/test_spanner.py @@ -18,10 +18,13 @@ from __future__ import annotations from unittest import mock +from unittest.mock import MagicMock import pytest from airflow.exceptions import AirflowException +from airflow.models import Connection +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.google.cloud.operators.spanner import ( SpannerDeleteDatabaseInstanceOperator, SpannerDeleteInstanceOperator, @@ -30,6 +33,7 @@ SpannerQueryDatabaseInstanceOperator, SpannerUpdateDatabaseInstanceOperator, ) +from airflow.providers.openlineage.sqlparser import DatabaseInfo PROJECT_ID = "project-id" INSTANCE_ID = "instance-id" @@ -42,6 +46,32 @@ CREATE_QUERY = "CREATE TABLE my_table1 (id INT64, name STRING(100))" CREATE_QUERY_2 = "CREATE TABLE my_table2 (id INT64, name STRING(100))" DDL_STATEMENTS = [CREATE_QUERY, CREATE_QUERY_2] +TASK_ID = "task-id" + +SCHEMA_ROWS = { + "public.orders": [ + ("public", "orders", "id", 1, "INT64"), + ("public", "orders", "amount", 2, "FLOAT64"), + ], + "public.staging": [ + ("public", "staging", "id", 1, "INT64"), + ("public", "staging", "amount", 2, "FLOAT64"), + ], + "public.customers": [ + ("public", "customers", "id", 1, "INT64"), + ("public", "customers", "name", 2, "STRING(100)"), + ("public", "customers", "customer_id", 3, "INT64"), + ], + "public.logs": [ + ("public", "logs", "id", 1, "INT64"), + ("public", "logs", "message", 2, "STRING(100)"), + ], + "public.t1": [("public", "t1", "col1", 1, "STRING(100)")], + "public.t2": [("public", "t2", "col1", 1, "STRING(100)")], + "public.t3": [("public", "t3", "id", 1, "INT64")], + # example of explicit non-default schema + "myschema.orders": [("myschema", "orders", "id", 1, "INT64")], +} class TestCloudSpanner: @@ -353,6 +383,133 @@ def test_instance_query_dml_list(self, mock_hook): queries=[INSERT_QUERY, INSERT_QUERY_2], ) + @pytest.mark.parametrize( + "sql, expected_inputs, expected_outputs, expected_lineage", + [ + ("SELECT id, amount FROM public.orders", ["db1.public.orders"], [], {}), + ( + "INSERT INTO public.orders (id, amount) SELECT id, amount FROM public.staging", + ["db1.public.staging", "db1.public.orders"], + [], + {}, + ), + ("DELETE FROM public.logs WHERE id=1", [], ["db1.public.logs"], {}), + ( + "SELECT o.id, c.name FROM public.orders o JOIN public.customers c ON o.customer_id = c.id", + ["db1.public.orders", "db1.public.customers"], + [], + {}, + ), + ( + "UPDATE public.customers SET name='x' WHERE id IN (SELECT id FROM public.staging)", + ["db1.public.customers", "db1.public.staging"], + [], + {}, + ), + ( + ["INSERT INTO public.t1 SELECT * FROM public.t2;", "DELETE FROM public.t3 WHERE id=1;"], + ["db1.public.t1", "db1.public.t2", "db1.public.t3"], + [], + {}, + ), + ("SELECT id, amount FROM myschema.orders", ["db1.myschema.orders"], [], {}), + ], + ) + def test_spannerquerydatabaseinstanceoperator_get_openlineage_facets( + self, sql, expected_inputs, expected_outputs, expected_lineage + ): + # Arrange + class SpannerHookForTests(DbApiHook): + conn_name_attr = "gcp_conn_id" + get_conn = MagicMock(name="conn") + get_connection = MagicMock() + database = DB_ID + + def get_openlineage_database_info(self, connection): + return DatabaseInfo( + scheme="spanner", + authority=f"{PROJECT_ID}/{INSTANCE_ID}", + database=DB_ID, + information_schema_columns=[ + "table_schema", + "table_name", + "column_name", + "ordinal_position", + "spanner_type", + ], + information_schema_table_name="information_schema.columns", + use_flat_cross_db_query=False, + is_information_schema_cross_db=False, + is_uppercase_names=False, + ) + + dbapi_hook = SpannerHookForTests() + + class SpannerOperatorForTest(SpannerQueryDatabaseInstanceOperator): + @property + def hook(self): + return dbapi_hook + + op = SpannerOperatorForTest( + task_id=TASK_ID, + instance_id=INSTANCE_ID, + database_id=DB_ID, + gcp_conn_id="spanner_conn", + query=sql, + ) + + dbapi_hook.get_connection.return_value = Connection( + conn_id="spanner_conn", conn_type="spanner", host="spanner-host" + ) + + combined_rows = [] + for ds in expected_inputs + expected_outputs: + tbl = ds.split(".", 1)[1] + combined_rows.extend(SCHEMA_ROWS.get(tbl, [])) + + dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [combined_rows, []] + + # Act + lineage = op.get_openlineage_facets_on_complete(task_instance=None) + assert lineage is not None + + # Assert inputs + input_names = {ds.name for ds in lineage.inputs} + assert input_names == set(expected_inputs) + for ds in lineage.inputs: + assert ds.namespace == f"spanner://{PROJECT_ID}/{INSTANCE_ID}" + + # Assert outputs + output_names = {ds.name for ds in lineage.outputs} + assert output_names == set(expected_outputs) + for ds in lineage.outputs: + assert ds.namespace == f"spanner://{PROJECT_ID}/{INSTANCE_ID}" + + # Assert SQLJobFacet + sql_job = lineage.job_facets["sql"] + if isinstance(sql, list): + for q in sql: + assert q.replace(";", "").strip() in sql_job.query.replace(";", "") + else: + assert sql_job.query == sql + + # Assert column lineage + found_lineage = { + getattr(field, "field", None) or getattr(field, "name", None): [ + f"{inp.dataset.name}.{getattr(inp, 'field', getattr(inp, 'name', None))}" + for inp in getattr(field, "inputFields", []) + ] + for ds in lineage.outputs + lineage.inputs + for cl_facet in [ds.facets.get("columnLineage")] + if cl_facet + for field in cl_facet.fields + } + + for col, sources in expected_lineage.items(): + assert col in found_lineage + for src in sources: + assert any(src in s for s in found_lineage[col]) + @mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook") def test_database_create(self, mock_hook): mock_hook.return_value.get_database.return_value = None