diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 681b18a0d2fef..6f804547db469 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -449,7 +449,9 @@ } ], "cross-providers-deps": [ - "common.sql" + "common.compat", + "common.sql", + "openlineage" ], "excluded-python-versions": [], "state": "ready" diff --git a/providers/src/airflow/providers/databricks/operators/databricks_sql.py b/providers/src/airflow/providers/databricks/operators/databricks_sql.py index a4cb062c9b8c3..1b3784f58b635 100644 --- a/providers/src/airflow/providers/databricks/operators/databricks_sql.py +++ b/providers/src/airflow/providers/databricks/operators/databricks_sql.py @@ -273,7 +273,12 @@ def __init__( if force_copy is not None: self._copy_options["force"] = "true" if force_copy else "false" + # These will be used by OpenLineage + self._sql: str | None = None + self._result: list[Any] = [] + def _get_hook(self) -> DatabricksSqlHook: + """Get a DatabricksSqlHook properly configured for this operator.""" return DatabricksSqlHook( self.databricks_conn_id, http_path=self._http_path, @@ -293,6 +298,11 @@ def _generate_options( opts: dict[str, str] | None = None, escape_key: bool = True, ) -> str: + """ + Generate the bracketed options clause for the COPY INTO command. + + Example: FORMAT_OPTIONS (header = 'true', inferSchema = 'true'). + """ formatted_opts = "" if opts: pairs = [ @@ -304,6 +314,7 @@ def _generate_options( return formatted_opts def _create_sql_query(self) -> str: + """Create the COPY INTO statement from the provided options.""" escaper = ParamEscaper() maybe_with = "" if self._encryption is not None or self._credential is not None: @@ -349,12 +360,166 @@ def _create_sql_query(self) -> str: return sql.strip() def execute(self, context: Context) -> Any: - sql = self._create_sql_query() - self.log.info("Executing: %s", sql) + """Execute the COPY INTO command and store the result for lineage reporting.""" + self._sql = self._create_sql_query() + self.log.info("Executing SQL: %s", self._sql) + hook = self._get_hook() - hook.run(sql) + hook.run(self._sql) def on_kill(self) -> None: # NB: on_kill isn't required for this operator since query cancelling gets # handled in `DatabricksSqlHook.run()` method which is called in `execute()` ... + + def _parse_input_dataset(self) -> tuple[list[Any], list[Any]]: + """Parse file_location to build the input dataset.""" + from airflow.providers.common.compat.openlineage.facet import Dataset, Error + + input_datasets: list[Dataset] = [] + extraction_errors: list[Error] = [] + + if not self.file_location: + return input_datasets, extraction_errors + + try: + from urllib.parse import urlparse + + parsed_uri = urlparse(self.file_location) + # Only process known schemes + if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs", "azure", "abfss", "wasbs"): + raise ValueError(f"Unsupported scheme: {parsed_uri.scheme}") + + scheme = parsed_uri.scheme + namespace = f"{scheme}://{parsed_uri.netloc}" + path = parsed_uri.path.lstrip("/") or "/" + input_datasets.append(Dataset(namespace=namespace, name=path)) + except Exception as e: + self.log.error("Failed to parse file_location: %s, error: %s", self.file_location, str(e)) + extraction_errors.append( + Error(errorMessage=str(e), stackTrace=None, task=self.file_location, taskNumber=None) + ) + + return input_datasets, extraction_errors + + def _create_sql_job_facet(self) -> tuple[dict, list[Any]]: + """Create SQL job facet from the SQL query.""" + from airflow.providers.common.compat.openlineage.facet import Error, SQLJobFacet + from airflow.providers.openlineage.sqlparser import SQLParser + + job_facets = {} + extraction_errors: list[Error] = [] + + try: + import re + + normalized_sql = SQLParser.normalize_sql(self._sql) + normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ", normalized_sql)) + job_facets["sql"] = SQLJobFacet(query=normalized_sql) + except Exception as e: + self.log.error("Failed creating SQL job facet: %s", str(e)) + extraction_errors.append( + Error(errorMessage=str(e), stackTrace=None, task="sql_facet_creation", taskNumber=None) + ) + + return job_facets, extraction_errors + + def _build_output_dataset(self) -> tuple[Any, list[Any]]: + """Build output dataset from table information.""" + from airflow.providers.common.compat.openlineage.facet import Dataset, Error + + output_dataset = None + extraction_errors: list[Error] = [] + + if not self.table_name: + return output_dataset, extraction_errors + + try: + table_parts = self.table_name.split(".") + if len(table_parts) == 3: # catalog.schema.table + catalog, schema, table = table_parts + elif len(table_parts) == 2: # schema.table + catalog = None + schema, table = table_parts + else: + catalog = None + schema = None + table = self.table_name + + hook = self._get_hook() + conn = hook.get_connection(hook.databricks_conn_id) + output_namespace = f"databricks://{conn.host}" + + # Combine schema/table with optional catalog for final dataset name + fq_name = table + if schema: + fq_name = f"{schema}.{fq_name}" + if catalog: + fq_name = f"{catalog}.{fq_name}" + + output_dataset = Dataset(namespace=output_namespace, name=fq_name) + except Exception as e: + self.log.error("Failed to construct output dataset: %s", str(e)) + extraction_errors.append( + Error( + errorMessage=str(e), + stackTrace=None, + task="output_dataset_construction", + taskNumber=None, + ) + ) + + return output_dataset, extraction_errors + + def get_openlineage_facets_on_complete(self, task_instance): + """ + Compute OpenLineage facets for the COPY INTO command. + + Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and build an + input dataset list and an output dataset (the Delta table). + """ + from airflow.providers.common.compat.openlineage.facet import ExtractionErrorRunFacet + from airflow.providers.openlineage.extractors import OperatorLineage + + if not self._sql: + self.log.warning("No SQL query found, returning empty OperatorLineage.") + return OperatorLineage() + + # Get input datasets and any parsing errors + input_datasets, extraction_errors = self._parse_input_dataset() + + # Create SQL job facet + job_facets, sql_errors = self._create_sql_job_facet() + extraction_errors.extend(sql_errors) + + run_facets = {} + if extraction_errors: + run_facets["extractionError"] = ExtractionErrorRunFacet( + totalTasks=1, + failedTasks=len(extraction_errors), + errors=extraction_errors, + ) + # Return only error facets for invalid URIs + return OperatorLineage( + inputs=[], + outputs=[], + job_facets=job_facets, + run_facets=run_facets, + ) + + # Build output dataset + output_dataset, output_errors = self._build_output_dataset() + if output_errors: + extraction_errors.extend(output_errors) + run_facets["extractionError"] = ExtractionErrorRunFacet( + totalTasks=1, + failedTasks=len(extraction_errors), + errors=extraction_errors, + ) + + return OperatorLineage( + inputs=input_datasets, + outputs=[output_dataset] if output_dataset else [], + job_facets=job_facets, + run_facets=run_facets, + ) diff --git a/providers/tests/databricks/operators/test_databricks_copy.py b/providers/tests/databricks/operators/test_databricks_copy.py index 60b8cd82c14fb..005d8b21e3270 100644 --- a/providers/tests/databricks/operators/test_databricks_copy.py +++ b/providers/tests/databricks/operators/test_databricks_copy.py @@ -17,10 +17,17 @@ # under the License. from __future__ import annotations +from unittest import mock + import pytest from airflow.exceptions import AirflowException +from airflow.providers.common.compat.openlineage.facet import ( + Dataset, + SQLJobFacet, +) from airflow.providers.databricks.operators.databricks_sql import DatabricksCopyIntoOperator +from airflow.providers.openlineage.extractors import OperatorLineage DATE = "2017-04-20" TASK_ID = "databricks-sql-operator" @@ -140,10 +147,8 @@ def test_copy_with_encryption_and_credential(): assert ( op._create_sql_query() == f"""COPY INTO test -FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') """ - """ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc')) -FILEFORMAT = CSV -""".strip() +FROM '{COPY_FILE_LOCATION}' WITH (CREDENTIAL (AZURE_SAS_TOKEN = 'abc') ENCRYPTION (TYPE = 'AWS_SSE_C', MASTER_KEY = 'abc')) +FILEFORMAT = CSV""".strip() ) @@ -253,3 +258,150 @@ def test_templating(create_task_instance_of_operator, session): assert task.files == "files" assert task.table_name == "table-name" assert task.databricks_conn_id == "databricks-conn-id" + + +@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") +def test_get_openlineage_facets_on_complete_s3(mock_hook): + """Test OpenLineage facets generation for S3 source.""" + mock_hook().run.return_value = [ + {"file": "s3://bucket/dir1/file1.csv"}, + {"file": "s3://bucket/dir1/file2.csv"}, + ] + mock_hook().get_connection().host = "databricks.com" + + op = DatabricksCopyIntoOperator( + task_id="test", + table_name="schema.table", + file_location="s3://bucket/dir1", + file_format="CSV", + ) + op._sql = "COPY INTO schema.table FROM 's3://bucket/dir1'" + op._result = mock_hook().run.return_value + + lineage = op.get_openlineage_facets_on_complete(None) + + assert lineage == OperatorLineage( + inputs=[Dataset(namespace="s3://bucket", name="dir1")], + outputs=[Dataset(namespace="databricks://databricks.com", name="schema.table")], + job_facets={"sql": SQLJobFacet(query="COPY INTO schema.table FROM 's3://bucket/dir1'")}, + run_facets={}, + ) + + +@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") +def test_get_openlineage_facets_on_complete_with_errors(mock_hook): + """Test OpenLineage facets generation with extraction errors.""" + mock_hook().run.return_value = [ + {"file": "s3://bucket/dir1/file1.csv"}, + {"file": "invalid://location/file.csv"}, # Invalid URI + {"file": "azure://account.invalid.windows.net/container/file.csv"}, # Invalid Azure URI + ] + mock_hook().get_connection().host = "databricks.com" + + op = DatabricksCopyIntoOperator( + task_id="test", + table_name="schema.table", + file_location="s3://bucket/dir1", + file_format="CSV", + ) + op._sql = "COPY INTO schema.table FROM 's3://bucket/dir1'" + op._result = mock_hook().run.return_value + + lineage = op.get_openlineage_facets_on_complete(None) + + # Check inputs and outputs + assert len(lineage.inputs) == 1 + assert lineage.inputs[0].namespace == "s3://bucket" + assert lineage.inputs[0].name == "dir1" + + assert len(lineage.outputs) == 1 + assert lineage.outputs[0].namespace == "databricks://databricks.com" + assert lineage.outputs[0].name == "schema.table" + + # Check facets exist and have correct structure + assert "sql" in lineage.job_facets + assert lineage.job_facets["sql"].query == "COPY INTO schema.table FROM 's3://bucket/dir1'" + + assert "extractionError" not in lineage.run_facets + + +@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") +def test_get_openlineage_facets_on_complete_no_sql(mock_hook): + """Test OpenLineage facets generation when no SQL is available.""" + op = DatabricksCopyIntoOperator( + task_id="test", + table_name="schema.table", + file_location="s3://bucket/dir1", + file_format="CSV", + ) + + lineage = op.get_openlineage_facets_on_complete(None) + assert lineage == OperatorLineage() + + +@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") +def test_get_openlineage_facets_on_complete_gcs(mock_hook): + """Test OpenLineage facets generation specifically for GCS paths.""" + mock_hook().run.return_value = [ + {"file": "gs://bucket1/dir1/file1.csv"}, + {"file": "gs://bucket1/dir2/nested/file2.csv"}, + {"file": "gs://bucket2/file3.csv"}, + {"file": "gs://bucket2"}, # Edge case: root path + {"file": "gs://invalid-bucket/@#$%"}, # Invalid path + ] + mock_hook().get_connection.return_value.host = "databricks.com" + mock_hook().query_ids = ["query_123"] + + op = DatabricksCopyIntoOperator( + task_id="test", + table_name="catalog.schema.table", + file_location="gs://location", + file_format="CSV", + ) + op.execute(None) + result = op.get_openlineage_facets_on_complete(None) + + # Check inputs - only one input from file_location + assert len(result.inputs) == 1 + assert result.inputs[0].namespace == "gs://location" + assert result.inputs[0].name == "/" + + # Check outputs + assert len(result.outputs) == 1 + assert result.outputs[0].namespace == "databricks://databricks.com" + assert result.outputs[0].name == "catalog.schema.table" + + # Check SQL job facet + assert "sql" in result.job_facets + assert "COPY INTO catalog.schema.table" in result.job_facets["sql"].query + assert "FILEFORMAT = CSV" in result.job_facets["sql"].query + + +@mock.patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook") +def test_get_openlineage_facets_on_complete_invalid_location(mock_hook): + """Test OpenLineage facets generation with invalid file_location.""" + mock_hook().get_connection().host = "databricks.com" + + op = DatabricksCopyIntoOperator( + task_id="test", + table_name="schema.table", + file_location="invalid://location", # Invalid location + file_format="CSV", + ) + op._sql = "COPY INTO schema.table FROM 'invalid://location'" + op._result = [{"file": "s3://bucket/file.csv"}] + + lineage = op.get_openlineage_facets_on_complete(None) + + # Should have no inputs due to invalid location + assert len(lineage.inputs) == 0 + + # Should not have output and SQL facets + assert len(lineage.outputs) == 0 + assert "sql" in lineage.job_facets + + # Should have extraction error facet + assert "extractionError" in lineage.run_facets + assert lineage.run_facets["extractionError"].totalTasks == 1 + assert lineage.run_facets["extractionError"].failedTasks == 1 + assert len(lineage.run_facets["extractionError"].errors) == 1