From 55812e9849a6ce8dee6a5edaacb7b46ce47e1c84 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Mon, 8 Dec 2025 14:13:45 +0100 Subject: [PATCH] chore: use OL macros instead of building OL ids from scratch --- .../providers/databricks/utils/openlineage.py | 88 ++++++------------- .../databricks/hooks/test_databricks_sql.py | 2 +- .../unit/databricks/utils/test_openlineage.py | 60 ++++--------- .../providers/dbt/cloud/utils/openlineage.py | 79 ++++++++++------- .../unit/dbt/cloud/utils/test_openlineage.py | 41 +++++++-- .../providers/snowflake/utils/openlineage.py | 88 ++++++------------- .../unit/snowflake/hooks/test_snowflake.py | 2 +- .../unit/snowflake/utils/test_openlineage.py | 59 ++++--------- 8 files changed, 166 insertions(+), 253 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py index 971e59f291549..56f4400df61df 100644 --- a/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py +++ b/providers/databricks/src/airflow/providers/databricks/utils/openlineage.py @@ -24,7 +24,6 @@ import requests from airflow.providers.common.compat.openlineage.check import require_openlineage_version -from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import timezone if TYPE_CHECKING: @@ -37,60 +36,6 @@ log = logging.getLogger(__name__) -def _get_logical_date(task_instance): - # todo: remove when min airflow version >= 3.0 - if AIRFLOW_V_3_0_PLUS: - dagrun = task_instance.get_template_context()["dag_run"] - return dagrun.logical_date or dagrun.run_after - - if hasattr(task_instance, "logical_date"): - date = task_instance.logical_date - else: - date = task_instance.execution_date - - return date - - -def _get_dag_run_clear_number(task_instance): - # todo: remove when min airflow version >= 3.0 - if AIRFLOW_V_3_0_PLUS: - dagrun = task_instance.get_template_context()["dag_run"] - return dagrun.clear_number - return task_instance.dag_run.clear_number - - -# todo: move this run_id logic into OpenLineage's listener to avoid differences -def _get_ol_run_id(task_instance) -> str: - """ - Get OpenLineage run_id from TaskInstance. - - It's crucial that the task_instance's run_id creation logic matches OpenLineage's listener implementation. - Only then can we ensure that the generated run_id aligns with the Airflow task, - enabling a proper connection between events. - """ - from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter - - # Generate same OL run id as is generated for current task instance - return OpenLineageAdapter.build_task_instance_run_id( - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - logical_date=_get_logical_date(task_instance), - try_number=task_instance.try_number, - map_index=task_instance.map_index, - ) - - -# todo: move this run_id logic into OpenLineage's listener to avoid differences -def _get_ol_dag_run_id(task_instance) -> str: - from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter - - return OpenLineageAdapter.build_dag_run_id( - dag_id=task_instance.dag_id, - logical_date=_get_logical_date(task_instance), - clear_number=_get_dag_run_clear_number(task_instance), - ) - - def _get_parent_run_facet(task_instance): """ Retrieve the ParentRunFacet associated with a specific Airflow task instance. @@ -101,22 +46,39 @@ def _get_parent_run_facet(task_instance): """ from openlineage.client.facet_v2 import parent_run - from airflow.providers.openlineage.conf import namespace + from airflow.providers.openlineage.plugins.macros import ( + lineage_job_name, + lineage_job_namespace, + lineage_root_job_name, + lineage_root_run_id, + lineage_run_id, + ) + + parent_run_id = lineage_run_id(task_instance) + parent_job_name = lineage_job_name(task_instance) + parent_job_namespace = lineage_job_namespace() + + root_parent_run_id = lineage_root_run_id(task_instance) + rot_parent_job_name = lineage_root_job_name(task_instance) + + try: # Added in OL provider 2.9.0, try to use it if possible + from airflow.providers.openlineage.plugins.macros import lineage_root_job_namespace - parent_run_id = _get_ol_run_id(task_instance) - root_parent_run_id = _get_ol_dag_run_id(task_instance) + root_parent_job_namespace = lineage_root_job_namespace(task_instance) + except ImportError: + root_parent_job_namespace = lineage_job_namespace() return parent_run.ParentRunFacet( run=parent_run.Run(runId=parent_run_id), job=parent_run.Job( - namespace=namespace(), - name=f"{task_instance.dag_id}.{task_instance.task_id}", + namespace=parent_job_namespace, + name=parent_job_name, ), root=parent_run.Root( run=parent_run.RootRun(runId=root_parent_run_id), job=parent_run.RootJob( - name=task_instance.dag_id, - namespace=namespace(), + name=rot_parent_job_name, + namespace=root_parent_job_namespace, ), ), ) @@ -209,7 +171,7 @@ def _create_ol_event_pair( return start, end -@require_openlineage_version(provider_min_version="2.3.0") +@require_openlineage_version(provider_min_version="2.5.0") def emit_openlineage_events_for_databricks_queries( task_instance, hook: DatabricksSqlHook | DatabricksHook | None = None, diff --git a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py index 6de20f4547788..96044f2b62031 100644 --- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py +++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py @@ -569,7 +569,7 @@ def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme") expected_err = ( - "OpenLineage provider version `1.99.0` is lower than required `2.3.0`, " + "OpenLineage provider version `1.99.0` is lower than required `2.5.0`, " "skipping function `emit_openlineage_events_for_databricks_queries` execution" ) with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err): diff --git a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py index 20305a9af4e9a..2d127040d32aa 100644 --- a/providers/databricks/tests/unit/databricks/utils/test_openlineage.py +++ b/providers/databricks/tests/unit/databricks/utils/test_openlineage.py @@ -34,7 +34,6 @@ from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook from airflow.providers.databricks.utils.openlineage import ( _create_ol_event_pair, - _get_ol_run_id, _get_parent_run_facet, _get_queries_details_from_databricks, _process_data_from_api, @@ -46,40 +45,9 @@ from airflow.utils.state import TaskInstanceState -def test_get_ol_run_id_ti_success(): - logical_date = timezone.datetime(2025, 1, 1) - mock_ti = mock.MagicMock( - dag_id="dag_id", - task_id="task_id", - map_index=1, - try_number=1, - logical_date=logical_date, - state=TaskInstanceState.SUCCESS, - ) - mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} - - result = _get_ol_run_id(mock_ti) - assert result == "01941f29-7c00-7087-8906-40e512c257bd" - - -def test_get_ol_run_id_ti_failed(): - logical_date = timezone.datetime(2025, 1, 1) - mock_ti = mock.MagicMock( - dag_id="dag_id", - task_id="task_id", - map_index=1, - try_number=1, - logical_date=logical_date, - state=TaskInstanceState.FAILED, - ) - mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} - - result = _get_ol_run_id(mock_ti) - assert result == "01941f29-7c00-7087-8906-40e512c257bd" - - def test_get_parent_run_facet(): logical_date = timezone.datetime(2025, 1, 1) + dr = mock.MagicMock(logical_date=logical_date, clear_number=0) mock_ti = mock.MagicMock( dag_id="dag_id", task_id="task_id", @@ -87,14 +55,18 @@ def test_get_parent_run_facet(): try_number=1, logical_date=logical_date, state=TaskInstanceState.SUCCESS, + dag_run=dr, ) - mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} + mock_ti.get_template_context.return_value = {"dag_run": dr} result = _get_parent_run_facet(mock_ti) assert result.run.runId == "01941f29-7c00-7087-8906-40e512c257bd" assert result.job.namespace == namespace() assert result.job.name == "dag_id.task_id" + assert result.root.run.runId == "01941f29-7c00-743e-b109-28b18d0a19c5" + assert result.root.job.namespace == namespace() + assert result.root.job.name == "dag_id" def test_run_api_call_success(): @@ -283,7 +255,7 @@ def test_create_ol_event_pair_success(mock_generate_uuid, is_successful): assert start_event.job == end_event.job -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_databricks_queries(mock_generate_uuid, mock_version, time_machine): fake_uuid = "01958e68-03a2-79e3-9ae9-26865cc40e2f" @@ -520,7 +492,7 @@ def test_emit_openlineage_events_for_databricks_queries(mock_generate_uuid, mock assert fake_adapter.emit.call_args_list == expected_calls -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_databricks_queries_without_metadata( mock_generate_uuid, mock_version, time_machine @@ -638,7 +610,7 @@ def test_emit_openlineage_events_for_databricks_queries_without_metadata( assert fake_adapter.emit.call_args_list == expected_calls -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids( mock_generate_uuid, mock_version, time_machine @@ -760,7 +732,7 @@ def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i @mock.patch( "airflow.providers.openlineage.sqlparser.SQLParser.create_namespace", return_value="databricks_ns" ) -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace( mock_generate_uuid, mock_version, mock_parser, time_machine @@ -878,7 +850,7 @@ def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i assert fake_adapter.emit.call_args_list == expected_calls -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_ids_and_namespace_raw_ns( mock_generate_uuid, mock_version, time_machine @@ -997,7 +969,7 @@ def test_emit_openlineage_events_for_databricks_queries_without_explicit_query_i assert fake_adapter.emit.call_args_list == expected_calls -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_query_ids( mock_generate_uuid, mock_version, time_machine @@ -1117,7 +1089,7 @@ def test_emit_openlineage_events_for_databricks_queries_ith_query_ids_and_hook_q assert fake_adapter.emit.call_args_list == expected_calls -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") def test_emit_openlineage_events_for_databricks_queries_missing_query_ids_and_hook(mock_version): query_ids = [] original_query_ids = copy.deepcopy(query_ids) @@ -1142,7 +1114,7 @@ def test_emit_openlineage_events_for_databricks_queries_missing_query_ids_and_ho fake_adapter.emit.assert_not_called() # No events should be emitted -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") def test_emit_openlineage_events_for_databricks_queries_missing_query_namespace_and_hook(mock_version): query_ids = ["1", "2"] original_query_ids = copy.deepcopy(query_ids) @@ -1168,7 +1140,7 @@ def test_emit_openlineage_events_for_databricks_queries_missing_query_namespace_ fake_adapter.emit.assert_not_called() # No events should be emitted -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") def test_emit_openlineage_events_for_databricks_queries_missing_hook_and_query_for_extra_metadata_true( mock_version, ): @@ -1213,7 +1185,7 @@ def test_emit_openlineage_events_with_old_openlineage_provider(mock_version): return_value=fake_listener, ): expected_err = ( - "OpenLineage provider version `1.99.0` is lower than required `2.3.0`, " + "OpenLineage provider version `1.99.0` is lower than required `2.5.0`, " "skipping function `emit_openlineage_events_for_databricks_queries` execution" ) diff --git a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py index 241e72d987429..188ca1255f601 100644 --- a/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py +++ b/providers/dbt/cloud/src/airflow/providers/dbt/cloud/utils/openlineage.py @@ -56,7 +56,48 @@ def _get_dag_run_clear_number(task_instance): return task_instance.dag_run.clear_number -@require_openlineage_version(provider_min_version="2.3.0") +def _get_parent_run_metadata(task_instance): + """ + Retrieve the ParentRunMetadata associated with a specific Airflow task instance. + + This metadata helps link OpenLineage events of child jobs to the original Airflow task execution. + Establishing this connection enables better lineage tracking and observability. + """ + from openlineage.common.provider.dbt import ParentRunMetadata + + from airflow.providers.openlineage.plugins.macros import ( + lineage_job_name, + lineage_job_namespace, + lineage_root_job_name, + lineage_root_run_id, + lineage_run_id, + ) + + parent_run_id = lineage_run_id(task_instance) + parent_job_name = lineage_job_name(task_instance) + parent_job_namespace = lineage_job_namespace() + + root_parent_run_id = lineage_root_run_id(task_instance) + rot_parent_job_name = lineage_root_job_name(task_instance) + + try: # Added in OL provider 2.9.0, try to use it if possible + from airflow.providers.openlineage.plugins.macros import lineage_root_job_namespace + + root_parent_job_namespace = lineage_root_job_namespace(task_instance) + except ImportError: + root_parent_job_namespace = lineage_job_namespace() + + return ParentRunMetadata( + run_id=parent_run_id, + job_name=parent_job_name, + job_namespace=parent_job_namespace, + root_parent_run_id=root_parent_run_id, + root_parent_job_name=rot_parent_job_name, + root_parent_job_namespace=root_parent_job_namespace, + ) + + +@require_openlineage_version(provider_min_version="2.5.0") def generate_openlineage_events_from_dbt_cloud_run( operator: DbtCloudRunJobOperator | DbtCloudJobRunSensor, task_instance: TaskInstance ) -> OperatorLineage: @@ -74,14 +115,10 @@ def generate_openlineage_events_from_dbt_cloud_run( :return: An empty OperatorLineage object indicating the completion of events generation. """ - from openlineage.common.provider.dbt import DbtCloudArtifactProcessor, ParentRunMetadata + from openlineage.common.provider.dbt import DbtCloudArtifactProcessor - from airflow.providers.openlineage.conf import namespace from airflow.providers.openlineage.extractors import OperatorLineage - from airflow.providers.openlineage.plugins.adapter import ( - _PRODUCER, - OpenLineageAdapter, - ) + from airflow.providers.openlineage.plugins.adapter import _PRODUCER from airflow.providers.openlineage.plugins.listener import get_openlineage_listener # if no account_id set this will fallback @@ -140,29 +177,7 @@ async def get_artifacts_for_steps(steps, artifacts): ) log.debug("Preparing OpenLineage parent job information to be included in DBT events.") - # generate same run id of current task instance - parent_run_id = OpenLineageAdapter.build_task_instance_run_id( - dag_id=task_instance.dag_id, - task_id=operator.task_id, - logical_date=_get_logical_date(task_instance), - try_number=task_instance.try_number, - map_index=task_instance.map_index, - ) - - root_parent_run_id = OpenLineageAdapter.build_dag_run_id( - dag_id=task_instance.dag_id, - logical_date=_get_logical_date(task_instance), - clear_number=_get_dag_run_clear_number(task_instance), - ) - - parent_job = ParentRunMetadata( - run_id=parent_run_id, - job_name=f"{task_instance.dag_id}.{task_instance.task_id}", - job_namespace=namespace(), - root_parent_run_id=root_parent_run_id, - root_parent_job_name=task_instance.dag_id, - root_parent_job_namespace=namespace(), - ) + parent_metadata = _get_parent_run_metadata(task_instance) adapter = get_openlineage_listener().adapter # process each step in loop, sending generated events in the same order as steps @@ -178,7 +193,7 @@ async def get_artifacts_for_steps(steps, artifacts): processor = DbtCloudArtifactProcessor( producer=_PRODUCER, - job_namespace=namespace(), + job_namespace=parent_metadata.job_namespace, skip_errors=False, logger=operator.log, manifest=manifest, @@ -187,7 +202,7 @@ async def get_artifacts_for_steps(steps, artifacts): catalog=catalog, ) - processor.dbt_run_metadata = parent_job + processor.dbt_run_metadata = parent_metadata events = processor.parse().events() log.debug("Found %s OpenLineage events for artifact no. %s.", len(events), counter) diff --git a/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py b/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py index e5d9431f8b311..9455a669f8962 100644 --- a/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py +++ b/providers/dbt/cloud/tests/unit/dbt/cloud/utils/test_openlineage.py @@ -27,8 +27,14 @@ from airflow.exceptions import AirflowOptionalProviderFeatureException from airflow.providers.dbt.cloud.hooks.dbt import DbtCloudHook from airflow.providers.dbt.cloud.operators.dbt import DbtCloudRunJobOperator -from airflow.providers.dbt.cloud.utils.openlineage import generate_openlineage_events_from_dbt_cloud_run +from airflow.providers.dbt.cloud.utils.openlineage import ( + _get_parent_run_metadata, + generate_openlineage_events_from_dbt_cloud_run, +) +from airflow.providers.openlineage.conf import namespace from airflow.providers.openlineage.extractors import OperatorLineage +from airflow.utils import timezone +from airflow.utils.state import TaskInstanceState TASK_ID = "dbt_test" DAG_ID = "dbt_dag" @@ -94,12 +100,13 @@ def get_dbt_artifact(*args, **kwargs): [ ("1.99.0", True), ("2.0.0", True), - ("2.3.0", False), + ("2.3.0", True), + ("2.5.0", False), ("2.99.0", False), ], ) def test_previous_version_openlineage_provider(value, is_error): - """When using OpenLineage, the dbt-cloud provider now depends on openlineage provider >= 2.3""" + """When using OpenLineage, the dbt-cloud provider now depends on openlineage provider >= 2.4""" def _mock_version(package): if package == "apache-airflow-providers-openlineage": @@ -110,7 +117,7 @@ def _mock_version(package): mock_task_instance = MagicMock() expected_err = ( - f"OpenLineage provider version `{value}` is lower than required `2.3.0`, " + f"OpenLineage provider version `{value}` is lower than required `2.5.0`, " "skipping function `generate_openlineage_events_from_dbt_cloud_run` execution" ) @@ -126,8 +133,32 @@ def _mock_version(package): generate_openlineage_events_from_dbt_cloud_run(mock_operator, mock_task_instance) +def test_get_parent_run_metadata(): + logical_date = timezone.datetime(2025, 1, 1) + dr = MagicMock(logical_date=logical_date, clear_number=0) + mock_ti = MagicMock( + dag_id="dag_id", + task_id="task_id", + map_index=1, + try_number=1, + logical_date=logical_date, + state=TaskInstanceState.SUCCESS, + dag_run=dr, + ) + mock_ti.get_template_context.return_value = {"dag_run": dr} + + result = _get_parent_run_metadata(mock_ti) + + assert result.run_id == "01941f29-7c00-7087-8906-40e512c257bd" + assert result.job_namespace == namespace() + assert result.job_name == "dag_id.task_id" + assert result.root_parent_run_id == "01941f29-7c00-743e-b109-28b18d0a19c5" + assert result.root_parent_job_namespace == namespace() + assert result.root_parent_job_name == "dag_id" + + class TestGenerateOpenLineageEventsFromDbtCloudRun: - @patch("importlib.metadata.version", return_value="2.3.0") + @patch("importlib.metadata.version", return_value="3.0.0") @patch("airflow.providers.openlineage.plugins.listener.get_openlineage_listener") @patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.build_task_instance_run_id") @patch("airflow.providers.openlineage.plugins.adapter.OpenLineageAdapter.build_dag_run_id") diff --git a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py index d06a6c463e406..483f176cf59e1 100644 --- a/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py +++ b/providers/snowflake/src/airflow/providers/snowflake/utils/openlineage.py @@ -23,7 +23,6 @@ from urllib.parse import quote, urlparse, urlunparse from airflow.providers.common.compat.openlineage.check import require_openlineage_version -from airflow.providers.snowflake.version_compat import AIRFLOW_V_3_0_PLUS from airflow.utils import timezone if TYPE_CHECKING: @@ -109,60 +108,6 @@ def fix_snowflake_sqlalchemy_uri(uri: str) -> str: return urlunparse((parts.scheme, hostname, parts.path, parts.params, parts.query, parts.fragment)) -def _get_logical_date(task_instance): - # todo: remove when min airflow version >= 3.0 - if AIRFLOW_V_3_0_PLUS: - dagrun = task_instance.get_template_context()["dag_run"] - return dagrun.logical_date or dagrun.run_after - - if hasattr(task_instance, "logical_date"): - date = task_instance.logical_date - else: - date = task_instance.execution_date - - return date - - -def _get_dag_run_clear_number(task_instance): - # todo: remove when min airflow version >= 3.0 - if AIRFLOW_V_3_0_PLUS: - dagrun = task_instance.get_template_context()["dag_run"] - return dagrun.clear_number - return task_instance.dag_run.clear_number - - -# todo: move this run_id logic into OpenLineage's listener to avoid differences -def _get_ol_run_id(task_instance) -> str: - """ - Get OpenLineage run_id from TaskInstance. - - It's crucial that the task_instance's run_id creation logic matches OpenLineage's listener implementation. - Only then can we ensure that the generated run_id aligns with the Airflow task, - enabling a proper connection between events. - """ - from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter - - # Generate same OL run id as is generated for current task instance - return OpenLineageAdapter.build_task_instance_run_id( - dag_id=task_instance.dag_id, - task_id=task_instance.task_id, - logical_date=_get_logical_date(task_instance), - try_number=task_instance.try_number, - map_index=task_instance.map_index, - ) - - -# todo: move this run_id logic into OpenLineage's listener to avoid differences -def _get_ol_dag_run_id(task_instance) -> str: - from airflow.providers.openlineage.plugins.adapter import OpenLineageAdapter - - return OpenLineageAdapter.build_dag_run_id( - dag_id=task_instance.dag_id, - logical_date=_get_logical_date(task_instance), - clear_number=_get_dag_run_clear_number(task_instance), - ) - - def _get_parent_run_facet(task_instance): """ Retrieve the ParentRunFacet associated with a specific Airflow task instance. @@ -173,22 +118,39 @@ def _get_parent_run_facet(task_instance): """ from openlineage.client.facet_v2 import parent_run - from airflow.providers.openlineage.conf import namespace + from airflow.providers.openlineage.plugins.macros import ( + lineage_job_name, + lineage_job_namespace, + lineage_root_job_name, + lineage_root_run_id, + lineage_run_id, + ) + + parent_run_id = lineage_run_id(task_instance) + parent_job_name = lineage_job_name(task_instance) + parent_job_namespace = lineage_job_namespace() + + root_parent_run_id = lineage_root_run_id(task_instance) + rot_parent_job_name = lineage_root_job_name(task_instance) + + try: # Added in OL provider 2.9.0, try to use it if possible + from airflow.providers.openlineage.plugins.macros import lineage_root_job_namespace - parent_run_id = _get_ol_run_id(task_instance) - root_parent_run_id = _get_ol_dag_run_id(task_instance) + root_parent_job_namespace = lineage_root_job_namespace(task_instance) + except ImportError: + root_parent_job_namespace = lineage_job_namespace() return parent_run.ParentRunFacet( run=parent_run.Run(runId=parent_run_id), job=parent_run.Job( - namespace=namespace(), - name=f"{task_instance.dag_id}.{task_instance.task_id}", + namespace=parent_job_namespace, + name=parent_job_name, ), root=parent_run.Root( run=parent_run.RootRun(runId=root_parent_run_id), job=parent_run.RootJob( - name=task_instance.dag_id, - namespace=namespace(), + name=rot_parent_job_name, + namespace=root_parent_job_namespace, ), ), ) @@ -299,7 +261,7 @@ def _create_snowflake_event_pair( return start, end -@require_openlineage_version(provider_min_version="2.3.0") +@require_openlineage_version(provider_min_version="2.5.0") def emit_openlineage_events_for_snowflake_queries( task_instance, hook: SnowflakeHook | SnowflakeSqlApiHook | None = None, diff --git a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py index 4485c87c7a2b2..d92ca12752d91 100644 --- a/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py +++ b/providers/snowflake/tests/unit/snowflake/hooks/test_snowflake.py @@ -1046,7 +1046,7 @@ def test_get_openlineage_database_specific_lineage_with_old_openlineage_provider hook.get_openlineage_database_info = lambda x: mock.MagicMock(authority="auth", scheme="scheme") expected_err = ( - "OpenLineage provider version `1.99.0` is lower than required `2.3.0`, " + "OpenLineage provider version `1.99.0` is lower than required `2.5.0`, " "skipping function `emit_openlineage_events_for_snowflake_queries` execution" ) with pytest.raises(AirflowOptionalProviderFeatureException, match=expected_err): diff --git a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py index 7f70ef51c50aa..0dcee3adc9531 100644 --- a/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py +++ b/providers/snowflake/tests/unit/snowflake/utils/test_openlineage.py @@ -35,7 +35,6 @@ from airflow.providers.snowflake.hooks.snowflake_sql_api import SnowflakeSqlApiHook from airflow.providers.snowflake.utils.openlineage import ( _create_snowflake_event_pair, - _get_ol_run_id, _get_parent_run_facet, _get_queries_details_from_snowflake, _process_data_from_api, @@ -117,40 +116,9 @@ def test_fix_account_name(name, expected): ) -def test_get_ol_run_id_ti_success(): - logical_date = timezone.datetime(2025, 1, 1) - mock_ti = mock.MagicMock( - dag_id="dag_id", - task_id="task_id", - map_index=1, - try_number=1, - logical_date=logical_date, - state=TaskInstanceState.SUCCESS, - ) - mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} - - result = _get_ol_run_id(mock_ti) - assert result == "01941f29-7c00-7087-8906-40e512c257bd" - - -def test_get_ol_run_id_ti_failed(): - logical_date = timezone.datetime(2025, 1, 1) - mock_ti = mock.MagicMock( - dag_id="dag_id", - task_id="task_id", - map_index=1, - try_number=1, - logical_date=logical_date, - state=TaskInstanceState.FAILED, - ) - mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} - - result = _get_ol_run_id(mock_ti) - assert result == "01941f29-7c00-7087-8906-40e512c257bd" - - def test_get_parent_run_facet(): logical_date = timezone.datetime(2025, 1, 1) + dr = mock.MagicMock(logical_date=logical_date, clear_number=0) mock_ti = mock.MagicMock( dag_id="dag_id", task_id="task_id", @@ -158,14 +126,18 @@ def test_get_parent_run_facet(): try_number=1, logical_date=logical_date, state=TaskInstanceState.SUCCESS, + dag_run=dr, ) - mock_ti.get_template_context.return_value = {"dag_run": mock.MagicMock(logical_date=logical_date)} + mock_ti.get_template_context.return_value = {"dag_run": dr} result = _get_parent_run_facet(mock_ti) assert result.run.runId == "01941f29-7c00-7087-8906-40e512c257bd" assert result.job.namespace == namespace() assert result.job.name == "dag_id.task_id" + assert result.root.run.runId == "01941f29-7c00-743e-b109-28b18d0a19c5" + assert result.root.job.namespace == namespace() + assert result.root.job.name == "dag_id" def test_process_data_from_api(): @@ -578,7 +550,7 @@ def test_create_snowflake_event_pair_success(mock_generate_uuid, is_successful): assert start_event.job == end_event.job -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_snowflake_queries_with_extra_metadata( mock_generate_uuid, mock_version, time_machine @@ -818,7 +790,7 @@ def test_emit_openlineage_events_for_snowflake_queries_with_extra_metadata( assert fake_adapter.emit.call_args_list == expected_calls -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_snowflake_queries_without_extra_metadata( mock_generate_uuid, mock_version, time_machine @@ -936,7 +908,7 @@ def test_emit_openlineage_events_for_snowflake_queries_without_extra_metadata( assert fake_adapter.emit.call_args_list == expected_calls -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_snowflake_queries_without_query_ids( mock_generate_uuid, mock_version, time_machine @@ -1056,7 +1028,7 @@ def test_emit_openlineage_events_for_snowflake_queries_without_query_ids( @mock.patch("airflow.providers.openlineage.sqlparser.SQLParser.create_namespace", return_value="snowflake_ns") -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_snowflake_queries_without_query_ids_and_namespace( mock_generate_uuid, mock_version, mock_parser, time_machine @@ -1175,7 +1147,7 @@ def test_emit_openlineage_events_for_snowflake_queries_without_query_ids_and_nam assert fake_adapter.emit.call_args_list == expected_calls -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") @mock.patch("openlineage.client.uuid.generate_new_uuid") def test_emit_openlineage_events_for_snowflake_queries_with_query_ids_and_hook_query_ids( mock_generate_uuid, mock_version, time_machine @@ -1294,7 +1266,7 @@ def test_emit_openlineage_events_for_snowflake_queries_with_query_ids_and_hook_q assert fake_adapter.emit.call_args_list == expected_calls -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") def test_emit_openlineage_events_for_snowflake_queries_missing_query_ids_and_hook(mock_version): fake_adapter = mock.MagicMock() fake_adapter.emit = mock.MagicMock() @@ -1313,7 +1285,7 @@ def test_emit_openlineage_events_for_snowflake_queries_missing_query_ids_and_hoo fake_adapter.emit.assert_not_called() # No events should be emitted -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") def test_emit_openlineage_events_for_snowflake_queries_missing_query_namespace_and_hook(mock_version): query_ids = ["1", "2"] original_query_ids = copy.deepcopy(query_ids) @@ -1338,7 +1310,7 @@ def test_emit_openlineage_events_for_snowflake_queries_missing_query_namespace_a fake_adapter.emit.assert_not_called() # No events should be emitted -@mock.patch("importlib.metadata.version", return_value="2.3.0") +@mock.patch("importlib.metadata.version", return_value="3.0.0") def test_emit_openlineage_events_for_snowflake_queries_missing_hook_and_query_for_extra_metadata_true( mock_version, ): @@ -1368,7 +1340,6 @@ def test_emit_openlineage_events_for_snowflake_queries_missing_hook_and_query_fo fake_adapter.emit.assert_not_called() # No events should be emitted -# emit_openlineage_events_for_snowflake_queries requires OL provider 2.3.0 @mock.patch("importlib.metadata.version", return_value="1.99.0") def test_emit_openlineage_events_with_old_openlineage_provider(mock_version): query_ids = ["q1", "q2"] @@ -1384,7 +1355,7 @@ def test_emit_openlineage_events_with_old_openlineage_provider(mock_version): return_value=fake_listener, ): expected_err = ( - "OpenLineage provider version `1.99.0` is lower than required `2.3.0`, " + "OpenLineage provider version `1.99.0` is lower than required `2.5.0`, " "skipping function `emit_openlineage_events_for_snowflake_queries` execution" )