Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.common.consts import CLIENT_INFO
from airflow.providers.google.common.hooks.base_google import GoogleBaseHook, get_field
from airflow.providers.openlineage.sqlparser import DatabaseInfo

if TYPE_CHECKING:
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
from google.cloud.spanner_v1.transaction import Transaction
from google.longrunning.operations_grpc_pb2 import Operation

from airflow.models.connection import Connection


class SpannerConnectionParams(NamedTuple):
"""Information about Google Spanner connection parameters."""
Expand Down Expand Up @@ -427,3 +430,45 @@ def _execute_sql_in_transaction(transaction: Transaction, queries: list[str]) ->
rc = transaction.execute_update(sql)
counts[sql] = rc
return counts

def _get_openlineage_authority_part(self, connection: Connection) -> str | None:
"""Build Spanner-specific authority part for OpenLineage. Returns {project}/{instance}."""
extras = connection.extra_dejson
project_id = extras.get("project_id")
instance_id = extras.get("instance_id")

if not project_id or not instance_id:
return None

return f"{project_id}/{instance_id}"

def get_openlineage_database_dialect(self, connection: Connection) -> str:
"""Return database dialect for OpenLineage."""
return "spanner"

def get_openlineage_database_info(self, connection: Connection) -> DatabaseInfo:
"""Return Spanner specific information for OpenLineage."""
extras = connection.extra_dejson
database_id = extras.get("database_id")

return DatabaseInfo(
scheme=self.get_openlineage_database_dialect(connection),
authority=self._get_openlineage_authority_part(connection),
database=database_id,
information_schema_columns=[
"table_schema",
"table_name",
"column_name",
"ordinal_position",
"spanner_type",
],
)

def get_openlineage_default_schema(self) -> str | None:
"""
Spanner expose 'public' or '' schema depending on dialect(Postgres vs GoogleSQL).

SQLAlchemy dialect for Spanner does not expose default schema, so we return None
to follow the same approach.
"""
return None
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING

from airflow.exceptions import AirflowException
Expand All @@ -29,6 +30,7 @@
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -254,6 +256,13 @@ def __init__(
self.impersonation_chain = impersonation_chain
super().__init__(**kwargs)

@cached_property
def hook(self) -> SpannerHook:
return SpannerHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)

def _validate_inputs(self) -> None:
if self.project_id == "":
raise AirflowException("The required parameter 'project_id' is empty")
Expand All @@ -265,10 +274,6 @@ def _validate_inputs(self) -> None:
raise AirflowException("The required parameter 'query' is empty")

def execute(self, context: Context):
hook = SpannerHook(
gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain,
)
if isinstance(self.query, str):
queries = [x.strip() for x in self.query.split(";")]
self.sanitize_queries(queries)
Expand All @@ -281,7 +286,7 @@ def execute(self, context: Context):
self.database_id,
)
self.log.info("Executing queries: %s", queries)
result_rows_count_per_query = hook.execute_dml(
result_rows_count_per_query = self.hook.execute_dml(
project_id=self.project_id,
instance_id=self.instance_id,
database_id=self.database_id,
Expand All @@ -291,7 +296,7 @@ def execute(self, context: Context):
context=context,
instance_id=self.instance_id,
database_id=self.database_id,
project_id=self.project_id or hook.project_id,
project_id=self.project_id or self.hook.project_id,
)
return result_rows_count_per_query

Expand All @@ -305,6 +310,17 @@ def sanitize_queries(queries: list[str]) -> None:
if queries and queries[-1] == "":
queries.pop()

def get_openlineage_facets_on_complete(self, task_instance) -> OperatorLineage | None:
"""Build a generic OpenLineage facet, aligned with SQL-based operators."""
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql

return get_openlineage_facets_with_sql(
hook=self.hook,
sql=self.query,
conn_id=self.gcp_conn_id,
database=self.database_id,
)


class SpannerDeployDatabaseInstanceOperator(GoogleCloudBaseOperator):
"""
Expand Down
157 changes: 157 additions & 0 deletions providers/google/tests/unit/google/cloud/operators/test_spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,13 @@
from __future__ import annotations

from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.exceptions import AirflowException
from airflow.models import Connection
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.operators.spanner import (
SpannerDeleteDatabaseInstanceOperator,
SpannerDeleteInstanceOperator,
Expand All @@ -30,6 +33,7 @@
SpannerQueryDatabaseInstanceOperator,
SpannerUpdateDatabaseInstanceOperator,
)
from airflow.providers.openlineage.sqlparser import DatabaseInfo

PROJECT_ID = "project-id"
INSTANCE_ID = "instance-id"
Expand All @@ -42,6 +46,32 @@
CREATE_QUERY = "CREATE TABLE my_table1 (id INT64, name STRING(100))"
CREATE_QUERY_2 = "CREATE TABLE my_table2 (id INT64, name STRING(100))"
DDL_STATEMENTS = [CREATE_QUERY, CREATE_QUERY_2]
TASK_ID = "task-id"

SCHEMA_ROWS = {
"public.orders": [
("public", "orders", "id", 1, "INT64"),
("public", "orders", "amount", 2, "FLOAT64"),
],
"public.staging": [
("public", "staging", "id", 1, "INT64"),
("public", "staging", "amount", 2, "FLOAT64"),
],
"public.customers": [
("public", "customers", "id", 1, "INT64"),
("public", "customers", "name", 2, "STRING(100)"),
("public", "customers", "customer_id", 3, "INT64"),
],
"public.logs": [
("public", "logs", "id", 1, "INT64"),
("public", "logs", "message", 2, "STRING(100)"),
],
"public.t1": [("public", "t1", "col1", 1, "STRING(100)")],
"public.t2": [("public", "t2", "col1", 1, "STRING(100)")],
"public.t3": [("public", "t3", "id", 1, "INT64")],
# example of explicit non-default schema
"myschema.orders": [("myschema", "orders", "id", 1, "INT64")],
}


class TestCloudSpanner:
Expand Down Expand Up @@ -353,6 +383,133 @@ def test_instance_query_dml_list(self, mock_hook):
queries=[INSERT_QUERY, INSERT_QUERY_2],
)

@pytest.mark.parametrize(
"sql, expected_inputs, expected_outputs, expected_lineage",
[
("SELECT id, amount FROM public.orders", ["db1.public.orders"], [], {}),
(
"INSERT INTO public.orders (id, amount) SELECT id, amount FROM public.staging",
["db1.public.staging", "db1.public.orders"],
[],
{},
),
("DELETE FROM public.logs WHERE id=1", [], ["db1.public.logs"], {}),
(
"SELECT o.id, c.name FROM public.orders o JOIN public.customers c ON o.customer_id = c.id",
["db1.public.orders", "db1.public.customers"],
[],
{},
),
(
"UPDATE public.customers SET name='x' WHERE id IN (SELECT id FROM public.staging)",
["db1.public.customers", "db1.public.staging"],
[],
{},
),
(
["INSERT INTO public.t1 SELECT * FROM public.t2;", "DELETE FROM public.t3 WHERE id=1;"],
["db1.public.t1", "db1.public.t2", "db1.public.t3"],
[],
{},
),
("SELECT id, amount FROM myschema.orders", ["db1.myschema.orders"], [], {}),
],
)
def test_spannerquerydatabaseinstanceoperator_get_openlineage_facets(
self, sql, expected_inputs, expected_outputs, expected_lineage
):
# Arrange
class SpannerHookForTests(DbApiHook):
conn_name_attr = "gcp_conn_id"
get_conn = MagicMock(name="conn")
get_connection = MagicMock()
database = DB_ID

def get_openlineage_database_info(self, connection):
return DatabaseInfo(
scheme="spanner",
authority=f"{PROJECT_ID}/{INSTANCE_ID}",
database=DB_ID,
information_schema_columns=[
"table_schema",
"table_name",
"column_name",
"ordinal_position",
"spanner_type",
],
information_schema_table_name="information_schema.columns",
use_flat_cross_db_query=False,
is_information_schema_cross_db=False,
is_uppercase_names=False,
)

dbapi_hook = SpannerHookForTests()

class SpannerOperatorForTest(SpannerQueryDatabaseInstanceOperator):
@property
def hook(self):
return dbapi_hook

op = SpannerOperatorForTest(
task_id=TASK_ID,
instance_id=INSTANCE_ID,
database_id=DB_ID,
gcp_conn_id="spanner_conn",
query=sql,
)

dbapi_hook.get_connection.return_value = Connection(
conn_id="spanner_conn", conn_type="spanner", host="spanner-host"
)

combined_rows = []
for ds in expected_inputs + expected_outputs:
tbl = ds.split(".", 1)[1]
combined_rows.extend(SCHEMA_ROWS.get(tbl, []))

dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [combined_rows, []]

# Act
lineage = op.get_openlineage_facets_on_complete(task_instance=None)
assert lineage is not None

# Assert inputs
input_names = {ds.name for ds in lineage.inputs}
assert input_names == set(expected_inputs)
for ds in lineage.inputs:
assert ds.namespace == f"spanner://{PROJECT_ID}/{INSTANCE_ID}"

# Assert outputs
output_names = {ds.name for ds in lineage.outputs}
assert output_names == set(expected_outputs)
for ds in lineage.outputs:
assert ds.namespace == f"spanner://{PROJECT_ID}/{INSTANCE_ID}"

# Assert SQLJobFacet
sql_job = lineage.job_facets["sql"]
if isinstance(sql, list):
for q in sql:
assert q.replace(";", "").strip() in sql_job.query.replace(";", "")
else:
assert sql_job.query == sql

# Assert column lineage
found_lineage = {
getattr(field, "field", None) or getattr(field, "name", None): [
f"{inp.dataset.name}.{getattr(inp, 'field', getattr(inp, 'name', None))}"
for inp in getattr(field, "inputFields", [])
]
for ds in lineage.outputs + lineage.inputs
for cl_facet in [ds.facets.get("columnLineage")]
if cl_facet
for field in cl_facet.fields
}

for col, sources in expected_lineage.items():
assert col in found_lineage
for src in sources:
assert any(src in s for s in found_lineage[col])

@mock.patch("airflow.providers.google.cloud.operators.spanner.SpannerHook")
def test_database_create(self, mock_hook):
mock_hook.return_value.get_database.return_value = None
Expand Down
Loading