Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move render map index method and apply to dry run #39087

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 27 additions & 11 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -2671,25 +2671,17 @@ def signal_handler(signum, frame):
previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session
)

def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
"""Render named map index if the DAG author defined map_index_template at the task level."""
if jinja_env is None or (template := context.get("map_index_template")) is None:
return None
rendered_map_index = jinja_env.from_string(template).render(context)
log.debug("Map index rendered as %s", rendered_map_index)
return rendered_map_index

# Execute the task.
with set_current_context(context):
try:
result = self._execute_task(context, task_orig)
except Exception:
# If the task failed, swallow rendering error so it doesn't mask the main error.
with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError):
self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
self.rendered_map_index = self._render_map_index(context, jinja_env=jinja_env)
raise
else: # If the task succeeded, render normally to let rendering error bubble up.
self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env)
self.rendered_map_index = self._render_map_index(context, jinja_env=jinja_env)

# Run post_execute callback
self.task.post_execute(context=context, result=result)
Expand All @@ -2699,6 +2691,18 @@ def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None)
Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type})
Stats.incr("ti_successes", tags=self.stats_tags)

def _render_map_index(self, context: Context, *, jinja_env: jinja2.Environment | None) -> str | None:
"""Render named map index if the DAG author defined map_index_template at the task level."""
if TYPE_CHECKING:
assert self.task

if jinja_env is None or (template := context.get("map_index_template")) is None:
return None

rendered_map_index = self.task.render_template(template, context, jinja_env)
self.log.debug("Map index rendered as %s", rendered_map_index)
return rendered_map_index

def _execute_task(self, context: Context, task_orig: Operator):
"""
Execute Task (optionally with a Timeout) and push Xcom results.
Expand Down Expand Up @@ -2805,11 +2809,23 @@ def dry_run(self) -> None:
assert self.task

self.task = self.task.prepare_for_execution()
self.render_templates()
context = self.get_template_context(ignore_param_exceptions=False)

jinja_env = None
dag = self.task.get_dag()
if dag is not None:
jinja_env = dag.get_template_env()

if TYPE_CHECKING:
assert isinstance(self.task, BaseOperator)
self.render_templates(context=context, jinja_env=jinja_env)
self.task.dry_run()

self.rendered_map_index = self._render_map_index(
context,
jinja_env=jinja_env,
)

@provide_session
def _handle_reschedule(
self,
Expand Down
70 changes: 69 additions & 1 deletion tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3608,6 +3608,73 @@ def raise_skip_exception():
assert State.SKIPPED == ti.state
assert callback_function.called

def test_rendered_map_index_on_dry_run(self, dag_maker, session):
from airflow.utils.task_instance_session import set_current_task_instance_session

with dag_maker(dag_id="example", session=session) as dag:

@dag.task
def emit():
return ["a", "b"]

@dag.task(map_index_template="instance-map-index-{{ task.op_kwargs['value'] * 3 }}")
def example(value: str):
# Whatever lies here won't be executed as we trigger a `dry_run()` only.
assert value

example.expand(value=emit())

dag_run = dag_maker.create_dagrun()
emit_ti = dag_run.get_task_instance("emit", session=session)
emit_ti.refresh_from_task(dag.get_task("emit"))
emit_ti.run()

example_task = dag.get_task("example")
mapped_tis, _ = example_task.expand_mapped_task(dag_run.run_id, session=session)

rendered_map_index = []
with set_current_task_instance_session(session):
for ti in mapped_tis:
ti.task = example_task
ti.refresh_from_task(example_task)
ti.dry_run()
rendered_map_index.append(ti.rendered_map_index)

assert sorted(rendered_map_index) == ["instance-map-index-aaa", "instance-map-index-bbb"]

def test_rendered_map_index_example_from_doc(self, dag_maker, session):
from airflow.operators.python import get_current_context

with dag_maker(dag_id="example", session=session) as dag:

@dag.task
def emit():
return ["a", "b"]

@dag.task(map_index_template="{{ my_variable }}")
def my_task(my_value: str):
context = get_current_context()
context["my_variable"] = my_value * 3 # type: ignore[typeddict-unknown-key]

my_task.expand(my_value=emit())

dag_run = dag_maker.create_dagrun()
emit_ti = dag_run.get_task_instance("emit", session=session)
emit_ti.refresh_from_task(dag.get_task("emit"))
emit_ti.run()

my_task = dag.get_task("my_task")
mapped_tis, _ = my_task.expand_mapped_task(dag_run.run_id, session=session)

rendered_map_index = []
for ti in mapped_tis:
ti.task = my_task
ti.refresh_from_task(my_task)
ti.run()
rendered_map_index.append(ti.rendered_map_index)

assert sorted(rendered_map_index) == ["aaa", "bbb"]


@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
@pytest.mark.parametrize("queue_by_policy", [None, "forced_queue"])
Expand Down Expand Up @@ -4247,7 +4314,8 @@ def show(value):

for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")):
ti.refresh_from_task(show_task)
ti.run()
ti.dry_run()

assert outputs == expected_outputs

def test_map_product(self, dag_maker, session):
Expand Down
Loading