Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,13 @@ def __init__(
self._client_parameters = client_parameters or {}
if force_copy is not None:
self._copy_options["force"] = "true" if force_copy else "false"
self._sql: str | None = None

def _get_hook(self) -> DatabricksSqlHook:
return self._hook

@cached_property
def _hook(self) -> DatabricksSqlHook:
return DatabricksSqlHook(
self.databricks_conn_id,
http_path=self._http_path,
Expand Down Expand Up @@ -354,12 +359,116 @@ def _create_sql_query(self) -> str:
return sql.strip()

def execute(self, context: Context) -> Any:
sql = self._create_sql_query()
self.log.info("Executing: %s", sql)
self._sql = self._create_sql_query()
self.log.info("Executing: %s", self._sql)
hook = self._get_hook()
hook.run(sql)
hook.run(self._sql)

def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling gets
# handled in `DatabricksSqlHook.run()` method which is called in `execute()`
...

def _build_input_openlineage_dataset(self) -> tuple[Any, list[Any]]:
"""Parse file_location to build the OpenLineage input dataset."""
from urllib.parse import urlparse

from airflow.providers.common.compat.openlineage.facet import Dataset, Error

try:
uri = urlparse(self.file_location)

# Only process schemes we know produce valid OL datasets with current implementation
if uri.scheme not in ("s3", "s3a", "s3n", "gs", "abfss", "wasbs"):
raise ValueError(f"Unsupported scheme: `{uri.scheme}` in `{self.file_location}`")

namespace = f"{uri.scheme}://{uri.netloc}"
name = uri.path.strip("/")
if name in ("", "."):
name = "/"
return Dataset(namespace=namespace, name=name), []
except Exception as e:
self.log.debug("Failed to parse file_location: `%s`, error: %s", self.file_location, str(e))
extraction_errors = [
Error(errorMessage=str(e), stackTrace=None, task=self.file_location, taskNumber=None)
]
return None, extraction_errors

def _build_output_openlineage_dataset(self, namespace: str) -> tuple[Any, list[Any]]:
"""Build output OpenLineage dataset from table information."""
from airflow.providers.common.compat.openlineage.facet import Dataset, Error

try:
table_parts = self.table_name.split(".")
if len(table_parts) == 3: # catalog.schema.table
catalog, schema, table = table_parts
elif len(table_parts) == 2: # schema.table
catalog = None
schema, table = table_parts
else:
catalog = None
schema = None
table = self.table_name

hook = self._get_hook()
schema = schema or hook.get_openlineage_default_schema() # Fallback to default schema
catalog = catalog or hook.catalog # Fallback to default catalog, if provided

# Combine schema/table with optional catalog for final dataset name
fq_name = table
if schema:
fq_name = f"{schema}.{fq_name}"
if catalog:
fq_name = f"{catalog}.{fq_name}"

return Dataset(namespace=namespace, name=fq_name), []
except Exception as e:
self.log.debug("Failed to construct output dataset: `%s`, error: %s", self.table_name, str(e))
extraction_errors = [
Error(errorMessage=str(e), stackTrace=None, task=self.table_name, taskNumber=None)
]
return None, extraction_errors

def get_openlineage_facets_on_complete(self, _):
"""Implement _on_complete as we are attaching query id."""
from airflow.providers.common.compat.openlineage.facet import (
ExternalQueryRunFacet,
ExtractionErrorRunFacet,
SQLJobFacet,
)
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.providers.openlineage.sqlparser import SQLParser

if not self._sql:
self.log.warning("No SQL query found, returning empty OperatorLineage.")
return OperatorLineage()

hook = self._get_hook()
run_facets = {}

connection = hook.get_connection(self.databricks_conn_id)
database_info = hook.get_openlineage_database_info(connection)
dbx_namespace = SQLParser.create_namespace(database_info)

if hook.query_ids:
run_facets["externalQuery"] = ExternalQueryRunFacet(
externalQueryId=hook.query_ids[0], source=dbx_namespace
)

input_dataset, extraction_errors = self._build_input_openlineage_dataset()
output_dataset, output_errors = self._build_output_openlineage_dataset(dbx_namespace)
extraction_errors.extend(output_errors)

if extraction_errors:
run_facets["extractionError"] = ExtractionErrorRunFacet(
totalTasks=1,
failedTasks=len(extraction_errors),
errors=extraction_errors,
)

return OperatorLineage(
inputs=[input_dataset] if input_dataset else [],
outputs=[output_dataset] if output_dataset else [],
job_facets={"sql": SQLJobFacet(query=SQLParser.normalize_sql(self._sql))},
run_facets=run_facets,
)
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@
# under the License.
from __future__ import annotations

from unittest import mock

import pytest

from airflow.exceptions import AirflowException
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
ExternalQueryRunFacet,
SQLJobFacet,
)
from airflow.providers.databricks.operators.databricks_sql import DatabricksCopyIntoOperator
from airflow.providers.openlineage.extractors import OperatorLineage

DATE = "2017-04-20"
TASK_ID = "databricks-sql-operator"
Expand Down Expand Up @@ -253,3 +261,172 @@ def test_templating(create_task_instance_of_operator, session):
assert task.files == "files"
assert task.table_name == "table-name"
assert task.databricks_conn_id == "databricks-conn-id"


def test_hook_is_cached():
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
task_id=TASK_ID,
)
hook = op._get_hook()
hook2 = op._get_hook()
assert hook is hook2


@pytest.mark.parametrize(
("file_location", "expected_namespace", "expected_name"),
(
("gs://bucket/another_dir/file1.csv", "gs://bucket", "another_dir/file1.csv"),
("gs://bucket/another_dir/", "gs://bucket", "another_dir"),
("s3://bucket/another_dir", "s3://bucket", "another_dir"),
("s3://bucket/", "s3://bucket", "/"),
(
"abfss://container@account.dfs.core.windows.net/my-data/csv",
"abfss://container@account.dfs.core.windows.net",
"my-data/csv",
),
(
"abfss://container@account.dfs.core.windows.net",
"abfss://container@account.dfs.core.windows.net",
"/",
),
(
"wasbs://container@account.dfs.core.windows.net",
"wasbs://container@account.dfs.core.windows.net",
"/",
),
),
)
def test_build_input_openlineage_dataset_correct(file_location, expected_namespace, expected_name):
op = DatabricksCopyIntoOperator(
file_location=file_location,
file_format="JSON",
table_name="test",
task_id=TASK_ID,
)
ds, errors = op._build_input_openlineage_dataset()
assert ds == Dataset(namespace=expected_namespace, name=expected_name)
assert not errors


@pytest.mark.parametrize(
("file_location", "expected_error"),
(
("azure://bucket/another_dir/file1.csv", "Unsupported scheme: `azure`"),
("r2://bucket/another_dir/file1.csv", "Unsupported scheme: `r2`"),
("my_random_location", "Unsupported scheme: ``"),
),
)
def test_build_input_openlineage_dataset_silences_error(file_location, expected_error):
op = DatabricksCopyIntoOperator(
file_location=file_location,
file_format="JSON",
table_name="test",
task_id=TASK_ID,
)
ds, errors = op._build_input_openlineage_dataset()
assert ds is None
assert len(errors) == 1
assert errors[0].task == file_location
assert expected_error in errors[0].errorMessage


@pytest.mark.parametrize(
("table_name", "default_catalog", "default_schema", "expected_name"),
(
("c.s.t", None, None, "c.s.t"),
("s.t", None, None, "s.t"),
("s.t", "dfc", None, "dfc.s.t"),
("t", None, None, "ol_default_schema.t"),
("c.s.t", "dfc", "dfs", "c.s.t"),
("s.t", "dfc", "dfs", "dfc.s.t"),
("t", "dfc", "dfs", "dfc.dfs.t"),
),
)
def test_build_output_openlineage_dataset_correct(table_name, default_catalog, default_schema, expected_name):
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name=table_name,
task_id=TASK_ID,
catalog=default_catalog,
schema=default_schema,
)
mock_hook = mock.MagicMock()
mock_hook.get_openlineage_default_schema.return_value = (
"ol_default_schema" if not default_schema else default_schema
)
mock_hook.catalog = default_catalog

mock_get_hook = mock.MagicMock()
mock_get_hook.return_value = mock_hook
op._get_hook = mock_get_hook

ds, errors = op._build_output_openlineage_dataset("ol_namespace")
assert ds == Dataset(namespace="ol_namespace", name=expected_name)
assert not errors


def test_build_output_openlineage_dataset_silences_error():
table_name = "test"
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name=table_name,
task_id=TASK_ID,
)

def mock_raise():
raise ValueError("test")

op._get_hook = mock_raise

ds, errors = op._build_output_openlineage_dataset("ol_namespace")
assert ds is None
assert len(errors) == 1
assert errors[0].task == table_name
assert errors[0].errorMessage == "test"


def test_get_openlineage_facets_early_return_when_no_sql_found():
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
task_id=TASK_ID,
)
op._sql = None
result = op.get_openlineage_facets_on_complete(None)
assert result == OperatorLineage()


def test_get_openlineage_facets():
op = DatabricksCopyIntoOperator(
file_location=COPY_FILE_LOCATION,
file_format="JSON",
table_name="test",
task_id=TASK_ID,
catalog="default_catalog",
schema="default_schema",
)
mock_hook = mock.MagicMock()
mock_hook.get_openlineage_default_schema.return_value = "default_schema"
mock_hook.catalog = "default_catalog"
mock_hook.query_ids = ["query_id"]
mock_hook.get_openlineage_database_info.return_value = mock.MagicMock(scheme="scheme", authority="host")

mock_get_hook = mock.MagicMock()
mock_get_hook.return_value = mock_hook
op._get_hook = mock_get_hook

op.execute(None)

result = op.get_openlineage_facets_on_complete(None)
assert result.inputs == [Dataset(namespace="s3://my-bucket", name="jsonData")]
assert result.outputs == [Dataset(namespace="scheme://host", name="default_catalog.default_schema.test")]
assert result.run_facets == {
"externalQuery": ExternalQueryRunFacet(externalQueryId="query_id", source="scheme://host")
}
assert result.job_facets == {"sql": SQLJobFacet(query=op._sql)}