Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -751,16 +751,20 @@ def get_airflow_state_run_facet(
dag_id: str, run_id: str, task_ids: list[str], dag_run_state: DagRunState
) -> dict[str, AirflowStateRunFacet]:
tis = DagRun.fetch_task_instances(dag_id=dag_id, run_id=run_id, task_ids=task_ids)

def get_task_duration(ti):
if ti.duration is not None:
return ti.duration
if ti.end_date is not None and ti.start_date is not None:
return (ti.end_date - ti.start_date).total_seconds()
# Fallback to 0.0 for tasks with missing timestamps (e.g., skipped/terminated tasks)
return 0.0

return {
"airflowState": AirflowStateRunFacet(
dagRunState=dag_run_state,
tasksState={ti.task_id: ti.state for ti in tis},
tasksDuration={
ti.task_id: ti.duration
if ti.duration is not None
else (ti.end_date - ti.start_date).total_seconds()
for ti in tis
},
tasksDuration={ti.task_id: get_task_duration(ti) for ti in tis},
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
_truncate_string_to_byte_size,
get_airflow_dag_run_facet,
get_airflow_job_facet,
get_airflow_state_run_facet,
get_dag_documentation,
get_fully_qualified_class_name,
get_job_name,
Expand All @@ -57,6 +58,7 @@
from airflow.timetables.events import EventsTimetable
from airflow.timetables.trigger import CronTriggerTimetable
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -2054,3 +2056,65 @@ def test_get_operator_provider_version_for_mapped_operator(mock_providers_manage
mapped_operator = BashOperator.partial(task_id="test_task").expand(bash_command=["echo 1", "echo 2"])
result = get_operator_provider_version(mapped_operator)
assert result == "1.2.0"


class TestGetAirflowStateRunFacet:
@pytest.mark.db_test
def test_task_with_timestamps_defined(self, dag_maker):
"""Test task instance with defined start_date and end_date."""
with dag_maker(dag_id="test_dag"):
BaseOperator(task_id="test_task")

dag_run = dag_maker.create_dagrun()
ti = dag_run.get_task_instance(task_id="test_task")

# Set valid timestamps
start_time = pendulum.parse("2024-01-01T10:00:00Z")
end_time = pendulum.parse("2024-01-01T10:02:30Z") # 150 seconds difference
ti.start_date = start_time
ti.end_date = end_time
ti.state = TaskInstanceState.SUCCESS
ti.duration = None

# Persist changes to database
with create_session() as session:
session.merge(ti)
session.commit()

result = get_airflow_state_run_facet(
dag_id="test_dag",
run_id=dag_run.run_id,
task_ids=["test_task"],
dag_run_state=DagRunState.SUCCESS,
)

assert result["airflowState"].tasksDuration["test_task"] == 150.0

@pytest.mark.db_test
def test_task_with_none_timestamps_fallback_to_zero(self, dag_maker):
"""Test task with None timestamps falls back to 0.0."""
with dag_maker(dag_id="test_dag"):
BaseOperator(task_id="terminated_task")

dag_run = dag_maker.create_dagrun()
ti = dag_run.get_task_instance(task_id="terminated_task")

# Set None timestamps (signal-terminated case)
ti.start_date = None
ti.end_date = None
ti.state = TaskInstanceState.SKIPPED
ti.duration = None

# Persist changes to database
with create_session() as session:
session.merge(ti)
session.commit()

result = get_airflow_state_run_facet(
dag_id="test_dag",
run_id=dag_run.run_id,
task_ids=["terminated_task"],
dag_run_state=DagRunState.FAILED,
)

assert result["airflowState"].tasksDuration["terminated_task"] == 0.0
Loading