diff --git a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py index 76aea1dde5434..0ace06abc998b 100644 --- a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py +++ b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py @@ -344,10 +344,9 @@ def get_openlineage_default_schema(self) -> str | None: def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None: """ - Generate OpenLineage metadata for a Databricks task instance based on executed query IDs. + Emit separate OpenLineage events for each Databricks query, based on executed query IDs. - If a single query ID is present, attach an `ExternalQueryRunFacet` to the lineage metadata. - If multiple query IDs are present, emits separate OpenLineage events for each query instead. + If a single query ID is present, also add an `ExternalQueryRunFacet` to the returned lineage metadata. Note that `get_openlineage_database_specific_lineage` is usually called after task's execution, so if multiple query IDs are present, both START and COMPLETE event for each query will be emitted @@ -368,13 +367,22 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi from airflow.providers.openlineage.sqlparser import SQLParser if not self.query_ids: - self.log.debug("openlineage: no databricks query ids found.") + self.log.info("OpenLineage could not find databricks query ids.") return None self.log.debug("openlineage: getting connection to get database info") connection = self.get_connection(self.get_conn_id()) namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection)) + self.log.info("Separate OpenLineage events will be emitted for each Databricks query_id.") + emit_openlineage_events_for_databricks_queries( + task_instance=task_instance, + hook=self, + query_ids=self.query_ids, + query_for_extra_metadata=True, + query_source_namespace=namespace, + ) + if len(self.query_ids) == 1: self.log.debug("Attaching ExternalQueryRunFacet with single query_id to OpenLineage event.") return OperatorLineage( @@ -385,12 +393,4 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi } ) - self.log.info("Multiple query_ids found. Separate OpenLineage event will be emitted for each query.") - emit_openlineage_events_for_databricks_queries( - query_ids=self.query_ids, - query_source_namespace=namespace, - task_instance=task_instance, - hook=self, - ) - return None diff --git a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py index 5d2fe3997039d..57c964d984258 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py @@ -31,6 +31,7 @@ from openlineage.client.event_v2 import RunEvent from openlineage.client.facet_v2 import JobFacet + from airflow.providers.databricks.hooks.databricks import DatabricksHook from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook log = logging.getLogger(__name__) @@ -121,20 +122,18 @@ def _get_parent_run_facet(task_instance): ) -def _run_api_call(hook: DatabricksSqlHook, query_ids: list[str]) -> list[dict]: +def _run_api_call(hook: DatabricksSqlHook | DatabricksHook, query_ids: list[str]) -> list[dict]: """Retrieve execution details for specific queries from Databricks's query history API.""" - if not hook._token: - # This has logic for token initialization - hook.get_conn() - - # https://docs.databricks.com/api/azure/workspace/queryhistory/list try: + token = hook._get_token(raise_error=True) + # https://docs.databricks.com/api/azure/workspace/queryhistory/list response = requests.get( url=f"https://{hook.host}/api/2.0/sql/history/queries", - headers={"Authorization": f"Bearer {hook._token}"}, + headers={"Authorization": f"Bearer {token}"}, data=json.dumps({"filter_by": {"statement_ids": query_ids}}), timeout=2, ) + response.raise_for_status() except Exception as e: log.warning( "OpenLineage could not retrieve Databricks queries details. Error received: `%s`.", @@ -142,48 +141,42 @@ def _run_api_call(hook: DatabricksSqlHook, query_ids: list[str]) -> list[dict]: ) return [] - if response.status_code != 200: - log.warning( - "OpenLineage could not retrieve Databricks queries details. API error received: `%s`: `%s`", - response.status_code, - response.text, - ) - return [] - return response.json()["res"] +def _process_data_from_api(data: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Convert timestamp fields to UTC datetime objects.""" + for row in data: + for key in ("query_start_time_ms", "query_end_time_ms"): + row[key] = datetime.datetime.fromtimestamp(row[key] / 1000, tz=datetime.timezone.utc) + + return data + + def _get_queries_details_from_databricks( - hook: DatabricksSqlHook, query_ids: list[str] + hook: DatabricksSqlHook | DatabricksHook, query_ids: list[str] ) -> dict[str, dict[str, Any]]: if not query_ids: return {} - queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids) - query_details = {} - for query_info in queries_info_from_api: - if not query_info.get("query_id"): - log.debug("Databricks query ID not found in API response.") - continue - - q_start_time = None - q_end_time = None - if query_info.get("query_start_time_ms") and query_info.get("query_end_time_ms"): - q_start_time = datetime.datetime.fromtimestamp( - query_info["query_start_time_ms"] / 1000, tz=datetime.timezone.utc - ) - q_end_time = datetime.datetime.fromtimestamp( - query_info["query_end_time_ms"] / 1000, tz=datetime.timezone.utc - ) - - query_details[query_info["query_id"]] = { - "status": query_info.get("status"), - "start_time": q_start_time, - "end_time": q_end_time, - "query_text": query_info.get("query_text"), - "error_message": query_info.get("error_message"), + try: + queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids) + queries_info_from_api = _process_data_from_api(queries_info_from_api) + + query_details = { + query_info["query_id"]: { + "status": query_info.get("status"), + "start_time": query_info.get("query_start_time_ms"), + "end_time": query_info.get("query_end_time_ms"), + "query_text": query_info.get("query_text"), + "error_message": query_info.get("error_message"), + } + for query_info in queries_info_from_api + if query_info["query_id"] } + except Exception as e: + log.warning("OpenLineage could not retrieve extra metadata from Databricks. Error encountered: %s", e) return query_details @@ -221,17 +214,18 @@ def _create_ol_event_pair( @require_openlineage_version(provider_min_version="2.3.0") def emit_openlineage_events_for_databricks_queries( - query_ids: list[str], - query_source_namespace: str, task_instance, - hook: DatabricksSqlHook | None = None, + hook: DatabricksSqlHook | DatabricksHook | None = None, + query_ids: list[str] | None = None, + query_source_namespace: str | None = None, + query_for_extra_metadata: bool = False, additional_run_facets: dict | None = None, additional_job_facets: dict | None = None, ) -> None: """ Emit OpenLineage events for executed Databricks queries. - Metadata retrieval from Databricks is attempted only if a `DatabricksSqlHook` is provided. + Metadata retrieval from Databricks is attempted only if `get_extra_metadata` is True and hook is provided. If metadata is available, execution details such as start time, end time, execution status, error messages, and SQL text are included in the events. If no metadata is found, the function defaults to using the Airflow task instance's state and the current timestamp. @@ -241,10 +235,16 @@ def emit_openlineage_events_for_databricks_queries( will correspond to actual query execution times. Args: - query_ids: A list of Databricks query IDs to emit events for. - query_source_namespace: The namespace to be included in ExternalQueryRunFacet. task_instance: The Airflow task instance that run these queries. - hook: A hook instance used to retrieve query metadata if available. + hook: A supported Databricks hook instance used to retrieve query metadata if available. + If omitted, `query_ids` and `query_source_namespace` must be provided explicitly and + `query_for_extra_metadata` must be `False`. + query_ids: A list of Databricks query IDs to emit events for, can only be None if `hook` is provided + and `hook.query_ids` are present (DatabricksHook does not store query_ids). + query_source_namespace: The namespace to be included in ExternalQueryRunFacet, + can be `None` only if hook is provided. + query_for_extra_metadata: Whether to query Databricks for additional metadata about queries. + Must be `False` if `hook` is not provided. additional_run_facets: Additional run facets to include in OpenLineage events. additional_job_facets: Additional job facets to include in OpenLineage events. """ @@ -259,25 +259,52 @@ def emit_openlineage_events_for_databricks_queries( from airflow.providers.openlineage.conf import namespace from airflow.providers.openlineage.plugins.listener import get_openlineage_listener - if not query_ids: - log.debug("No Databricks query IDs provided; skipping OpenLineage event emission.") - return - - query_ids = [q for q in query_ids] # Make a copy to make sure it does not change + log.info("OpenLineage will emit events for Databricks queries.") if hook: + if not query_ids: + log.debug("No Databricks query IDs provided; Checking `hook.query_ids` property.") + query_ids = getattr(hook, "query_ids", []) + if not query_ids: + raise ValueError("No Databricks query IDs provided and `hook.query_ids` are not present.") + + if not query_source_namespace: + log.debug("No Databricks query namespace provided; Creating one from scratch.") + + if hasattr(hook, "get_openlineage_database_info") and hasattr(hook, "get_conn_id"): + from airflow.providers.openlineage.sqlparser import SQLParser + + query_source_namespace = SQLParser.create_namespace( + hook.get_openlineage_database_info(hook.get_connection(hook.get_conn_id())) + ) + else: + query_source_namespace = f"databricks://{hook.host}" if hook.host else "databricks" + else: + if not query_ids: + raise ValueError("If 'hook' is not provided, 'query_ids' must be set.") + if not query_source_namespace: + raise ValueError("If 'hook' is not provided, 'query_source_namespace' must be set.") + if query_for_extra_metadata: + raise ValueError("If 'hook' is not provided, 'query_for_extra_metadata' must be False.") + + query_ids = [q for q in query_ids] # Make a copy to make sure we do not change hook's attribute + + if query_for_extra_metadata and hook: log.debug("Retrieving metadata for %s queries from Databricks.", len(query_ids)) databricks_metadata = _get_queries_details_from_databricks(hook, query_ids) else: - log.debug("DatabricksSqlHook not provided. No extra metadata fill be fetched from Databricks.") + log.debug("`query_for_extra_metadata` is False. No extra metadata fill be fetched from Databricks.") databricks_metadata = {} # If real metadata is unavailable, we send events with eventTime=now default_event_time = timezone.utcnow() - # If no query metadata is provided, we use task_instance's state when checking for success + # ti.state has no `value` attr (AF2) when task it's still running, in AF3 we get 'running', in that case + # assuming it's user call and query succeeded, so we replace it with success. # Adjust state for DBX logic, where "finished" means "success" - default_state = task_instance.state.value if hasattr(task_instance, "state") else "" - default_state = "finished" if default_state == "success" else default_state + default_state = ( + getattr(task_instance.state, "value", "running") if hasattr(task_instance, "state") else "" + ) + default_state = "finished" if default_state in ("running", "success") else default_state log.debug("Generating OpenLineage facets") common_run_facets = {"parent": _get_parent_run_facet(task_instance)} @@ -318,10 +345,10 @@ def emit_openlineage_events_for_databricks_queries( event_batch = _create_ol_event_pair( job_namespace=namespace(), job_name=f"{task_instance.dag_id}.{task_instance.task_id}.query.{counter}", - start_time=query_metadata.get("start_time", default_event_time), # type: ignore[arg-type] - end_time=query_metadata.get("end_time", default_event_time), # type: ignore[arg-type] + start_time=query_metadata.get("start_time") or default_event_time, # type: ignore[arg-type] + end_time=query_metadata.get("end_time") or default_event_time, # type: ignore[arg-type] # Only finished status means it completed without failures - is_successful=query_metadata.get("status", default_state).lower() == "finished", + is_successful=(query_metadata.get("status") or default_state).lower() == "finished", run_facets={**query_specific_run_facets, **common_run_facets, **additional_run_facets}, job_facets={**query_specific_job_facets, **common_job_facets, **additional_job_facets}, ) diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py index 7449489c77d7b..20e2c081f740d 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py @@ -457,7 +457,8 @@ def test_get_openlineage_database_specific_lineage_with_no_query_id(): assert result is None -def test_get_openlineage_database_specific_lineage_with_single_query_id(): +@mock.patch("airflow.providers.databricks.utils.openlineage.emit_openlineage_events_for_databricks_queries") +def test_get_openlineage_database_specific_lineage_with_single_query_id(mock_emit): from airflow.providers.common.compat.openlineage.facet import ExternalQueryRunFacet from airflow.providers.openlineage.extractors import OperatorLineage @@ -466,7 +467,18 @@ def test_get_openlineage_database_specific_lineage_with_single_query_id(): hook.get_connection = mock.MagicMock() hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme") - result = hook.get_openlineage_database_specific_lineage(None) + ti = mock.MagicMock() + + result = hook.get_openlineage_database_specific_lineage(ti) + mock_emit.assert_called_once_with( + **{ + "hook": hook, + "query_ids": ["query1"], + "query_source_namespace": "scheme://auth", + "task_instance": ti, + "query_for_extra_metadata": True, + } + ) assert result == OperatorLineage( run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="query1", source="scheme://auth")} ) @@ -488,6 +500,7 @@ def test_get_openlineage_database_specific_lineage_with_multiple_query_ids(mock_ "query_ids": ["query1", "query2"], "query_source_namespace": "scheme://auth", "task_instance": ti, + "query_for_extra_metadata": True, } ) assert result is None diff --git a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py index 699e018785ee9..6d427e0ba7775 100644 --- a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py +++ b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py @@ -30,11 +30,14 @@ ExternalQueryRunFacet, SQLJobFacet, ) +from airflow.providers.databricks.hooks.databricks import DatabricksHook +from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook from airflow.providers.databricks.utils.openlineage import ( _create_ol_event_pair, _get_ol_run_id, _get_parent_run_facet, _get_queries_details_from_databricks, + _process_data_from_api, _run_api_call, emit_openlineage_events_for_databricks_queries, ) @@ -96,7 +99,7 @@ def test_get_parent_run_facet(): def test_run_api_call_success(): mock_hook = mock.MagicMock() - mock_hook._token = "mock_token" + mock_hook._get_token.return_value = "mock_token" mock_hook.host = "mock_host" mock_response = mock.MagicMock() @@ -109,14 +112,27 @@ def test_run_api_call_success(): assert result == [{"query_id": "123", "status": "success"}] -def test_run_api_call_error(): +def test_run_api_call_request_error(): mock_hook = mock.MagicMock() - mock_hook._token = "mock_token" + mock_hook._get_token.return_value = "mock_token" mock_hook.host = "mock_host" mock_response = mock.MagicMock() - mock_response.status_code = 500 - mock_response.text = "Internal Server Error" + mock_response.status_code = 200 + + with mock.patch("requests.get", side_effect=RuntimeError("request error")): + result = _run_api_call(mock_hook, ["123"]) + + assert result == [] + + +def test_run_api_call_token_error(): + mock_hook = mock.MagicMock() + mock_hook._get_token.side_effect = RuntimeError("Token error") + mock_hook.host = "mock_host" + + mock_response = mock.MagicMock() + mock_response.status_code = 200 with mock.patch("requests.get", return_value=mock_response): result = _run_api_call(mock_hook, ["123"]) @@ -124,6 +140,55 @@ def test_run_api_call_error(): assert result == [] +def test_process_data_from_api(): + data = [ + { + "query_id": "ABC", + "status": "FINISHED", + "query_start_time_ms": 1595357086200, + "query_end_time_ms": 1595357087200, + "query_text": "SELECT * FROM table1;", + "error_message": "Error occurred", + }, + { + "query_id": "DEF", + "query_start_time_ms": 1595357086200, + "query_end_time_ms": 1595357087200, + }, + ] + expected_details = [ + { + "query_id": "ABC", + "status": "FINISHED", + "query_start_time_ms": datetime.datetime( + 2020, 7, 21, 18, 44, 46, 200000, tzinfo=datetime.timezone.utc + ), + "query_end_time_ms": datetime.datetime( + 2020, 7, 21, 18, 44, 47, 200000, tzinfo=datetime.timezone.utc + ), + "query_text": "SELECT * FROM table1;", + "error_message": "Error occurred", + }, + { + "query_id": "DEF", + "query_start_time_ms": datetime.datetime( + 2020, 7, 21, 18, 44, 46, 200000, tzinfo=datetime.timezone.utc + ), + "query_end_time_ms": datetime.datetime( + 2020, 7, 21, 18, 44, 47, 200000, tzinfo=datetime.timezone.utc + ), + }, + ] + result = _process_data_from_api(data=data) + assert len(result) == 2 + assert result == expected_details + + +def test_process_data_from_api_error(): + with pytest.raises(KeyError): + _process_data_from_api(data=[{"query_start_time_ms": 1595357086200}]) + + def test_get_queries_details_from_databricks_empty_query_ids(): details = _get_queries_details_from_databricks(None, []) assert details == {} @@ -131,7 +196,7 @@ def test_get_queries_details_from_databricks_empty_query_ids(): @mock.patch("airflow.providers.databricks.utils.openlineage._run_api_call") def test_get_queries_details_from_databricks(mock_api_call): - hook = mock.MagicMock() + hook = DatabricksSqlHook() query_ids = ["ABC"] fake_result = [ { @@ -160,7 +225,7 @@ def test_get_queries_details_from_databricks(mock_api_call): @mock.patch("airflow.providers.databricks.utils.openlineage._run_api_call") def test_get_queries_details_from_databricks_no_data_found(mock_api_call): - hook = mock.MagicMock() + hook = DatabricksSqlHook() query_ids = ["ABC", "DEF"] mock_api_call.return_value = [] @@ -274,6 +339,7 @@ def test_emit_openlineage_events_for_databricks_queries(mock_now, mock_generate_ query_source_namespace="databricks_ns", task_instance=mock_ti, hook=mock.MagicMock(), + query_for_extra_metadata=True, additional_run_facets=additional_run_facets, additional_job_facets=additional_job_facets, ) @@ -448,7 +514,7 @@ def test_emit_openlineage_events_for_databricks_queries(mock_now, mock_generate_ @mock.patch("importlib.metadata.version", return_value="2.3.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") @mock.patch("airflow.utils.timezone.utcnow") -def test_emit_openlineage_events_for_databricks_queries_without_metadata_found( +def test_emit_openlineage_events_for_databricks_queries_without_metadata( mock_now, mock_generate_uuid, mock_version ): fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" @@ -489,7 +555,8 @@ def test_emit_openlineage_events_for_databricks_queries_without_metadata_found( query_ids=query_ids, query_source_namespace="databricks_ns", task_instance=mock_ti, - hook=None, # None so metadata retrieval is not triggered + hook=mock.MagicMock(), + # query_for_extra_metadata=False, # False by default additional_run_facets=additional_run_facets, additional_job_facets=additional_job_facets, ) @@ -564,9 +631,37 @@ def test_emit_openlineage_events_for_databricks_queries_without_metadata_found( @mock.patch("importlib.metadata.version", return_value="2.3.0") -def test_emit_openlineage_events_for_databricks_queries_without_query_ids(mock_version): - query_ids = [] +@mock.patch("openlineage.client.uuid.generate_new_uuid") +@mock.patch("airflow.utils.timezone.utcnow") +def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids( + mock_now, mock_generate_uuid, mock_version +): + fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" + mock_generate_uuid.return_value = fake_uuid + + default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0) + mock_now.return_value = default_event_time + + query_ids = ["query1"] + hook = mock.MagicMock() + hook.query_ids = query_ids original_query_ids = copy.deepcopy(query_ids) + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state=TaskInstanceState.RUNNING, # This will be query default state if no metadata found + dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0), + ) + mock_ti.get_template_context.return_value = { + "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0) + } + + additional_run_facets = {"custom_run": "value_run"} + additional_job_facets = {"custom_job": "value_job"} fake_adapter = mock.MagicMock() fake_adapter.emit = mock.MagicMock() @@ -578,16 +673,527 @@ def test_emit_openlineage_events_for_databricks_queries_without_query_ids(mock_v return_value=fake_listener, ): emit_openlineage_events_for_databricks_queries( - query_ids=query_ids, query_source_namespace="databricks_ns", - task_instance=None, + task_instance=mock_ti, + hook=hook, + # query_for_extra_metadata=False, # False by default + additional_run_facets=additional_run_facets, + additional_job_facets=additional_job_facets, + ) + + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + assert fake_adapter.emit.call_count == 2 # Expect two events per query. + + expected_common_job_facets = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + processingType="BATCH", + integration="DATABRICKS", + ), + "custom_job": "value_job", + } + expected_common_run_facets = { + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"), + job=parent_run.Job(namespace=namespace(), name="dag_id.task_id"), + root=parent_run.Root( + run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"), + job=parent_run.RootJob(namespace=namespace(), name="dag_id"), + ), + ), + "custom_run": "value_run", + } + + expected_calls = [ + mock.call( # Query1: START event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="databricks_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + mock.call( # Query1: COMPLETE event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.COMPLETE, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="databricks_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + ] + + assert fake_adapter.emit.call_args_list == expected_calls + + +@mock.patch( + "airflow.providers.openlineage.sqlparser.SQLParser.create_namespace", return_value="databricks_ns" +) +@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("openlineage.client.uuid.generate_new_uuid") +@mock.patch("airflow.utils.timezone.utcnow") +def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace( + mock_now, mock_generate_uuid, mock_version, mock_parser +): + fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" + mock_generate_uuid.return_value = fake_uuid + + default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0) + mock_now.return_value = default_event_time + + query_ids = ["query1"] + hook = mock.MagicMock() + hook.query_ids = query_ids + original_query_ids = copy.deepcopy(query_ids) + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state=TaskInstanceState.RUNNING, # This will be query default state if no metadata found + dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0), + ) + mock_ti.get_template_context.return_value = { + "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0) + } + + additional_run_facets = {"custom_run": "value_run"} + additional_job_facets = {"custom_job": "value_job"} + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + emit_openlineage_events_for_databricks_queries( + task_instance=mock_ti, + hook=hook, + # query_for_extra_metadata=False, # False by default + additional_run_facets=additional_run_facets, + additional_job_facets=additional_job_facets, ) + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + assert fake_adapter.emit.call_count == 2 # Expect two events per query. + + expected_common_job_facets = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + processingType="BATCH", + integration="DATABRICKS", + ), + "custom_job": "value_job", + } + expected_common_run_facets = { + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"), + job=parent_run.Job(namespace=namespace(), name="dag_id.task_id"), + root=parent_run.Root( + run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"), + job=parent_run.RootJob(namespace=namespace(), name="dag_id"), + ), + ), + "custom_run": "value_run", + } + + expected_calls = [ + mock.call( # Query1: START event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="databricks_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + mock.call( # Query1: COMPLETE event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.COMPLETE, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="databricks_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + ] + + assert fake_adapter.emit.call_args_list == expected_calls + + +@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("openlineage.client.uuid.generate_new_uuid") +@mock.patch("airflow.utils.timezone.utcnow") +def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace_raw_ns( + mock_now, mock_generate_uuid, mock_version +): + fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" + mock_generate_uuid.return_value = fake_uuid + + default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0) + mock_now.return_value = default_event_time + + query_ids = ["query1"] + hook = DatabricksHook() + hook.query_ids = query_ids + hook.host = "some_host" + original_query_ids = copy.deepcopy(query_ids) + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state=TaskInstanceState.RUNNING, # This will be query default state if no metadata found + dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0), + ) + mock_ti.get_template_context.return_value = { + "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0) + } + + additional_run_facets = {"custom_run": "value_run"} + additional_job_facets = {"custom_job": "value_job"} + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + emit_openlineage_events_for_databricks_queries( + task_instance=mock_ti, + hook=hook, + # query_for_extra_metadata=False, # False by default + additional_run_facets=additional_run_facets, + additional_job_facets=additional_job_facets, + ) + + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + assert fake_adapter.emit.call_count == 2 # Expect two events per query. + + expected_common_job_facets = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + processingType="BATCH", + integration="DATABRICKS", + ), + "custom_job": "value_job", + } + expected_common_run_facets = { + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"), + job=parent_run.Job(namespace=namespace(), name="dag_id.task_id"), + root=parent_run.Root( + run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"), + job=parent_run.RootJob(namespace=namespace(), name="dag_id"), + ), + ), + "custom_run": "value_run", + } + + expected_calls = [ + mock.call( # Query1: START event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="databricks://some_host" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + mock.call( # Query1: COMPLETE event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.COMPLETE, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="databricks://some_host" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + ] + + assert fake_adapter.emit.call_args_list == expected_calls + + +@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("openlineage.client.uuid.generate_new_uuid") +@mock.patch("airflow.utils.timezone.utcnow") +def test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_query_ids( + mock_now, mock_generate_uuid, mock_version +): + fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" + mock_generate_uuid.return_value = fake_uuid + + default_event_time = timezone.datetime(2025, 1, 5, 0, 0, 0) + mock_now.return_value = default_event_time + + hook = DatabricksSqlHook() + hook.query_ids = ["query2", "query3"] + query_ids = ["query1"] + original_query_ids = copy.deepcopy(query_ids) + logical_date = timezone.datetime(2025, 1, 1) + mock_ti = mock.MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state=TaskInstanceState.SUCCESS, # This will be query default state if no metadata found + dag_run=mock.MagicMock(logical_date=logical_date, clear_number=0), + ) + mock_ti.get_template_context.return_value = { + "dag_run": mock.MagicMock(logical_date=logical_date, clear_number=0) + } + + additional_run_facets = {"custom_run": "value_run"} + additional_job_facets = {"custom_job": "value_job"} + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + emit_openlineage_events_for_databricks_queries( + query_ids=query_ids, + query_source_namespace="databricks_ns", + task_instance=mock_ti, + hook=hook, + # query_for_extra_metadata=False, # False by default + additional_run_facets=additional_run_facets, + additional_job_facets=additional_job_facets, + ) + + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + assert fake_adapter.emit.call_count == 2 # Expect two events per query. + + expected_common_job_facets = { + "jobType": job_type_job.JobTypeJobFacet( + jobType="QUERY", + processingType="BATCH", + integration="DATABRICKS", + ), + "custom_job": "value_job", + } + expected_common_run_facets = { + "parent": parent_run.ParentRunFacet( + run=parent_run.Run(runId="01941f29-7c00-7087-8906-40e512c257bd"), + job=parent_run.Job(namespace=namespace(), name="dag_id.task_id"), + root=parent_run.Root( + run=parent_run.RootRun(runId="01941f29-7c00-743e-b109-28b18d0a19c5"), + job=parent_run.RootJob(namespace=namespace(), name="dag_id"), + ), + ), + "custom_run": "value_run", + } + + expected_calls = [ + mock.call( # Query1: START event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.START, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="databricks_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + mock.call( # Query1: COMPLETE event (no metadata) + RunEvent( + eventTime=default_event_time.isoformat(), + eventType=RunState.COMPLETE, + run=Run( + runId=fake_uuid, + facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query1", source="databricks_ns" + ), + **expected_common_run_facets, + }, + ), + job=Job( + namespace=namespace(), + name="dag_id.task_id.query.1", + facets=expected_common_job_facets, + ), + ) + ), + ] + + assert fake_adapter.emit.call_args_list == expected_calls + + +@mock.patch("importlib.metadata.version", return_value="2.3.0") +def test_emit_openlineage_events_for_databricks_queries_missing_query_ids_and_hook(mock_version): + query_ids = [] + original_query_ids = copy.deepcopy(query_ids) + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + with pytest.raises(ValueError, match="If 'hook' is not provided, 'query_ids' must be set."): + emit_openlineage_events_for_databricks_queries( + query_ids=query_ids, + query_source_namespace="databricks_ns", + task_instance=None, + ) + + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + fake_adapter.emit.assert_not_called() # No events should be emitted + + +@mock.patch("importlib.metadata.version", return_value="2.3.0") +def test_emit_openlineage_events_for_databricks_queries_missing_query_namespace_and_hook(mock_version): + query_ids = ["1", "2"] + original_query_ids = copy.deepcopy(query_ids) + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + with pytest.raises( + ValueError, match="If 'hook' is not provided, 'query_source_namespace' must be set." + ): + emit_openlineage_events_for_databricks_queries( + query_ids=query_ids, + task_instance=None, + ) + + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. + fake_adapter.emit.assert_not_called() # No events should be emitted + + +@mock.patch("importlib.metadata.version", return_value="2.3.0") +def test_emit_openlineage_events_for_databricks_queries_missing_hook_and_query_for_extra_metadata_true( + mock_version, +): + query_ids = ["1", "2"] + original_query_ids = copy.deepcopy(query_ids) + + fake_adapter = mock.MagicMock() + fake_adapter.emit = mock.MagicMock() + fake_listener = mock.MagicMock() + fake_listener.adapter = fake_adapter + + with mock.patch( + "airflow.providers.openlineage.plugins.listener.get_openlineage_listener", + return_value=fake_listener, + ): + with pytest.raises( + ValueError, match="If 'hook' is not provided, 'query_for_extra_metadata' must be False." + ): + emit_openlineage_events_for_databricks_queries( + query_ids=query_ids, + query_source_namespace="databricks_ns", + task_instance=None, + query_for_extra_metadata=True, + ) + assert query_ids == original_query_ids # Verify that the input query_ids list is unchanged. fake_adapter.emit.assert_not_called() # No events should be emitted -# emit_openlineage_events_for_databricks_queries requires OL provider 2.3.0 @mock.patch("importlib.metadata.version", return_value="1.99.0") def test_emit_openlineage_events_with_old_openlineage_provider(mock_version): query_ids = ["q1", "q2"]