Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import requests

from airflow.providers.common.compat.openlineage.check import require_openlineage_version
from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import timezone

if TYPE_CHECKING:
Expand All @@ -37,60 +36,6 @@
log = logging.getLogger(__name__)


def _get_logical_date(task_instance):
# todo: remove when min airflow version >= 3.0
if AIRFLOW_V_3_0_PLUS:
dagrun = task_instance.get_template_context()["dag_run"]
return dagrun.logical_date or dagrun.run_after

if hasattr(task_instance, "logical_date"):
date = task_instance.logical_date
else:
date = task_instance.execution_date

return date


def _get_dag_run_clear_number(task_instance):
# todo: remove when min airflow version >= 3.0
if AIRFLOW_V_3_0_PLUS:
dagrun = task_instance.get_template_context()["dag_run"]
return dagrun.clear_number
return task_instance.dag_run.clear_number


# todo: move this run_id logic into OpenLineage's listener to avoid differences
def _get_ol_run_id(task_instance) -> str:
"""
Get OpenLineage run_id from TaskInstance.

It's crucial that the task_instance's run_id creation logic matches OpenLineage's listener implementation.
Only then can we ensure that the generated run_id aligns with the Airflow task,
enabling a proper connection between events.
"""
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter

# Generate same OL run id as is generated for current task instance
return OpenLineageAdapter.build_task_instance_run_id(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
logical_date=_get_logical_date(task_instance),
try_number=task_instance.try_number,
map_index=task_instance.map_index,
)


# todo: move this run_id logic into OpenLineage's listener to avoid differences
def _get_ol_dag_run_id(task_instance) -> str:
from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter

return OpenLineageAdapter.build_dag_run_id(
dag_id=task_instance.dag_id,
logical_date=_get_logical_date(task_instance),
clear_number=_get_dag_run_clear_number(task_instance),
)


def _get_parent_run_facet(task_instance):
"""
Retrieve the ParentRunFacet associated with a specific Airflow task instance.
Expand All @@ -101,22 +46,39 @@ def _get_parent_run_facet(task_instance):
"""
from openlineage.client.facet_v2 import parent_run

from airflow.providers.openlineage.conf import namespace
from airflow.providers.openlineage.plugins.macros import (
lineage_job_name,
lineage_job_namespace,
lineage_root_job_name,
lineage_root_run_id,
lineage_run_id,
)

parent_run_id = lineage_run_id(task_instance)
parent_job_name = lineage_job_name(task_instance)
parent_job_namespace = lineage_job_namespace()

root_parent_run_id = lineage_root_run_id(task_instance)
rot_parent_job_name = lineage_root_job_name(task_instance)

try: # Added in OL provider 2.9.0, try to use it if possible
from airflow.providers.openlineage.plugins.macros import lineage_root_job_namespace

parent_run_id = _get_ol_run_id(task_instance)
root_parent_run_id = _get_ol_dag_run_id(task_instance)
root_parent_job_namespace = lineage_root_job_namespace(task_instance)
except ImportError:
root_parent_job_namespace = lineage_job_namespace()

return parent_run.ParentRunFacet(
run=parent_run.Run(runId=parent_run_id),
job=parent_run.Job(
namespace=namespace(),
name=f"{task_instance.dag_id}.{task_instance.task_id}",
namespace=parent_job_namespace,
name=parent_job_name,
),
root=parent_run.Root(
run=parent_run.RootRun(runId=root_parent_run_id),
job=parent_run.RootJob(
name=task_instance.dag_id,
namespace=namespace(),
name=rot_parent_job_name,
namespace=root_parent_job_namespace,
),
),
)
Expand Down Expand Up @@ -209,7 +171,7 @@ def _create_ol_event_pair(
return start, end


@require_openlineage_version(provider_min_version="2.3.0")
@require_openlineage_version(provider_min_version="2.5.0")
def emit_openlineage_events_for_databricks_queries(
task_instance,
hook: DatabricksSqlHook | DatabricksHook | None = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider
hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme")

expected_err = (
"OpenLineage provider version `1.99.0` is lower than required `2.3.0`, "
"OpenLineage provider version `1.99.0` is lower than required `2.5.0`, "
"skipping function `emit_openlineage_events_for_databricks_queries` execution"
)
with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook
from airflow.providers.databricks.utils.openlineage import (
_create_ol_event_pair,
_get_ol_run_id,
_get_parent_run_facet,
_get_queries_details_from_databricks,
_process_data_from_api,
Expand All @@ -46,55 +45,28 @@
from airflow.utils.state import TaskInstanceState


def test_get_ol_run_id_ti_success():
logical_date = timezone.datetime(2025, 1, 1)
mock_ti = mock.MagicMock(
dag_id="dag_id",
task_id="task_id",
map_index=1,
try_number=1,
logical_date=logical_date,
state=TaskInstanceState.SUCCESS,
)
mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)}

result = _get_ol_run_id(mock_ti)
assert result == "01941f29-7c00-7087-8906-40e512c257bd"


def test_get_ol_run_id_ti_failed():
logical_date = timezone.datetime(2025, 1, 1)
mock_ti = mock.MagicMock(
dag_id="dag_id",
task_id="task_id",
map_index=1,
try_number=1,
logical_date=logical_date,
state=TaskInstanceState.FAILED,
)
mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)}

result = _get_ol_run_id(mock_ti)
assert result == "01941f29-7c00-7087-8906-40e512c257bd"


def test_get_parent_run_facet():
logical_date = timezone.datetime(2025, 1, 1)
dr = mock.MagicMock(logical_date=logical_date, clear_number=0)
mock_ti = mock.MagicMock(
dag_id="dag_id",
task_id="task_id",
map_index=1,
try_number=1,
logical_date=logical_date,
state=TaskInstanceState.SUCCESS,
dag_run=dr,
)
mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)}
mock_ti.get_template_context.return_value = {"dag_run": dr}

result = _get_parent_run_facet(mock_ti)

assert result.run.runId == "01941f29-7c00-7087-8906-40e512c257bd"
assert result.job.namespace == namespace()
assert result.job.name == "dag_id.task_id"
assert result.root.run.runId == "01941f29-7c00-743e-b109-28b18d0a19c5"
assert result.root.job.namespace == namespace()
assert result.root.job.name == "dag_id"


def test_run_api_call_success():
Expand Down Expand Up @@ -283,7 +255,7 @@ def test_create_ol_event_pair_success(mock_generate_uuid, is_successful):
assert start_event.job == end_event.job


@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_databricks_queries(mock_generate_uuid, mock_version, time_machine):
fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f"
Expand Down Expand Up @@ -520,7 +492,7 @@ def test_emit_openlineage_events_for_databricks_queries(mock_generate_uuid, mock
assert fake_adapter.emit.call_args_list == expected_calls


@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_databricks_queries_without_metadata(
mock_generate_uuid, mock_version, time_machine
Expand Down Expand Up @@ -638,7 +610,7 @@ def test_emit_openlineage_events_for_databricks_queries_without_metadata(
assert fake_adapter.emit.call_args_list == expected_calls


@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids(
mock_generate_uuid, mock_version, time_machine
Expand Down Expand Up @@ -760,7 +732,7 @@ def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i
@mock.patch(
"airflow.providers.openlineage.sqlparser.SQLParser.create_namespace", return_value="databricks_ns"
)
@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace(
mock_generate_uuid, mock_version, mock_parser, time_machine
Expand Down Expand Up @@ -878,7 +850,7 @@ def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i
assert fake_adapter.emit.call_args_list == expected_calls


@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace_raw_ns(
mock_generate_uuid, mock_version, time_machine
Expand Down Expand Up @@ -997,7 +969,7 @@ def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i
assert fake_adapter.emit.call_args_list == expected_calls


@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("importlib.metadata.version", return_value="3.0.0")
@mock.patch("openlineage.client.uuid.generate_new_uuid")
def test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_query_ids(
mock_generate_uuid, mock_version, time_machine
Expand Down Expand Up @@ -1117,7 +1089,7 @@ def test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_q
assert fake_adapter.emit.call_args_list == expected_calls


@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("importlib.metadata.version", return_value="3.0.0")
def test_emit_openlineage_events_for_databricks_queries_missing_query_ids_and_hook(mock_version):
query_ids = []
original_query_ids = copy.deepcopy(query_ids)
Expand All @@ -1142,7 +1114,7 @@ def test_emit_openlineage_events_for_databricks_queries_missing_query_ids_and_ho
fake_adapter.emit.assert_not_called() # No events should be emitted


@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("importlib.metadata.version", return_value="3.0.0")
def test_emit_openlineage_events_for_databricks_queries_missing_query_namespace_and_hook(mock_version):
query_ids = ["1", "2"]
original_query_ids = copy.deepcopy(query_ids)
Expand All @@ -1168,7 +1140,7 @@ def test_emit_openlineage_events_for_databricks_queries_missing_query_namespace_
fake_adapter.emit.assert_not_called() # No events should be emitted


@mock.patch("importlib.metadata.version", return_value="2.3.0")
@mock.patch("importlib.metadata.version", return_value="3.0.0")
def test_emit_openlineage_events_for_databricks_queries_missing_hook_and_query_for_extra_metadata_true(
mock_version,
):
Expand Down Expand Up @@ -1213,7 +1185,7 @@ def test_emit_openlineage_events_with_old_openlineage_provider(mock_version):
return_value=fake_listener,
):
expected_err = (
"OpenLineage provider version `1.99.0` is lower than required `2.3.0`, "
"OpenLineage provider version `1.99.0` is lower than required `2.5.0`, "
"skipping function `emit_openlineage_events_for_databricks_queries` execution"
)

Expand Down
Loading