diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py index bf8b007009068..5e0f9cc7d244c 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py @@ -277,8 +277,13 @@ def __init__( self._client_parameters = client_parameters or {} if force_copy is not None: self._copy_options["force"] = "true" if force_copy else "false" + self._sql: str | None = None def _get_hook(self) -> DatabricksSqlHook: + return self._hook + + @cached_property + def _hook(self) -> DatabricksSqlHook: return DatabricksSqlHook( self.databricks_conn_id, http_path=self._http_path, @@ -354,12 +359,116 @@ def _create_sql_query(self) -> str: return sql.strip() def execute(self, context: Context) -> Any: - sql = self._create_sql_query() - self.log.info("Executing: %s", sql) + self._sql = self._create_sql_query() + self.log.info("Executing: %s", self._sql) hook = self._get_hook() - hook.run(sql) + hook.run(self._sql) def on_kill(self) -> None: # NB: on_kill isn't required for this operator since query cancelling gets # handled in `DatabricksSqlHook.run()` method which is called in `execute()` ... + + def _build_input_openlineage_dataset(self) -> tuple[Any, list[Any]]: + """Parse file_location to build the OpenLineage input dataset.""" + from urllib.parse import urlparse + + from airflow.providers.common.compat.openlineage.facet import Dataset, Error + + try: + uri = urlparse(self.file_location) + + # Only process schemes we know produce valid OL datasets with current implementation + if uri.scheme not in ("s3", "s3a", "s3n", "gs", "abfss", "wasbs"): + raise ValueError(f"Unsupported scheme: `{uri.scheme}` in `{self.file_location}`") + + namespace = f"{uri.scheme}://{uri.netloc}" + name = uri.path.strip("/") + if name in ("", "."): + name = "/" + return Dataset(namespace=namespace, name=name), [] + except Exception as e: + self.log.debug("Failed to parse file_location: `%s`, error: %s", self.file_location, str(e)) + extraction_errors = [ + Error(errorMessage=str(e), stackTrace=None, task=self.file_location, taskNumber=None) + ] + return None, extraction_errors + + def _build_output_openlineage_dataset(self, namespace: str) -> tuple[Any, list[Any]]: + """Build output OpenLineage dataset from table information.""" + from airflow.providers.common.compat.openlineage.facet import Dataset, Error + + try: + table_parts = self.table_name.split(".") + if len(table_parts) == 3: # catalog.schema.table + catalog, schema, table = table_parts + elif len(table_parts) == 2: # schema.table + catalog = None + schema, table = table_parts + else: + catalog = None + schema = None + table = self.table_name + + hook = self._get_hook() + schema = schema or hook.get_openlineage_default_schema() # Fallback to default schema + catalog = catalog or hook.catalog # Fallback to default catalog, if provided + + # Combine schema/table with optional catalog for final dataset name + fq_name = table + if schema: + fq_name = f"{schema}.{fq_name}" + if catalog: + fq_name = f"{catalog}.{fq_name}" + + return Dataset(namespace=namespace, name=fq_name), [] + except Exception as e: + self.log.debug("Failed to construct output dataset: `%s`, error: %s", self.table_name, str(e)) + extraction_errors = [ + Error(errorMessage=str(e), stackTrace=None, task=self.table_name, taskNumber=None) + ] + return None, extraction_errors + + def get_openlineage_facets_on_complete(self, _): + """Implement _on_complete as we are attaching query id.""" + from airflow.providers.common.compat.openlineage.facet import ( + ExternalQueryRunFacet, + ExtractionErrorRunFacet, + SQLJobFacet, + ) + from airflow.providers.openlineage.extractors import OperatorLineage + from airflow.providers.openlineage.sqlparser import SQLParser + + if not self._sql: + self.log.warning("No SQL query found, returning empty OperatorLineage.") + return OperatorLineage() + + hook = self._get_hook() + run_facets = {} + + connection = hook.get_connection(self.databricks_conn_id) + database_info = hook.get_openlineage_database_info(connection) + dbx_namespace = SQLParser.create_namespace(database_info) + + if hook.query_ids: + run_facets["externalQuery"] = ExternalQueryRunFacet( + externalQueryId=hook.query_ids[0], source=dbx_namespace + ) + + input_dataset, extraction_errors = self._build_input_openlineage_dataset() + output_dataset, output_errors = self._build_output_openlineage_dataset(dbx_namespace) + extraction_errors.extend(output_errors) + + if extraction_errors: + run_facets["extractionError"] = ExtractionErrorRunFacet( + totalTasks=1, + failedTasks=len(extraction_errors), + errors=extraction_errors, + ) + + return OperatorLineage( + inputs=[input_dataset] if input_dataset else [], + outputs=[output_dataset] if output_dataset else [], + job_facets={"sql": SQLJobFacet(query=SQLParser.normalize_sql(self._sql))}, + run_facets=run_facets, + ) diff --git a/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py b/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py index 60b8cd82c14fb..28799755bf76a 100644 --- a/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py +++ b/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py @@ -17,10 +17,18 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest from airflow.exceptions import AirflowException +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + ExternalQueryRunFacet, + SQLJobFacet, +) from airflow.providers.databricks.operators.databricks_sql import DatabricksCopyIntoOperator +from airflow.providers.openlineage.extractors import OperatorLineage DATE = "2017-04-20" TASK_ID = "databricks-sql-operator" @@ -253,3 +261,172 @@ def test_templating(create_task_instance_of_operator, session): assert task.files == "files" assert task.table_name == "table-name" assert task.databricks_conn_id == "databricks-conn-id" + + +def test_hook_is_cached(): + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + task_id=TASK_ID, + ) + hook = op._get_hook() + hook2 = op._get_hook() + assert hook is hook2 + + +@pytest.mark.parametrize( + ("file_location", "expected_namespace", "expected_name"), + ( + ("gs://bucket/another_dir/file1.csv", "gs://bucket", "another_dir/file1.csv"), + ("gs://bucket/another_dir/", "gs://bucket", "another_dir"), + ("s3://bucket/another_dir", "s3://bucket", "another_dir"), + ("s3://bucket/", "s3://bucket", "/"), + ( + "abfss://container@account.dfs.core.windows.net/my-data/csv", + "abfss://container@account.dfs.core.windows.net", + "my-data/csv", + ), + ( + "abfss://container@account.dfs.core.windows.net", + "abfss://container@account.dfs.core.windows.net", + "/", + ), + ( + "wasbs://container@account.dfs.core.windows.net", + "wasbs://container@account.dfs.core.windows.net", + "/", + ), + ), +) +def test_build_input_openlineage_dataset_correct(file_location, expected_namespace, expected_name): + op = DatabricksCopyIntoOperator( + file_location=file_location, + file_format="JSON", + table_name="test", + task_id=TASK_ID, + ) + ds, errors = op._build_input_openlineage_dataset() + assert ds == Dataset(namespace=expected_namespace, name=expected_name) + assert not errors + + +@pytest.mark.parametrize( + ("file_location", "expected_error"), + ( + ("azure://bucket/another_dir/file1.csv", "Unsupported scheme: `azure`"), + ("r2://bucket/another_dir/file1.csv", "Unsupported scheme: `r2`"), + ("my_random_location", "Unsupported scheme: ``"), + ), +) +def test_build_input_openlineage_dataset_silences_error(file_location, expected_error): + op = DatabricksCopyIntoOperator( + file_location=file_location, + file_format="JSON", + table_name="test", + task_id=TASK_ID, + ) + ds, errors = op._build_input_openlineage_dataset() + assert ds is None + assert len(errors) == 1 + assert errors[0].task == file_location + assert expected_error in errors[0].errorMessage + + +@pytest.mark.parametrize( + ("table_name", "default_catalog", "default_schema", "expected_name"), + ( + ("c.s.t", None, None, "c.s.t"), + ("s.t", None, None, "s.t"), + ("s.t", "dfc", None, "dfc.s.t"), + ("t", None, None, "ol_default_schema.t"), + ("c.s.t", "dfc", "dfs", "c.s.t"), + ("s.t", "dfc", "dfs", "dfc.s.t"), + ("t", "dfc", "dfs", "dfc.dfs.t"), + ), +) +def test_build_output_openlineage_dataset_correct(table_name, default_catalog, default_schema, expected_name): + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name=table_name, + task_id=TASK_ID, + catalog=default_catalog, + schema=default_schema, + ) + mock_hook = mock.MagicMock() + mock_hook.get_openlineage_default_schema.return_value = ( + "ol_default_schema" if not default_schema else default_schema + ) + mock_hook.catalog = default_catalog + + mock_get_hook = mock.MagicMock() + mock_get_hook.return_value = mock_hook + op._get_hook = mock_get_hook + + ds, errors = op._build_output_openlineage_dataset("ol_namespace") + assert ds == Dataset(namespace="ol_namespace", name=expected_name) + assert not errors + + +def test_build_output_openlineage_dataset_silences_error(): + table_name = "test" + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name=table_name, + task_id=TASK_ID, + ) + + def mock_raise(): + raise ValueError("test") + + op._get_hook = mock_raise + + ds, errors = op._build_output_openlineage_dataset("ol_namespace") + assert ds is None + assert len(errors) == 1 + assert errors[0].task == table_name + assert errors[0].errorMessage == "test" + + +def test_get_openlineage_facets_early_return_when_no_sql_found(): + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + task_id=TASK_ID, + ) + op._sql = None + result = op.get_openlineage_facets_on_complete(None) + assert result == OperatorLineage() + + +def test_get_openlineage_facets(): + op = DatabricksCopyIntoOperator( + file_location=COPY_FILE_LOCATION, + file_format="JSON", + table_name="test", + task_id=TASK_ID, + catalog="default_catalog", + schema="default_schema", + ) + mock_hook = mock.MagicMock() + mock_hook.get_openlineage_default_schema.return_value = "default_schema" + mock_hook.catalog = "default_catalog" + mock_hook.query_ids = ["query_id"] + mock_hook.get_openlineage_database_info.return_value = mock.MagicMock(scheme="scheme", authority="host") + + mock_get_hook = mock.MagicMock() + mock_get_hook.return_value = mock_hook + op._get_hook = mock_get_hook + + op.execute(None) + + result = op.get_openlineage_facets_on_complete(None) + assert result.inputs == [Dataset(namespace="s3://my-bucket", name="jsonData")] + assert result.outputs == [Dataset(namespace="scheme://host", name="default_catalog.default_schema.test")] + assert result.run_facets == { + "externalQuery": ExternalQueryRunFacet(externalQueryId="query_id", source="scheme://host") + } + assert result.job_facets == {"sql": SQLJobFacet(query=op._sql)}