diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index 7b6f3595a54ce..da086e88d7505 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -54,7 +54,6 @@ from airflow.utils.providers_configuration_loader import providers_configuration_loaded from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.state import DagRunState, State -from airflow.utils.task_instance_session import set_current_task_instance_session from airflow.utils.types import DagRunTriggeredByType, DagRunType if TYPE_CHECKING: @@ -441,7 +440,7 @@ def task_render(args, dag: DAG | None = None) -> None: create_if_necessary="memory", ) - with create_session() as session, set_current_task_instance_session(session=session): + with create_session() as session: context = ti.get_template_context(session=session) task = dag.get_task(args.task_id) # TODO (GH-52141): After sdk separation, ti.get_template_context() would diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 748b44acb865e..231a74966c868 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -60,7 +60,6 @@ from airflow.utils.db_manager import RunDBManager from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.sqlalchemy import get_dialect_name -from airflow.utils.task_instance_session import get_current_task_instance_session USE_PSYCOPG3: bool try: @@ -1546,7 +1545,7 @@ class LazySelectSequence(Sequence[T]): _select_asc: Select _select_desc: Select - _session: Session = attrs.field(kw_only=True, factory=get_current_task_instance_session) + _session: Session _len: int | None = attrs.field(init=False, default=None) @classmethod @@ -1555,7 +1554,7 @@ def from_select( select: Select, *, order_by: Sequence[ColumnElement], - session: Session | None = None, + session: Session, ) -> Self: s1 = select for col in order_by: @@ -1563,7 +1562,7 @@ def from_select( s2 = select for col in order_by: s2 = s2.order_by(col.desc()) - return cls(s1, s2, session=session or get_current_task_instance_session()) + return cls(s1, s2, session=session) @staticmethod def _rebuild_select(stmt: TextClause) -> Select: @@ -1603,7 +1602,6 @@ def __setstate__(self, state: Any) -> None: s1, s2, self._len = state self._select_asc = self._rebuild_select(text(s1)) self._select_desc = self._rebuild_select(text(s2)) - self._session = get_current_task_instance_session() def __bool__(self) -> bool: return check_query_exists(self._select_asc, session=self._session) diff --git a/airflow-core/src/airflow/utils/task_instance_session.py b/airflow-core/src/airflow/utils/task_instance_session.py deleted file mode 100644 index 019a752c773c1..0000000000000 --- a/airflow-core/src/airflow/utils/task_instance_session.py +++ /dev/null @@ -1,60 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -from __future__ import annotations - -import contextlib -import logging -import traceback -from typing import TYPE_CHECKING - -from airflow import settings - -if TYPE_CHECKING: - from sqlalchemy.orm import Session - -__current_task_instance_session: Session | None = None - -log = logging.getLogger(__name__) - - -def get_current_task_instance_session() -> Session: - global __current_task_instance_session - if not __current_task_instance_session: - log.warning("No task session set for this task. Continuing but this likely causes a resource leak.") - log.warning("Please report this and stacktrace below to https://github.com/apache/airflow/issues") - for filename, line_number, name, line in traceback.extract_stack(): - log.warning('File: "%s", %s , in %s', filename, line_number, name) - if line: - log.warning(" %s", line.strip()) - __current_task_instance_session = settings.get_session()() - return __current_task_instance_session - - -@contextlib.contextmanager -def set_current_task_instance_session(session: Session): - global __current_task_instance_session - if __current_task_instance_session: - raise RuntimeError( - "Session already set for this task. " - "You can only have one 'set_current_task_session' context manager active at a time." - ) - __current_task_instance_session = session - try: - yield - finally: - __current_task_instance_session = None diff --git a/airflow-core/tests/unit/models/test_renderedtifields.py b/airflow-core/tests/unit/models/test_renderedtifields.py index 8083459b6be85..f31ed7722c019 100644 --- a/airflow-core/tests/unit/models/test_renderedtifields.py +++ b/airflow-core/tests/unit/models/test_renderedtifields.py @@ -41,7 +41,6 @@ from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import task as task_decorator from airflow.utils.state import TaskInstanceState -from airflow.utils.task_instance_session import set_current_task_instance_session from tests_common.test_utils.asserts import assert_queries_count from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields @@ -250,32 +249,29 @@ def test_delete_old_records( Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id. """ - with set_current_task_instance_session(session=session): - with dag_maker("test_delete_old_records") as dag: - task = BashOperator(task_id="test", bash_command="echo {{ ds }}") - rtif_list = [] - for num in range(rtif_num): - dr = dag_maker.create_dagrun( - run_id=str(num), logical_date=dag.start_date + timedelta(days=num) - ) - ti = dr.task_instances[0] - ti.task = task - rtif_list.append(RTIF(ti)) + with dag_maker("test_delete_old_records") as dag: + task = BashOperator(task_id="test", bash_command="echo {{ ds }}") + rtif_list = [] + for num in range(rtif_num): + dr = dag_maker.create_dagrun(run_id=str(num), logical_date=dag.start_date + timedelta(days=num)) + ti = dr.task_instances[0] + ti.task = task + rtif_list.append(RTIF(ti)) - session.add_all(rtif_list) - session.flush() + session.add_all(rtif_list) + session.flush() - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() - for rtif in rtif_list: - assert rtif in result + for rtif in rtif_list: + assert rtif in result - assert rtif_num == len(result) + assert rtif_num == len(result) - with assert_queries_count(expected_query_count): - RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() - assert remaining_rtifs == len(result) + with assert_queries_count(expected_query_count): + RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() + assert remaining_rtifs == len(result) @pytest.mark.parametrize( ("num_runs", "num_to_keep", "remaining_rtifs", "expected_query_count"), @@ -292,35 +288,32 @@ def test_delete_old_records_mapped( Test that old records are deleted from rendered_task_instance_fields table for a given task_id and dag_id with mapped tasks. """ - with set_current_task_instance_session(session=session): - with dag_maker("test_delete_old_records", session=session, serialized=True) as dag: - mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"]) - for num in range(num_runs): - dr = dag_maker.create_dagrun( - run_id=f"run_{num}", logical_date=dag.start_date + timedelta(days=num) - ) + with dag_maker("test_delete_old_records", session=session, serialized=True) as dag: + mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"]) + for num in range(num_runs): + dr = dag_maker.create_dagrun( + run_id=f"run_{num}", logical_date=dag.start_date + timedelta(days=num) + ) - TaskMap.expand_mapped_task( - dag.task_dict[mapped.task_id], dr.run_id, session=dag_maker.session - ) - session.refresh(dr) - for ti in dr.task_instances: - ti.task = mapped - session.add(RTIF(ti)) - session.flush() + TaskMap.expand_mapped_task(dag.task_dict[mapped.task_id], dr.run_id, session=dag_maker.session) + session.refresh(dr) + for ti in dr.task_instances: + ti.task = mapped + session.add(RTIF(ti)) + session.flush() - result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all() - assert len(result) == num_runs * 2 + result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all() + assert len(result) == num_runs * 2 - with assert_queries_count(expected_query_count): - RTIF.delete_old_records( - task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session - ) - result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all() - rtif_num_runs = Counter(rtif.run_id for rtif in result) - assert len(rtif_num_runs) == remaining_rtifs - # Check that we have _all_ the data for each row - assert len(result) == remaining_rtifs * 2 + with assert_queries_count(expected_query_count): + RTIF.delete_old_records( + task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session + ) + result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all() + rtif_num_runs = Counter(rtif.run_id for rtif in result) + assert len(rtif_num_runs) == remaining_rtifs + # Check that we have _all_ the data for each row + assert len(result) == remaining_rtifs * 2 def test_write(self, dag_maker): """ diff --git a/providers/standard/tests/unit/standard/decorators/test_python.py b/providers/standard/tests/unit/standard/decorators/test_python.py index 391d4ae7f4d0a..cfc5df5ab2ae2 100644 --- a/providers/standard/tests/unit/standard/decorators/test_python.py +++ b/providers/standard/tests/unit/standard/decorators/test_python.py @@ -25,7 +25,6 @@ from airflow.exceptions import AirflowException, XComNotFound from airflow.models.taskinstance import TaskInstance from airflow.models.taskmap import TaskMap -from airflow.utils.task_instance_session import set_current_task_instance_session from tests_common.test_utils.version_compat import ( AIRFLOW_V_3_0_1, @@ -818,10 +817,49 @@ def task2(arg1, arg2): ... assert set(unmapped.op_kwargs) == {"arg1", "arg2"} +@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2") def test_mapped_render_template_fields(dag_maker, session): @task_decorator def fn(arg1, arg2): ... + with dag_maker(session=session): + task1 = BaseOperator(task_id="op1") + mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output) + + dr = dag_maker.create_dagrun() + ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session) + + ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session) + + session.add( + TaskMap( + dag_id=dr.dag_id, + task_id=task1.task_id, + run_id=dr.run_id, + map_index=-1, + length=1, + keys=None, + ) + ) + session.flush() + + mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session) + mapped_ti.map_index = 0 + assert isinstance(mapped_ti.task, MappedOperator) + mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, BaseOperator) + + assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}" + assert mapped_ti.task.op_kwargs["arg2"] == "fn" + + +@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2") +def test_mapped_render_template_fields_af2(dag_maker, session): + from airflow.utils.task_instance_session import set_current_task_instance_session + + @task_decorator + def fn(arg1, arg2): ... + with set_current_task_instance_session(session=session): with dag_maker(session=session): task1 = BaseOperator(task_id="op1") diff --git a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py index aa3be7a3cd8f3..28b5164c626f2 100644 --- a/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py +++ b/providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py @@ -20,8 +20,6 @@ import pytest -from airflow.utils.task_instance_session import set_current_task_instance_session - pytest.importorskip("weaviate") from airflow.providers.weaviate.operators.weaviate import ( @@ -87,10 +85,9 @@ def test_partial_batch_hook_params(self, dag_maker, session): dr = dag_maker.create_dagrun() tis = dr.get_task_instances(session=session) - with set_current_task_instance_session(session=session): - for ti in tis: - ti.render_templates() - assert ti.task.hook_params == {"baz": "biz"} + for ti in tis: + ti.render_templates() + assert ti.task.hook_params == {"baz": "biz"} class TestWeaviateDocumentIngestOperator: @@ -147,7 +144,6 @@ def test_partial_hook_params(self, dag_maker, session): dr = dag_maker.create_dagrun() tis = dr.get_task_instances(session=session) - with set_current_task_instance_session(session=session): - for ti in tis: - ti.render_templates() - assert ti.task.hook_params == {"baz": "biz"} + for ti in tis: + ti.render_templates() + assert ti.task.hook_params == {"baz": "biz"}