Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -344,10 +344,9 @@ def get_openlineage_default_schema(self) -> str | None:

def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLineage | None:
"""
Generate OpenLineage metadata for a Databricks task instance based on executed query IDs.
Emit separate OpenLineage events for each Databricks query, based on executed query IDs.

If a single query ID is present, attach an `ExternalQueryRunFacet` to the lineage metadata.
If multiple query IDs are present, emits separate OpenLineage events for each query instead.
If a single query ID is present, also add an `ExternalQueryRunFacet` to the returned lineage metadata.

Note that `get_openlineage_database_specific_lineage` is usually called after task's execution,
so if multiple query IDs are present, both START and COMPLETE event for each query will be emitted
Expand All @@ -368,13 +367,22 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi
from airflow.providers.openlineage.sqlparser import SQLParser

if not self.query_ids:
self.log.debug("openlineage: no databricks query ids found.")
self.log.info("OpenLineage could not find databricks query ids.")
return None

self.log.debug("openlineage: getting connection to get database info")
connection = self.get_connection(self.get_conn_id())
namespace = SQLParser.create_namespace(self.get_openlineage_database_info(connection))

self.log.info("Separate OpenLineage events will be emitted for each Databricks query_id.")
emit_openlineage_events_for_databricks_queries(
task_instance=task_instance,
hook=self,
query_ids=self.query_ids,
query_for_extra_metadata=True,
query_source_namespace=namespace,
)

if len(self.query_ids) == 1:
self.log.debug("Attaching ExternalQueryRunFacet with single query_id to OpenLineage event.")
return OperatorLineage(
Expand All @@ -385,12 +393,4 @@ def get_openlineage_database_specific_lineage(self, task_instance) -> OperatorLi
}
)

self.log.info("Multiple query_ids found. Separate OpenLineage event will be emitted for each query.")
emit_openlineage_events_for_databricks_queries(
query_ids=self.query_ids,
query_source_namespace=namespace,
task_instance=task_instance,
hook=self,
)

return None
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from openlineage.client.event_v2 import RunEvent
from openlineage.client.facet_v2 import JobFacet

from airflow.providers.databricks.hooks.databricks import DatabricksHook
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -121,69 +122,61 @@ def _get_parent_run_facet(task_instance):
)


def _run_api_call(hook: DatabricksSqlHook, query_ids: list[str]) -> list[dict]:
def _run_api_call(hook: DatabricksSqlHook | DatabricksHook, query_ids: list[str]) -> list[dict]:
"""Retrieve execution details for specific queries from Databricks's query history API."""
if not hook._token:
# This has logic for token initialization
hook.get_conn()

# https://docs.databricks.com/api/azure/workspace/queryhistory/list
try:
token = hook._get_token(raise_error=True)
# https://docs.databricks.com/api/azure/workspace/queryhistory/list
response = requests.get(
url=f"https://{hook.host}/api/2.0/sql/history/queries",
headers={"Authorization": f"Bearer {hook._token}"},
headers={"Authorization": f"Bearer {token}"},
data=json.dumps({"filter_by": {"statement_ids": query_ids}}),
timeout=2,
)
response.raise_for_status()
except Exception as e:
log.warning(
"OpenLineage could not retrieve Databricks queries details. Error received: `%s`.",
e,
)
return []

if response.status_code != 200:
log.warning(
"OpenLineage could not retrieve Databricks queries details. API error received: `%s`: `%s`",
response.status_code,
response.text,
)
return []

return response.json()["res"]


def _process_data_from_api(data: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Convert timestamp fields to UTC datetime objects."""
for row in data:
for key in ("query_start_time_ms", "query_end_time_ms"):
row[key] = datetime.datetime.fromtimestamp(row[key] / 1000, tz=datetime.timezone.utc)

return data


def _get_queries_details_from_databricks(
hook: DatabricksSqlHook, query_ids: list[str]
hook: DatabricksSqlHook | DatabricksHook, query_ids: list[str]
) -> dict[str, dict[str, Any]]:
if not query_ids:
return {}

queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids)

query_details = {}
for query_info in queries_info_from_api:
if not query_info.get("query_id"):
log.debug("Databricks query ID not found in API response.")
continue

q_start_time = None
q_end_time = None
if query_info.get("query_start_time_ms") and query_info.get("query_end_time_ms"):
q_start_time = datetime.datetime.fromtimestamp(
query_info["query_start_time_ms"] / 1000, tz=datetime.timezone.utc
)
q_end_time = datetime.datetime.fromtimestamp(
query_info["query_end_time_ms"] / 1000, tz=datetime.timezone.utc
)

query_details[query_info["query_id"]] = {
"status": query_info.get("status"),
"start_time": q_start_time,
"end_time": q_end_time,
"query_text": query_info.get("query_text"),
"error_message": query_info.get("error_message"),
try:
queries_info_from_api = _run_api_call(hook=hook, query_ids=query_ids)
queries_info_from_api = _process_data_from_api(queries_info_from_api)

query_details = {
query_info["query_id"]: {
"status": query_info.get("status"),
"start_time": query_info.get("query_start_time_ms"),
"end_time": query_info.get("query_end_time_ms"),
"query_text": query_info.get("query_text"),
"error_message": query_info.get("error_message"),
}
for query_info in queries_info_from_api
if query_info["query_id"]
}
except Exception as e:
log.warning("OpenLineage could not retrieve extra metadata from Databricks. Error encountered: %s", e)

return query_details

Expand Down Expand Up @@ -221,17 +214,18 @@ def _create_ol_event_pair(

@require_openlineage_version(provider_min_version="2.3.0")
def emit_openlineage_events_for_databricks_queries(
query_ids: list[str],
query_source_namespace: str,
task_instance,
hook: DatabricksSqlHook | None = None,
hook: DatabricksSqlHook | DatabricksHook | None = None,
query_ids: list[str] | None = None,
query_source_namespace: str | None = None,
query_for_extra_metadata: bool = False,
additional_run_facets: dict | None = None,
additional_job_facets: dict | None = None,
) -> None:
"""
Emit OpenLineage events for executed Databricks queries.

Metadata retrieval from Databricks is attempted only if a `DatabricksSqlHook` is provided.
Metadata retrieval from Databricks is attempted only if `get_extra_metadata` is True and hook is provided.
If metadata is available, execution details such as start time, end time, execution status,
error messages, and SQL text are included in the events. If no metadata is found, the function
defaults to using the Airflow task instance's state and the current timestamp.
Expand All @@ -241,10 +235,16 @@ def emit_openlineage_events_for_databricks_queries(
will correspond to actual query execution times.

Args:
query_ids: A list of Databricks query IDs to emit events for.
query_source_namespace: The namespace to be included in ExternalQueryRunFacet.
task_instance: The Airflow task instance that run these queries.
hook: A hook instance used to retrieve query metadata if available.
hook: A supported Databricks hook instance used to retrieve query metadata if available.
If omitted, `query_ids` and `query_source_namespace` must be provided explicitly and
`query_for_extra_metadata` must be `False`.
query_ids: A list of Databricks query IDs to emit events for, can only be None if `hook` is provided
and `hook.query_ids` are present (DatabricksHook does not store query_ids).
query_source_namespace: The namespace to be included in ExternalQueryRunFacet,
can be `None` only if hook is provided.
query_for_extra_metadata: Whether to query Databricks for additional metadata about queries.
Must be `False` if `hook` is not provided.
additional_run_facets: Additional run facets to include in OpenLineage events.
additional_job_facets: Additional job facets to include in OpenLineage events.
"""
Expand All @@ -259,25 +259,52 @@ def emit_openlineage_events_for_databricks_queries(
from airflow.providers.openlineage.conf import namespace
from airflow.providers.openlineage.plugins.listener import get_openlineage_listener

if not query_ids:
log.debug("No Databricks query IDs provided; skipping OpenLineage event emission.")
return

query_ids = [q for q in query_ids] # Make a copy to make sure it does not change
log.info("OpenLineage will emit events for Databricks queries.")

if hook:
if not query_ids:
log.debug("No Databricks query IDs provided; Checking `hook.query_ids` property.")
query_ids = getattr(hook, "query_ids", [])
if not query_ids:
raise ValueError("No Databricks query IDs provided and `hook.query_ids` are not present.")

if not query_source_namespace:
log.debug("No Databricks query namespace provided; Creating one from scratch.")

if hasattr(hook, "get_openlineage_database_info") and hasattr(hook, "get_conn_id"):
from airflow.providers.openlineage.sqlparser import SQLParser

query_source_namespace = SQLParser.create_namespace(
hook.get_openlineage_database_info(hook.get_connection(hook.get_conn_id()))
)
else:
query_source_namespace = f"databricks://{hook.host}" if hook.host else "databricks"
else:
if not query_ids:
raise ValueError("If 'hook' is not provided, 'query_ids' must be set.")
if not query_source_namespace:
raise ValueError("If 'hook' is not provided, 'query_source_namespace' must be set.")
if query_for_extra_metadata:
raise ValueError("If 'hook' is not provided, 'query_for_extra_metadata' must be False.")

query_ids = [q for q in query_ids] # Make a copy to make sure we do not change hook's attribute

if query_for_extra_metadata and hook:
log.debug("Retrieving metadata for %s queries from Databricks.", len(query_ids))
databricks_metadata = _get_queries_details_from_databricks(hook, query_ids)
else:
log.debug("DatabricksSqlHook not provided. No extra metadata fill be fetched from Databricks.")
log.debug("`query_for_extra_metadata` is False. No extra metadata fill be fetched from Databricks.")
databricks_metadata = {}

# If real metadata is unavailable, we send events with eventTime=now
default_event_time = timezone.utcnow()
# If no query metadata is provided, we use task_instance's state when checking for success
# ti.state has no `value` attr (AF2) when task it's still running, in AF3 we get 'running', in that case
# assuming it's user call and query succeeded, so we replace it with success.
# Adjust state for DBX logic, where "finished" means "success"
default_state = task_instance.state.value if hasattr(task_instance, "state") else ""
default_state = "finished" if default_state == "success" else default_state
default_state = (
getattr(task_instance.state, "value", "running") if hasattr(task_instance, "state") else ""
)
default_state = "finished" if default_state in ("running", "success") else default_state

log.debug("Generating OpenLineage facets")
common_run_facets = {"parent": _get_parent_run_facet(task_instance)}
Expand Down Expand Up @@ -318,10 +345,10 @@ def emit_openlineage_events_for_databricks_queries(
event_batch = _create_ol_event_pair(
job_namespace=namespace(),
job_name=f"{task_instance.dag_id}.{task_instance.task_id}.query.{counter}",
start_time=query_metadata.get("start_time", default_event_time), # type: ignore[arg-type]
end_time=query_metadata.get("end_time", default_event_time), # type: ignore[arg-type]
start_time=query_metadata.get("start_time") or default_event_time, # type: ignore[arg-type]
end_time=query_metadata.get("end_time") or default_event_time, # type: ignore[arg-type]
# Only finished status means it completed without failures
is_successful=query_metadata.get("status", default_state).lower() == "finished",
is_successful=(query_metadata.get("status") or default_state).lower() == "finished",
run_facets={**query_specific_run_facets, **common_run_facets, **additional_run_facets},
job_facets={**query_specific_job_facets, **common_job_facets, **additional_job_facets},
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,8 @@ def test_get_openlineage_database_specific_lineage_with_no_query_id():
assert result is None


def test_get_openlineage_database_specific_lineage_with_single_query_id():
@mock.patch("airflow.providers.databricks.utils.openlineage.emit_openlineage_events_for_databricks_queries")
def test_get_openlineage_database_specific_lineage_with_single_query_id(mock_emit):
from airflow.providers.common.compat.openlineage.facet import ExternalQueryRunFacet
from airflow.providers.openlineage.extractors import OperatorLineage

Expand All @@ -466,7 +467,18 @@ def test_get_openlineage_database_specific_lineage_with_single_query_id():
hook.get_connection = mock.MagicMock()
hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme")

result = hook.get_openlineage_database_specific_lineage(None)
ti = mock.MagicMock()

result = hook.get_openlineage_database_specific_lineage(ti)
mock_emit.assert_called_once_with(
**{
"hook": hook,
"query_ids": ["query1"],
"query_source_namespace": "scheme://auth",
"task_instance": ti,
"query_for_extra_metadata": True,
}
)
assert result == OperatorLineage(
run_facets={"externalQuery": ExternalQueryRunFacet(externalQueryId="query1", source="scheme://auth")}
)
Expand All @@ -488,6 +500,7 @@ def test_get_openlineage_database_specific_lineage_with_multiple_query_ids(mock_
"query_ids": ["query1", "query2"],
"query_source_namespace": "scheme://auth",
"task_instance": ti,
"query_for_extra_metadata": True,
}
)
assert result is None
Expand Down
Loading