Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: Move rendering of map_index_template so it renders for failed tasks as long as it was defined before the point of failure #38902

Merged
merged 20 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 22 additions & 13 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,9 @@ def _creator_note(val):
return TaskInstanceNote(*val)


def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: Context, task_orig: Operator):
def _execute_task(
task_instance: TaskInstance | TaskInstancePydantic, context: Context, task_orig: Operator, jinja_env=None
):
"""
Execute Task (optionally with a Timeout) and push Xcom results.

Expand Down Expand Up @@ -433,7 +435,7 @@ def _execute_task(task_instance: TaskInstance | TaskInstancePydantic, context: C
if execute_callable.__name__ == "execute":
execute_callable_kwargs[f"{task_to_execute.__class__.__name__}__sentinel"] = _sentinel

def _execute_callable(context: Context, **execute_callable_kwargs):
def _execute_callable(context: Context, jinja_env=None, **execute_callable_kwargs):
try:
# Print a marker for log grouping of details before task execution
log.info("::endgroup::")
Expand All @@ -453,6 +455,15 @@ def _execute_callable(context: Context, **execute_callable_kwargs):
# Print a marker post execution for internals of post task processing
log.info("::group::Post task execution logs")

# DAG authors define map_index_template at the task level
if jinja_env is not None and (template := context.get("map_index_template")) is not None:
rendered_map_index = jinja_env.from_string(template).render(context)
log.info("Map index rendered as %s", rendered_map_index)
else:
rendered_map_index = None

task_instance.rendered_map_index = rendered_map_index

# If a timeout is specified for the task, make it fail
# if it goes beyond
if task_to_execute.execution_timeout:
Expand All @@ -470,12 +481,12 @@ def _execute_callable(context: Context, **execute_callable_kwargs):
raise AirflowTaskTimeout()
# Run task in timeout wrapper
with timeout(timeout_seconds):
result = _execute_callable(context=context, **execute_callable_kwargs)
result = _execute_callable(context=context, **execute_callable_kwargs, jinja_env=jinja_env)
except AirflowTaskTimeout:
task_to_execute.on_kill()
raise
else:
result = _execute_callable(context=context, **execute_callable_kwargs)
result = _execute_callable(context=context, **execute_callable_kwargs, jinja_env=jinja_env)
cm = nullcontext() if InternalApiConfig.get_use_internal_api() else create_session()
with cm as session_or_null:
if task_to_execute.do_xcom_push:
Expand All @@ -501,7 +512,8 @@ def _execute_callable(context: Context, **execute_callable_kwargs):
_record_task_map_for_downstreams(
task_instance=task_instance, task=task_orig, value=xcom_value, session=session_or_null
)
return result

return result, task_instance.rendered_map_index


def _refresh_from_db(
Expand Down Expand Up @@ -2715,29 +2727,26 @@ def signal_handler(signum, frame):

# Execute the task
with set_current_context(context):
result = self._execute_task(context, task_orig)
result, rendered_map_index = self._execute_task(context, task_orig, jinja_env=jinja_env)
uranusjr marked this conversation as resolved.
Show resolved Hide resolved

self.rendered_map_index = rendered_map_index

# Run post_execute callback
self.task.post_execute(context=context, result=result)

# DAG authors define map_index_template at the task level
if jinja_env is not None and (template := context.get("map_index_template")) is not None:
rendered_map_index = self.rendered_map_index = jinja_env.from_string(template).render(context)
self.log.info("Map index rendered as %s", rendered_map_index)

Stats.incr(f"operator_successes_{self.task.task_type}", tags=self.stats_tags)
# Same metric with tagging
Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
Stats.incr("ti_successes", tags=self.stats_tags)

def _execute_task(self, context: Context, task_orig: Operator):
def _execute_task(self, context: Context, task_orig: Operator, jinja_env=None):
"""
Execute Task (optionally with a Timeout) and push Xcom results.

:param context: Jinja2 context
:param task_orig: origin task
"""
return _execute_task(self, context, task_orig)
return _execute_task(self, context, task_orig, jinja_env)

@provide_session
def defer_task(self, session: Session, defer: TaskDeferred) -> None:
Expand Down
26 changes: 26 additions & 0 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,30 @@ def task1(map_name):
return task1.expand(map_name=map_names)


def _create_named_map_index_renders_on_failure_classic(*, task_id, map_names, template):
class HasMapName(BaseOperator):
def __init__(self, *, map_name: str, **kwargs):
super().__init__(**kwargs)
self.map_name = map_name
raise AirflowSkipException("Imagine this task failed!")

return HasMapName.partial(task_id=task_id, map_index_template=template).expand(
map_name=map_names,
)


def _create_named_map_index_renders_on_failure_taskflow(*, task_id, map_names, template):
from airflow.operators.python import get_current_context

@task(task_id=task_id, map_index_template=template)
def task1(map_name):
context = get_current_context()
context["map_name"] = map_name
raise AirflowSkipException("Imagine this task failed!")

return task1.expand(map_name=map_names)


@pytest.mark.parametrize(
"template, expected_rendered_names",
[
Expand All @@ -649,6 +673,8 @@ def task1(map_name):
[
pytest.param(_create_mapped_with_name_template_classic, id="classic"),
pytest.param(_create_mapped_with_name_template_taskflow, id="taskflow"),
pytest.param(_create_named_map_index_renders_on_failure_classic, id="classic-failure"),
pytest.param(_create_named_map_index_renders_on_failure_taskflow, id="taskflow-failure"),
],
)
def test_expand_mapped_task_instance_with_named_index(
Expand Down
Loading