Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,13 @@ class DagRunInfo(InfoJsonEncodable):
"end_date",
]

casts = {"duration": lambda dagrun: DagRunInfo.duration(dagrun)}
casts = {
"duration": lambda dagrun: DagRunInfo.duration(dagrun),
"dag_bundle_name": lambda dagrun: DagRunInfo.dag_version_info(dagrun, "bundle_name"),
"dag_bundle_version": lambda dagrun: DagRunInfo.dag_version_info(dagrun, "bundle_version"),
"dag_version_id": lambda dagrun: DagRunInfo.dag_version_info(dagrun, "version_id"),
"dag_version_number": lambda dagrun: DagRunInfo.dag_version_info(dagrun, "version_number"),
}

@classmethod
def duration(cls, dagrun: DagRun) -> float | None:
Expand All @@ -394,15 +400,33 @@ def duration(cls, dagrun: DagRun) -> float | None:
return None
return (dagrun.end_date - dagrun.start_date).total_seconds()

@classmethod
def dag_version_info(cls, dagrun: DagRun, key: str) -> str | int | None:
# AF2 DagRun and AF3 DagRun SDK model (on worker) do not have this information
if not getattr(dagrun, "dag_versions", []):
return None
current_version = dagrun.dag_versions[-1]
if key == "bundle_name":
return current_version.bundle_name
if key == "bundle_version":
return current_version.bundle_version
if key == "version_id":
return str(current_version.id)
if key == "version_number":
return current_version.version_number
raise ValueError(f"Unsupported key: {key}`")


class TaskInstanceInfo(InfoJsonEncodable):
"""Defines encoding TaskInstance object to JSON."""

includes = ["duration", "try_number", "pool", "queued_dttm", "log_url"]
casts = {
"map_index": lambda ti: (
ti.map_index if hasattr(ti, "map_index") and getattr(ti, "map_index") != -1 else None
)
"map_index": lambda ti: ti.map_index if getattr(ti, "map_index", -1) != -1 else None,
"dag_bundle_version": lambda ti: (
ti.bundle_instance.version if hasattr(ti, "bundle_instance") else None
),
"dag_bundle_name": lambda ti: ti.bundle_instance.name if hasattr(ti, "bundle_instance") else None,
}


Expand Down
131 changes: 115 additions & 16 deletions providers/openlineage/tests/unit/openlineage/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

import datetime
import pathlib
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, PropertyMock, patch

import pendulum
import pytest
from uuid6 import uuid7

from airflow import DAG
from airflow.providers.openlineage.version_compat import AIRFLOW_V_3_0_PLUS
from airflow.utils import timezone

if AIRFLOW_V_3_0_PLUS:
from airflow.sdk import task
Expand Down Expand Up @@ -156,6 +158,14 @@ def test_get_airflow_dag_run_facet():
dagrun_mock.run_after = datetime.datetime(2024, 6, 1, 1, 2, 4, tzinfo=datetime.timezone.utc)
dagrun_mock.start_date = datetime.datetime(2024, 6, 1, 1, 2, 4, tzinfo=datetime.timezone.utc)
dagrun_mock.end_date = datetime.datetime(2024, 6, 1, 1, 2, 14, 34172, tzinfo=datetime.timezone.utc)
dagrun_mock.dag_versions = [
MagicMock(
bundle_name="bundle_name",
bundle_version="bundle_version",
id="version_id",
version_number="version_number",
)
]

result = get_airflow_dag_run_facet(dagrun_mock)

Expand Down Expand Up @@ -189,6 +199,10 @@ def test_get_airflow_dag_run_facet():
"duration": 10.034172,
"logical_date": "2024-06-01T01:02:04+00:00",
"run_after": "2024-06-01T01:02:04+00:00",
"dag_bundle_name": "bundle_name",
"dag_bundle_version": "bundle_version",
"dag_version_id": "version_id",
"dag_version_number": "version_number",
},
)
}
Expand Down Expand Up @@ -216,6 +230,28 @@ def test_dag_run_duration(dag_run_attrs, expected_duration):
assert result == expected_duration


def test_dag_run_version_no_versions():
dag_run = MagicMock()
del dag_run.dag_versions
result = DagRunInfo.dag_version_info(dag_run, "somekey")
assert result is None


@pytest.mark.parametrize("key", ["bundle_name", "bundle_version", "version_id", "version_number"])
def test_dag_run_version(key):
dagrun_mock = MagicMock(DagRun)
dagrun_mock.dag_versions = [
MagicMock(
bundle_name="bundle_name",
bundle_version="bundle_version",
id="version_id",
version_number="version_number",
)
]
result = DagRunInfo.dag_version_info(dagrun_mock, key)
assert result == key


def test_get_fully_qualified_class_name_serialized_operator():
op_module_path = BASH_OPERATOR_PATH
op_name = "BashOperator"
Expand Down Expand Up @@ -1334,27 +1370,37 @@ def test_dag_info_schedule_asset_or_time_schedule(self):


@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 3 test")
@pytest.mark.db_test
def test_dagrun_info_af3():
@patch.object(DagRun, "dag_versions", new_callable=PropertyMock)
def test_dagrun_info_af3(mocked_dag_versions):
from airflow.models.dag_version import DagVersion
from airflow.utils.types import DagRunTriggeredByType

date = datetime.datetime(2024, 6, 1, tzinfo=datetime.timezone.utc)
dag = DAG(
"dag_id",
schedule=None,
start_date=date,
)
dagrun = dag.create_dagrun(
dv1 = DagVersion()
dv2 = DagVersion()
dv2.id = "version_id"
dv2.version_number = "version_number"
dv2.bundle_name = "bundle_name"
dv2.bundle_version = "bundle_version"

mocked_dag_versions.return_value = [dv1, dv2]
dagrun = DagRun(
dag_id="dag_id",
run_id="dag_run__run_id",
data_interval=(date, date),
run_type=DagRunType.MANUAL,
state=DagRunState.RUNNING,
triggered_by=DagRunTriggeredByType.UI,
run_after=date,
queued_at=date,
logical_date=date,
run_after=date,
start_date=date,
conf={"a": 1},
state=DagRunState.RUNNING,
run_type=DagRunType.MANUAL,
creating_job_id=123,
data_interval=(date, date),
triggered_by=DagRunTriggeredByType.UI,
backfill_id=999,
bundle_version="bundle_version",
)
dagrun.start_date = date
assert dagrun.dag_versions == [dv1, dv2]
dagrun.end_date = date + datetime.timedelta(seconds=74, microseconds=546)

result = DagRunInfo(dagrun)
Expand All @@ -1370,6 +1416,10 @@ def test_dagrun_info_af3():
"start_date": "2024-06-01T00:00:00+00:00",
"logical_date": "2024-06-01T00:00:00+00:00",
"run_after": "2024-06-01T00:00:00+00:00",
"dag_bundle_name": "bundle_name",
"dag_bundle_version": "bundle_version",
"dag_version_id": "version_id",
"dag_version_number": "version_number",
}


Expand Down Expand Up @@ -1407,11 +1457,58 @@ def test_dagrun_info_af2():
"external_trigger": False,
"start_date": "2024-06-01T00:00:00+00:00",
"logical_date": "2024-06-01T00:00:00+00:00",
"dag_bundle_name": None,
"dag_bundle_version": None,
"dag_version_id": None,
"dag_version_number": None,
}


@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 3 test")
def test_taskinstance_info_af3():
from airflow.sdk.api.datamodels._generated import TaskInstance
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance

task = BaseOperator(task_id="hello")
task._is_mapped = True
dag_id = "basic_task"

dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
task.dag = dag

ti_id = uuid7()
ti = TaskInstance(
id=ti_id,
task_id=task.task_id,
dag_id=dag_id,
run_id="test_run",
try_number=1,
map_index=2,
)
start_date = timezone.datetime(2025, 1, 1)

runtime_ti = RuntimeTaskInstance.model_construct(
**ti.model_dump(exclude_unset=True),
task=task,
_ti_context_from_server=None,
start_date=start_date,
)
runtime_ti.end_date = start_date + datetime.timedelta(seconds=12, milliseconds=345)
bundle_instance = MagicMock(version="bundle_version")
bundle_instance.name = "bundle_name"
runtime_ti.bundle_instance = bundle_instance

assert dict(TaskInstanceInfo(runtime_ti)) == {
"map_index": 2,
"try_number": 1,
"dag_bundle_version": "bundle_version",
"dag_bundle_name": "bundle_name",
}


@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Airflow 2 test")
@patch.object(TaskInstance, "log_url", "some_log_url") # Depends on the host, hard to test exact value
def test_taskinstance_info():
def test_taskinstance_info_af2():
some_date = datetime.datetime(2024, 6, 1, tzinfo=datetime.timezone.utc)
task_obj = PythonOperator(task_id="task_id", python_callable=lambda x: x)
ti = TaskInstance(
Expand All @@ -1427,6 +1524,8 @@ def test_taskinstance_info():
"try_number": 0,
"queued_dttm": "2024-06-01T00:00:00+00:00",
"log_url": "some_log_url",
"dag_bundle_name": None,
"dag_bundle_version": None,
}


Expand Down