Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,9 +1395,17 @@ def finalize(
task = ti.task
# Pushing xcom for each operator extra links defined on the operator only.
for oe in task.operator_extra_links:
link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type]
log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key)
_xcom_push_to_db(ti, key=xcom_key, value=link)
try:
link, xcom_key = oe.get_link(operator=task, ti_key=ti), oe.xcom_key # type: ignore[arg-type]
log.debug("Setting xcom for operator extra link", link=link, xcom_key=xcom_key)
_xcom_push_to_db(ti, key=xcom_key, value=link)
except Exception:
log.exception(
"Failed to push an xcom for task operator extra link",
link_name=oe.name,
xcom_key=oe.xcom_key,
ti=ti,
)

if getattr(ti.task, "overwrite_rtif_after_execution", False):
log.debug("Overwriting Rendered template fields.")
Expand Down
90 changes: 89 additions & 1 deletion task-sdk/tests/task_sdk/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pathlib import Path
from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import patch
from unittest.mock import call, patch

import pandas as pd
import pytest
Expand All @@ -48,6 +48,7 @@
from airflow.sdk import (
DAG,
BaseOperator,
BaseOperatorLink,
Connection,
dag as dag_decorator,
get_current_context,
Expand Down Expand Up @@ -1723,6 +1724,93 @@ def execute(self, context):
map_index=runtime_ti.map_index,
)

def test_task_failed_with_operator_extra_links(
self, create_runtime_ti, mock_supervisor_comms, time_machine
):
"""Test that operator extra links are pushed to xcoms even when task fails."""
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

class DummyTestOperator(BaseOperator):
operator_extra_links = (AirflowLink(),)

def execute(self, context):
raise ValueError("Task failed intentionally")

task = DummyTestOperator(task_id="task_with_operator_extra_links")
runtime_ti = create_runtime_ti(task=task)
context = runtime_ti.get_template_context()
runtime_ti.start_date = instant
runtime_ti.end_date = instant

state, _, error = run(runtime_ti, context=context, log=mock.MagicMock())
assert state == TaskInstanceState.FAILED
assert error is not None

with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set:
finalize(
runtime_ti,
log=mock.MagicMock(),
state=TaskInstanceState.FAILED,
context=context,
error=error,
)
assert mock_xcom_set.mock_calls == [
call(
key="_link_AirflowLink",
value="https://airflow.apache.org",
dag_id=runtime_ti.dag_id,
task_id=runtime_ti.task_id,
run_id=runtime_ti.run_id,
map_index=runtime_ti.map_index,
)
]

def test_operator_extra_links_exception_handling(
self, create_runtime_ti, mock_supervisor_comms, time_machine
):
"""Test that exceptions in get_link() don't prevent other links from being pushed."""
instant = timezone.datetime(2024, 12, 3, 10, 0)
time_machine.move_to(instant, tick=False)

class FailingLink(BaseOperatorLink):
"""A link that raises an exception when get_link is called."""

name = "failing_link"

def get_link(self, operator, *, ti_key):
raise ValueError("Link generation failed")

class DummyTestOperator(BaseOperator):
operator_extra_links = (FailingLink(), AirflowLink())

def execute(self, context):
pass

task = DummyTestOperator(task_id="task_with_multiple_links")
runtime_ti = create_runtime_ti(task=task)
context = runtime_ti.get_template_context()
runtime_ti.start_date = instant
runtime_ti.end_date = instant

with mock.patch.object(XCom, "_set_xcom_in_db") as mock_xcom_set:
finalize(
runtime_ti,
log=mock.MagicMock(),
state=TaskInstanceState.SUCCESS,
context=context,
)
assert mock_xcom_set.mock_calls == [
call(
key="_link_AirflowLink",
value="https://airflow.apache.org",
dag_id=runtime_ti.dag_id,
task_id=runtime_ti.task_id,
run_id=runtime_ti.run_id,
map_index=runtime_ti.map_index,
)
]

@pytest.mark.parametrize(
["cmd", "rendered_cmd"],
[
Expand Down