Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions airflow-core/src/airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.state import DagRunState, State
from airflow.utils.task_instance_session import set_current_task_instance_session
from airflow.utils.types import DagRunTriggeredByType, DagRunType

if TYPE_CHECKING:
Expand Down Expand Up @@ -441,7 +440,7 @@ def task_render(args, dag: DAG | None = None) -> None:
create_if_necessary="memory",
)

with create_session() as session, set_current_task_instance_session(session=session):
with create_session() as session:
context = ti.get_template_context(session=session)
task = dag.get_task(args.task_id)
# TODO (GH-52141): After sdk separation, ti.get_template_context() would
Expand Down
8 changes: 3 additions & 5 deletions airflow-core/src/airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from airflow.utils.db_manager import RunDBManager
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import get_dialect_name
from airflow.utils.task_instance_session import get_current_task_instance_session

USE_PSYCOPG3: bool
try:
Expand Down Expand Up @@ -1546,7 +1545,7 @@ class LazySelectSequence(Sequence[T]):

_select_asc: Select
_select_desc: Select
_session: Session = attrs.field(kw_only=True, factory=get_current_task_instance_session)
_session: Session
_len: int | None = attrs.field(init=False, default=None)

@classmethod
Expand All @@ -1555,15 +1554,15 @@ def from_select(
select: Select,
*,
order_by: Sequence[ColumnElement],
session: Session | None = None,
session: Session,
) -> Self:
s1 = select
for col in order_by:
s1 = s1.order_by(col.asc())
s2 = select
for col in order_by:
s2 = s2.order_by(col.desc())
return cls(s1, s2, session=session or get_current_task_instance_session())
return cls(s1, s2, session=session)

@staticmethod
def _rebuild_select(stmt: TextClause) -> Select:
Expand Down Expand Up @@ -1603,7 +1602,6 @@ def __setstate__(self, state: Any) -> None:
s1, s2, self._len = state
self._select_asc = self._rebuild_select(text(s1))
self._select_desc = self._rebuild_select(text(s2))
self._session = get_current_task_instance_session()

def __bool__(self) -> bool:
return check_query_exists(self._select_asc, session=self._session)
Expand Down
60 changes: 0 additions & 60 deletions airflow-core/src/airflow/utils/task_instance_session.py

This file was deleted.

89 changes: 41 additions & 48 deletions airflow-core/tests/unit/models/test_renderedtifields.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import task as task_decorator
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_instance_session import set_current_task_instance_session

from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields
Expand Down Expand Up @@ -250,32 +249,29 @@ def test_delete_old_records(
Test that old records are deleted from rendered_task_instance_fields table
for a given task_id and dag_id.
"""
with set_current_task_instance_session(session=session):
with dag_maker("test_delete_old_records") as dag:
task = BashOperator(task_id="test", bash_command="echo {{ ds }}")
rtif_list = []
for num in range(rtif_num):
dr = dag_maker.create_dagrun(
run_id=str(num), logical_date=dag.start_date + timedelta(days=num)
)
ti = dr.task_instances[0]
ti.task = task
rtif_list.append(RTIF(ti))
with dag_maker("test_delete_old_records") as dag:
task = BashOperator(task_id="test", bash_command="echo {{ ds }}")
rtif_list = []
for num in range(rtif_num):
dr = dag_maker.create_dagrun(run_id=str(num), logical_date=dag.start_date + timedelta(days=num))
ti = dr.task_instances[0]
ti.task = task
rtif_list.append(RTIF(ti))

session.add_all(rtif_list)
session.flush()
session.add_all(rtif_list)
session.flush()

result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()

for rtif in rtif_list:
assert rtif in result
for rtif in rtif_list:
assert rtif in result

assert rtif_num == len(result)
assert rtif_num == len(result)

with assert_queries_count(expected_query_count):
RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
assert remaining_rtifs == len(result)
with assert_queries_count(expected_query_count):
RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
assert remaining_rtifs == len(result)

@pytest.mark.parametrize(
("num_runs", "num_to_keep", "remaining_rtifs", "expected_query_count"),
Expand All @@ -292,35 +288,32 @@ def test_delete_old_records_mapped(
Test that old records are deleted from rendered_task_instance_fields table
for a given task_id and dag_id with mapped tasks.
"""
with set_current_task_instance_session(session=session):
with dag_maker("test_delete_old_records", session=session, serialized=True) as dag:
mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"])
for num in range(num_runs):
dr = dag_maker.create_dagrun(
run_id=f"run_{num}", logical_date=dag.start_date + timedelta(days=num)
)
with dag_maker("test_delete_old_records", session=session, serialized=True) as dag:
mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"])
for num in range(num_runs):
dr = dag_maker.create_dagrun(
run_id=f"run_{num}", logical_date=dag.start_date + timedelta(days=num)
)

TaskMap.expand_mapped_task(
dag.task_dict[mapped.task_id], dr.run_id, session=dag_maker.session
)
session.refresh(dr)
for ti in dr.task_instances:
ti.task = mapped
session.add(RTIF(ti))
session.flush()
TaskMap.expand_mapped_task(dag.task_dict[mapped.task_id], dr.run_id, session=dag_maker.session)
session.refresh(dr)
for ti in dr.task_instances:
ti.task = mapped
session.add(RTIF(ti))
session.flush()

result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all()
assert len(result) == num_runs * 2
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all()
assert len(result) == num_runs * 2

with assert_queries_count(expected_query_count):
RTIF.delete_old_records(
task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session
)
result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all()
rtif_num_runs = Counter(rtif.run_id for rtif in result)
assert len(rtif_num_runs) == remaining_rtifs
# Check that we have _all_ the data for each row
assert len(result) == remaining_rtifs * 2
with assert_queries_count(expected_query_count):
RTIF.delete_old_records(
task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session
)
result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all()
rtif_num_runs = Counter(rtif.run_id for rtif in result)
assert len(rtif_num_runs) == remaining_rtifs
# Check that we have _all_ the data for each row
assert len(result) == remaining_rtifs * 2

def test_write(self, dag_maker):
"""
Expand Down
40 changes: 39 additions & 1 deletion providers/standard/tests/unit/standard/decorators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from airflow.exceptions import AirflowException, XComNotFound
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.utils.task_instance_session import set_current_task_instance_session

from tests_common.test_utils.version_compat import (
AIRFLOW_V_3_0_1,
Expand Down Expand Up @@ -818,10 +817,49 @@ def task2(arg1, arg2): ...
assert set(unmapped.op_kwargs) == {"arg1", "arg2"}


@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
def test_mapped_render_template_fields(dag_maker, session):
@task_decorator
def fn(arg1, arg2): ...

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = fn.partial(arg2="{{ ti.task_id }}").expand(arg1=task1.output)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)

ti.xcom_push(key=XCOM_RETURN_KEY, value=["{{ ds }}"], session=session)

session.add(
TaskMap(
dag_id=dr.dag_id,
task_id=task1.task_id,
run_id=dr.run_id,
map_index=-1,
length=1,
keys=None,
)
)
session.flush()

mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session)
mapped_ti.map_index = 0
assert isinstance(mapped_ti.task, MappedOperator)
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, BaseOperator)

assert mapped_ti.task.op_kwargs["arg1"] == "{{ ds }}"
assert mapped_ti.task.op_kwargs["arg2"] == "fn"


@pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Different test for AF 2")
def test_mapped_render_template_fields_af2(dag_maker, session):
from airflow.utils.task_instance_session import set_current_task_instance_session

@task_decorator
def fn(arg1, arg2): ...

with set_current_task_instance_session(session=session):
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
Expand Down
16 changes: 6 additions & 10 deletions providers/weaviate/tests/unit/weaviate/operators/test_weaviate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@

import pytest

from airflow.utils.task_instance_session import set_current_task_instance_session

pytest.importorskip("weaviate")

from airflow.providers.weaviate.operators.weaviate import (
Expand Down Expand Up @@ -87,10 +85,9 @@ def test_partial_batch_hook_params(self, dag_maker, session):

dr = dag_maker.create_dagrun()
tis = dr.get_task_instances(session=session)
with set_current_task_instance_session(session=session):
for ti in tis:
ti.render_templates()
assert ti.task.hook_params == {"baz": "biz"}
for ti in tis:
ti.render_templates()
assert ti.task.hook_params == {"baz": "biz"}


class TestWeaviateDocumentIngestOperator:
Expand Down Expand Up @@ -147,7 +144,6 @@ def test_partial_hook_params(self, dag_maker, session):

dr = dag_maker.create_dagrun()
tis = dr.get_task_instances(session=session)
with set_current_task_instance_session(session=session):
for ti in tis:
ti.render_templates()
assert ti.task.hook_params == {"baz": "biz"}
for ti in tis:
ti.render_templates()
assert ti.task.hook_params == {"baz": "biz"}
Loading