From da516044d37f61307e1f693f4488d29cac207592 Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Wed, 17 Apr 2024 14:14:36 +0200 Subject: [PATCH 01/10] Move render map index method and apply to dry run --- airflow/models/taskinstance.py | 21 +++++++++++---------- tests/models/test_taskinstance.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 10 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 8f9d71cfe7f43..9ed5002c863d9 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2731,14 +2731,6 @@ def signal_handler(signum, frame): previous_state=TaskInstanceState.QUEUED, task_instance=self, session=session ) - def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: - """Render named map index if the DAG author defined map_index_template at the task level.""" - if jinja_env is None or (template := context.get("map_index_template")) is None: - return None - rendered_map_index = jinja_env.from_string(template).render(context) - log.debug("Map index rendered as %s", rendered_map_index) - return rendered_map_index - # Execute the task. with set_current_context(context): try: @@ -2746,10 +2738,10 @@ def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) except Exception: # If the task failed, swallow rendering error so it doesn't mask the main error. with contextlib.suppress(jinja2.TemplateSyntaxError, jinja2.UndefinedError): - self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) + self.rendered_map_index = self._render_map_index(context, jinja_env=jinja_env) raise else: # If the task succeeded, render normally to let rendering error bubble up. - self.rendered_map_index = _render_map_index(context, jinja_env=jinja_env) + self.rendered_map_index = self._render_map_index(context, jinja_env=jinja_env) # Run post_execute callback self.task.post_execute(context=context, result=result) @@ -2759,6 +2751,14 @@ def _render_map_index(context: Context, *, jinja_env: jinja2.Environment | None) Stats.incr("operator_successes", tags={**self.stats_tags, "task_type": self.task.task_type}) Stats.incr("ti_successes", tags=self.stats_tags) + def _render_map_index(self, context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: + """Render named map index if the DAG author defined map_index_template at the task level.""" + if jinja_env is None or (template := context.get("map_index_template")) is None: + return None + rendered_map_index = jinja_env.from_string(template).render(context) + self.log.debug("Map index rendered as %s", rendered_map_index) + return rendered_map_index + def _execute_task(self, context: Context, task_orig: Operator): """ Execute Task (optionally with a Timeout) and push Xcom results. @@ -2872,6 +2872,7 @@ def dry_run(self) -> None: if TYPE_CHECKING: assert isinstance(self.task, BaseOperator) self.task.dry_run() + self.rendered_map_index = self._render_map_index(self.get_template_context(), jinja_env=None) @provide_session def _handle_reschedule( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index e4c9e17b21541..34daefcb01100 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3480,6 +3480,34 @@ def raise_skip_exception(): assert State.SKIPPED == ti.state assert callback_function.called + def test_rendered_map_index_on_dry_run(self, dag_maker, session): + from airflow.operators.python import get_current_context + from airflow.utils.task_instance_session import set_current_task_instance_session + + with dag_maker(): + @task(map_index_template="instance-map-index-{{ variable }}") + def example(param: str): + context = get_current_context() + context["variable"] = param * 3 + + example.expand(param=["a", "b"]) + + dr = dag_maker.create_dagrun() + tis, _ = example.expand_mapped_task( + dr.run_id, session=session + ) + + rendered_map_index = [] + with set_current_task_instance_session(session): + for ti in tis: + ti.task = task + ti.refresh_from_task(example) + ti.dry_run() + + rendered_map_index.append(ti.rendered_map_index) + + assert sorted(rendered_map_index) == ["instance-map-index-aaa", "instance-map-index-bbb"] + @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) @pytest.mark.parametrize("queue_by_policy", [None, "forced_queue"]) From 70e978bec729285a962e70a018055d058bac3181 Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Wed, 17 Apr 2024 14:53:22 +0200 Subject: [PATCH 02/10] fix format --- tests/models/test_taskinstance.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 34daefcb01100..9a464e5b405f3 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3493,9 +3493,7 @@ def example(param: str): example.expand(param=["a", "b"]) dr = dag_maker.create_dagrun() - tis, _ = example.expand_mapped_task( - dr.run_id, session=session - ) + tis, _ = example.expand_mapped_task(dr.run_id, session=session) rendered_map_index = [] with set_current_task_instance_session(session): From 9303eb79cc3e4be55fff2b261869057de297fefc Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Wed, 17 Apr 2024 16:28:22 +0200 Subject: [PATCH 03/10] Update test_taskinstance.py --- tests/models/test_taskinstance.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 9a464e5b405f3..4a060ca2b2143 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3485,6 +3485,7 @@ def test_rendered_map_index_on_dry_run(self, dag_maker, session): from airflow.utils.task_instance_session import set_current_task_instance_session with dag_maker(): + @task(map_index_template="instance-map-index-{{ variable }}") def example(param: str): context = get_current_context() From 8a717c043056b2cbb4572721e639922c6941f12f Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Wed, 17 Apr 2024 17:01:26 +0200 Subject: [PATCH 04/10] Update test_taskinstance.py --- tests/models/test_taskinstance.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 4a060ca2b2143..ef4e30c788c7e 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3494,13 +3494,14 @@ def example(param: str): example.expand(param=["a", "b"]) dr = dag_maker.create_dagrun() - tis, _ = example.expand_mapped_task(dr.run_id, session=session) + example_task = dr.dag.get_task("example") + tis, _ = example_task.expand_mapped_task(dr.run_id, session=session) rendered_map_index = [] with set_current_task_instance_session(session): for ti in tis: - ti.task = task - ti.refresh_from_task(example) + ti.task = example_task + ti.refresh_from_task(example_task) ti.dry_run() rendered_map_index.append(ti.rendered_map_index) From f5ebd5f5924290ae6fd3136af118c858602317a0 Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Fri, 26 Apr 2024 19:00:15 +0200 Subject: [PATCH 05/10] WIP --- airflow/models/taskinstance.py | 25 +++++++++++-- tests/models/test_taskinstance.py | 58 +++++++++++++++++++++++-------- 2 files changed, 66 insertions(+), 17 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 7aa615279e042..feb05d1a58062 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -31,6 +31,7 @@ from contextlib import nullcontext from datetime import timedelta from enum import Enum +from pprint import pprint from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple from urllib.parse import quote @@ -2752,9 +2753,15 @@ def signal_handler(signum, frame): def _render_map_index(self, context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: """Render named map index if the DAG author defined map_index_template at the task level.""" + if jinja_env is None or (template := context.get("map_index_template")) is None: return None - rendered_map_index = jinja_env.from_string(template).render(context) + + pprint(self.task) + print() + pprint(context) + + rendered_map_index = self.task.render_template(template, context, jinja_env) self.log.debug("Map index rendered as %s", rendered_map_index) return rendered_map_index @@ -2867,11 +2874,23 @@ def dry_run(self) -> None: assert self.task self.task = self.task.prepare_for_execution() - self.render_templates() + context = self.get_template_context(ignore_param_exceptions=False) + + jinja_env = None + dag = self.task.get_dag() + if dag is not None: + jinja_env = dag.get_template_env() + if TYPE_CHECKING: assert isinstance(self.task, BaseOperator) + self.render_templates(context=context, jinja_env=jinja_env) self.task.dry_run() - self.rendered_map_index = self._render_map_index(self.get_template_context(), jinja_env=None) + + with set_current_context(context): + self.rendered_map_index = self._render_map_index( + context, + jinja_env=jinja_env, + ) @provide_session def _handle_reschedule( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 50b96cb56196b..30500fc5004ba 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -93,6 +93,7 @@ from airflow.utils.session import create_session, provide_session from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.task_group import TaskGroup + from airflow.utils.types import DagRunType from airflow.utils.xcom import XCOM_RETURN_KEY from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER @@ -3477,26 +3478,47 @@ def raise_skip_exception(): assert callback_function.called def test_rendered_map_index_on_dry_run(self, dag_maker, session): - from airflow.operators.python import get_current_context from airflow.utils.task_instance_session import set_current_task_instance_session + from airflow.operators.python import get_current_context - with dag_maker(): + with dag_maker(dag_id="example", session=session) as dag: + @dag.task + def emit(): + return ["a", "b"] - @task(map_index_template="instance-map-index-{{ variable }}") - def example(param: str): + @dag.task(map_index_template="instance-map-index-{{ task.my_variable }}") + def example(value: str): context = get_current_context() - context["variable"] = param * 3 + context["my_variable"] = value * 3 - example.expand(param=["a", "b"]) + example.expand(value=emit()) - dr = dag_maker.create_dagrun() - example_task = dr.dag.get_task("example") - tis, _ = example_task.expand_mapped_task(dr.run_id, session=session) + dag_run = dag_maker.create_dagrun() + + # kwargs = { + # "state": State.RUNNING, + # "start_date": pendulum.yesterday("UTC"), + # "session": session, + # "run_type": DagRunType.MANUAL, + # "data_interval": ( + # pendulum.yesterday("UTC"), + # pendulum.now("UTC") + # ), + # "execution_date": pendulum.now("UTC"), + # } + # dag_run = dag.create_dagrun(**kwargs) + emit_ti = dag_run.get_task_instance("emit", session=session) + emit_ti.refresh_from_task(dag.get_task("emit")) + emit_ti.run() + + example_task = dag.get_task("example") + + mapped_tis, _ = example_task.expand_mapped_task(dag_run.run_id, session=session) rendered_map_index = [] + with set_current_task_instance_session(session): - for ti in tis: - ti.task = example_task + for ti in mapped_tis: ti.refresh_from_task(example_task) ti.dry_run() @@ -4141,9 +4163,17 @@ def show(value): mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) assert max_map_index + 1 == len(mapped_tis) == len(upstream_return) - for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): - ti.refresh_from_task(show_task) - ti.run() + # for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): + # ti.refresh_from_task(show_task) + # ti.dry_run() + + from airflow.utils.task_instance_session import set_current_task_instance_session + + with set_current_task_instance_session(session): + for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): + ti.refresh_from_task(show_task) + ti.dry_run() + assert outputs == expected_outputs def test_map_product(self, dag_maker, session): From 773d193d6699ffaeffe6d024058a34347b564f4b Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Fri, 26 Apr 2024 19:28:55 +0200 Subject: [PATCH 06/10] Edited test --- airflow/models/taskinstance.py | 14 ++++---------- tests/models/test_taskinstance.py | 21 ++++----------------- 2 files changed, 8 insertions(+), 27 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index d5533e20b46ad..9e2ab53f99a91 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -31,7 +31,6 @@ from contextlib import nullcontext from datetime import timedelta from enum import Enum -from pprint import pprint from typing import TYPE_CHECKING, Any, Callable, Collection, Generator, Iterable, Mapping, Tuple from urllib.parse import quote @@ -2759,10 +2758,6 @@ def _render_map_index(self, context: Context, *, jinja_env: jinja2.Environment | if jinja_env is None or (template := context.get("map_index_template")) is None: return None - pprint(self.task) - print() - pprint(context) - rendered_map_index = self.task.render_template(template, context, jinja_env) self.log.debug("Map index rendered as %s", rendered_map_index) return rendered_map_index @@ -2888,11 +2883,10 @@ def dry_run(self) -> None: self.render_templates(context=context, jinja_env=jinja_env) self.task.dry_run() - with set_current_context(context): - self.rendered_map_index = self._render_map_index( - context, - jinja_env=jinja_env, - ) + self.rendered_map_index = self._render_map_index( + context, + jinja_env=jinja_env, + ) @provide_session def _handle_reschedule( diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index b0950abaa58b1..a0af9682241ed 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3480,34 +3480,20 @@ def raise_skip_exception(): def test_rendered_map_index_on_dry_run(self, dag_maker, session): from airflow.utils.task_instance_session import set_current_task_instance_session - from airflow.operators.python import get_current_context with dag_maker(dag_id="example", session=session) as dag: @dag.task def emit(): return ["a", "b"] - @dag.task(map_index_template="instance-map-index-{{ task.my_variable }}") + @dag.task(map_index_template="instance-map-index-{{ task.op_kwargs['value'] * 3 }}") def example(value: str): - context = get_current_context() - context["my_variable"] = value * 3 + # Whatever lies here won't be executed as we trigger a `dry_run()` only. + assert value example.expand(value=emit()) dag_run = dag_maker.create_dagrun() - - # kwargs = { - # "state": State.RUNNING, - # "start_date": pendulum.yesterday("UTC"), - # "session": session, - # "run_type": DagRunType.MANUAL, - # "data_interval": ( - # pendulum.yesterday("UTC"), - # pendulum.now("UTC") - # ), - # "execution_date": pendulum.now("UTC"), - # } - # dag_run = dag.create_dagrun(**kwargs) emit_ti = dag_run.get_task_instance("emit", session=session) emit_ti.refresh_from_task(dag.get_task("emit")) emit_ti.run() @@ -3520,6 +3506,7 @@ def example(value: str): with set_current_task_instance_session(session): for ti in mapped_tis: + ti.task = example_task ti.refresh_from_task(example_task) ti.dry_run() From d3b453d9275ed2ff7e9f0dca7e7958bf31730ee8 Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Wed, 15 May 2024 16:05:15 +0200 Subject: [PATCH 07/10] Update deprecations_ignore.yml --- tests/deprecations_ignore.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml index d54ba980f8045..3189507e17b61 100644 --- a/tests/deprecations_ignore.yml +++ b/tests/deprecations_ignore.yml @@ -167,6 +167,7 @@ - tests/models/test_taskinstance.py::TestTaskInstance::test_handle_failure - tests/models/test_taskinstance.py::TestTaskInstance::test_handle_failure_fail_stop - tests/models/test_taskinstance.py::TestTaskInstance::test_outlet_datasets +- tests/models/test_taskinstance.py::TestTaskInstance::test_rendered_map_index_on_dry_run - tests/models/test_taskinstance.py::TestTaskInstance::test_template_with_custom_timetable_deprecated_context - tests/models/test_taskinstance.py::TestTaskInstance::test_xcom_pull - tests/models/test_taskinstance.py::TestTaskInstance::test_xcom_pull_different_execution_date From 486193c21c9a02b6307904cb09810876a40f7717 Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Sun, 19 May 2024 15:29:49 +0200 Subject: [PATCH 08/10] Update test_taskinstance.py --- tests/models/test_taskinstance.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 8ea8cb0445012..fbe742358a224 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -4281,16 +4281,9 @@ def show(value): mapped_tis, max_map_index = show_task.expand_mapped_task(dag_run.run_id, session=session) assert max_map_index + 1 == len(mapped_tis) == len(upstream_return) - # for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): - # ti.refresh_from_task(show_task) - # ti.dry_run() - - from airflow.utils.task_instance_session import set_current_task_instance_session - - with set_current_task_instance_session(session): - for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): - ti.refresh_from_task(show_task) - ti.dry_run() + for ti in sorted(mapped_tis, key=operator.attrgetter("map_index")): + ti.refresh_from_task(show_task) + ti.dry_run() assert outputs == expected_outputs From 4da0c122d61a53dcd89692f588207dcd26fdb95a Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Sun, 19 May 2024 16:11:09 +0200 Subject: [PATCH 09/10] add test for doc --- tests/deprecations_ignore.yml | 1 - tests/models/test_taskinstance.py | 36 ++++++++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml index 5edce292e2d88..d6400074e6567 100644 --- a/tests/deprecations_ignore.yml +++ b/tests/deprecations_ignore.yml @@ -121,7 +121,6 @@ - tests/models/test_taskinstance.py::TestTaskInstance::test_get_previous_start_date_none - tests/models/test_taskinstance.py::TestTaskInstance::test_handle_failure - tests/models/test_taskinstance.py::TestTaskInstance::test_handle_failure_fail_stop -- tests/models/test_taskinstance.py::TestTaskInstance::test_rendered_map_index_on_dry_run - tests/models/test_taskinstance.py::TestTaskInstance::test_template_with_custom_timetable_deprecated_context - tests/models/test_taskinstance.py::TestTaskInstance::test_xcom_pull - tests/models/test_taskinstance.py::TestTaskInstance::test_xcom_pull_different_execution_date diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index fbe742358a224..eac60f0522ab4 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3629,21 +3629,51 @@ def example(value: str): emit_ti.run() example_task = dag.get_task("example") - mapped_tis, _ = example_task.expand_mapped_task(dag_run.run_id, session=session) rendered_map_index = [] - with set_current_task_instance_session(session): for ti in mapped_tis: ti.task = example_task ti.refresh_from_task(example_task) ti.dry_run() - rendered_map_index.append(ti.rendered_map_index) assert sorted(rendered_map_index) == ["instance-map-index-aaa", "instance-map-index-bbb"] + def test_rendered_map_index_example_from_doc(self, dag_maker, session): + + from airflow.operators.python import get_current_context + + with dag_maker(dag_id="example", session=session) as dag: + @dag.task + def emit(): + return ["a", "b"] + + @dag.task(map_index_template="{{ my_variable }}") + def my_task(my_value: str): + context = get_current_context() + context["my_variable"] = my_value * 3 + return context["my_variable"] + + my_task.expand(my_value=emit()) + + dag_run = dag_maker.create_dagrun() + emit_ti = dag_run.get_task_instance("emit", session=session) + emit_ti.refresh_from_task(dag.get_task("emit")) + emit_ti.run() + + my_task = dag.get_task("my_task") + mapped_tis, _ = my_task.expand_mapped_task(dag_run.run_id, session=session) + + rendered_map_index = [] + for ti in mapped_tis: + ti.task = my_task + ti.refresh_from_task(my_task) + ti.run() + rendered_map_index.append(ti.rendered_map_index) + + assert sorted(rendered_map_index) == ["aaa", "bbb"] @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) @pytest.mark.parametrize("queue_by_policy", [None, "forced_queue"]) From 1b84ae5e3283b453ce678b085e07001adc8ded25 Mon Sep 17 00:00:00 2001 From: Aurelien Didier Date: Tue, 21 May 2024 17:23:00 +0200 Subject: [PATCH 10/10] fix static checks --- airflow/models/taskinstance.py | 2 ++ tests/models/test_taskinstance.py | 7 ++++--- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 39615a615a165..0745155cea9a6 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2693,6 +2693,8 @@ def signal_handler(signum, frame): def _render_map_index(self, context: Context, *, jinja_env: jinja2.Environment | None) -> str | None: """Render named map index if the DAG author defined map_index_template at the task level.""" + if TYPE_CHECKING: + assert self.task if jinja_env is None or (template := context.get("map_index_template")) is None: return None diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index eac60f0522ab4..87373ca2864f5 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -3612,6 +3612,7 @@ def test_rendered_map_index_on_dry_run(self, dag_maker, session): from airflow.utils.task_instance_session import set_current_task_instance_session with dag_maker(dag_id="example", session=session) as dag: + @dag.task def emit(): return ["a", "b"] @@ -3642,10 +3643,10 @@ def example(value: str): assert sorted(rendered_map_index) == ["instance-map-index-aaa", "instance-map-index-bbb"] def test_rendered_map_index_example_from_doc(self, dag_maker, session): - from airflow.operators.python import get_current_context with dag_maker(dag_id="example", session=session) as dag: + @dag.task def emit(): return ["a", "b"] @@ -3653,8 +3654,7 @@ def emit(): @dag.task(map_index_template="{{ my_variable }}") def my_task(my_value: str): context = get_current_context() - context["my_variable"] = my_value * 3 - return context["my_variable"] + context["my_variable"] = my_value * 3 # type: ignore[typeddict-unknown-key] my_task.expand(my_value=emit()) @@ -3675,6 +3675,7 @@ def my_task(my_value: str): assert sorted(rendered_map_index) == ["aaa", "bbb"] + @pytest.mark.parametrize("pool_override", [None, "test_pool2"]) @pytest.mark.parametrize("queue_by_policy", [None, "forced_queue"]) def test_refresh_from_task(pool_override, queue_by_policy, monkeypatch):