Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion airflow-core/docs/img/airflow_erd.sha256
Original file line number Diff line number Diff line change
@@ -1 +1 @@
7d6ced1b0a60a60c192ccc102b750c2d893d1c388625f8a90bb95ce457d0d9c4
5fb3647b8dc66e0c22e13797f2b3b05a9c09ac9e4a6d0d82b03f1556197fa219
1,670 changes: 835 additions & 835 deletions airflow-core/docs/img/airflow_erd.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion airflow-core/docs/migrations-ref.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ Here's the list of all the Database Migrations that are executed via when you ru
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| Revision ID | Revises ID | Airflow Version | Description |
+=========================+==================+===================+==============================================================+
| ``53ff648b8a26`` (head) | ``a5a3e5eb9b8d`` | ``3.2.0`` | Add revoked_token table. |
| ``f8c9d7e6b5a4`` (head) | ``53ff648b8a26`` | ``3.2.0`` | Standardize UUID column format for non-PostgreSQL databases. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``53ff648b8a26`` | ``a5a3e5eb9b8d`` | ``3.2.0`` | Add revoked_token table. |
+-------------------------+------------------+-------------------+--------------------------------------------------------------+
| ``a5a3e5eb9b8d`` | ``82dbd68e6171`` | ``3.2.0`` | Make external_executor_id TEXT to allow for longer |
| | | | external_executor_ids. |
Expand Down
1 change: 0 additions & 1 deletion airflow-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ dependencies = [
"setproctitle>=1.3.3",
# SQLAlchemy >=2.0.36 fixes Python 3.13 TypingOnly import AssertionError caused by new typing attributes (__static_attributes__, __firstlineno__)
"sqlalchemy[asyncio]>=2.0.36",
"sqlalchemy-utils>=0.41.2",
"svcs>=25.1.0",
"tabulate>=0.9.0",
"tenacity>=8.3.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Iterable
from datetime import datetime
from typing import Annotated, Any
from uuid import UUID

from pydantic import (
AliasPath,
Expand All @@ -42,7 +43,7 @@
class TaskInstanceResponse(BaseModel):
"""TaskInstance serializer for responses."""

id: str
id: UUID
task_id: str
dag_id: str
run_id: str = Field(alias="dag_run_id")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2545,6 +2545,7 @@ components:
properties:
id:
type: string
format: uuid
title: Id
task_id:
type: string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12582,6 +12582,7 @@ components:
properties:
id:
type: string
format: uuid
title: Id
task_id:
type: string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,6 @@ async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str
.join(DagModel, DagModel.dag_id == TaskInstance.dag_id)
.join(DagBundleModel, DagBundleModel.name == DagModel.bundle_name)
.join(DagBundleModel.teams)
.where(TaskInstance.id == str(token.id))
.where(TaskInstance.id == token.id)
)
return await session.scalar(stmt)
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,10 @@ def upsert_hitl_detail(
This happens when a task instance is cleared after a response has been received.
This design ensures that each task instance has only one HITLDetail.
"""
ti_id_str = str(task_instance_id)
hitl_detail_model = session.scalar(select(HITLDetail).where(HITLDetail.ti_id == ti_id_str))
hitl_detail_model = session.scalar(select(HITLDetail).where(HITLDetail.ti_id == task_instance_id))
if not hitl_detail_model:
hitl_detail_model = HITLDetail(
ti_id=ti_id_str,
ti_id=task_instance_id,
options=payload.options,
subject=payload.subject,
body=payload.body,
Expand Down Expand Up @@ -109,15 +108,14 @@ def update_hitl_detail(
session: SessionDep,
) -> HITLDetailResponse:
"""Update the response part of a Human-in-the-loop detail for a specific Task Instance."""
ti_id_str = str(task_instance_id)
hitl_detail_model_result = session.execute(
select(HITLDetail).where(HITLDetail.ti_id == ti_id_str)
select(HITLDetail).where(HITLDetail.ti_id == task_instance_id)
).scalar()
hitl_detail_model = _check_hitl_detail_exists(hitl_detail_model_result)
if hitl_detail_model.response_received:
raise HTTPException(
status.HTTP_409_CONFLICT,
f"Human-in-the-loop detail for Task Instance with id {ti_id_str} already exists.",
f"Human-in-the-loop detail for Task Instance with id {task_instance_id} already exists.",
)

hitl_detail_model.responded_by = None
Expand All @@ -138,9 +136,8 @@ def get_hitl_detail(
session: SessionDep,
) -> HITLDetailResponse:
"""Get Human-in-the-loop detail for a specific Task Instance."""
ti_id_str = str(task_instance_id)
hitl_detail_model_result = session.execute(
select(HITLDetail).where(HITLDetail.ti_id == ti_id_str),
select(HITLDetail).where(HITLDetail.ti_id == task_instance_id),
).scalar()
hitl_detail_model = _check_hitl_detail_exists(hitl_detail_model_result)
return HITLDetailResponse.from_hitl_detail_orm(hitl_detail_model)
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,7 @@ def ti_run(

This endpoint is used to start a TaskInstance that is in the QUEUED state.
"""
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)
bind_contextvars(ti_id=ti_id_str)
bind_contextvars(ti_id=str(task_instance_id))
log.debug(
"Starting task instance run",
hostname=ti_run_payload.hostname,
Expand Down Expand Up @@ -140,7 +138,7 @@ def ti_run(
column("next_kwargs", JSON),
)
.select_from(TI)
.where(TI.id == ti_id_str)
.where(TI.id == task_instance_id)
.with_for_update()
)
try:
Expand All @@ -164,7 +162,7 @@ def ti_run(
data.pop("start_date")
log.debug("Removed start_date from update as task is resuming from deferral")

query = update(TI).where(TI.id == ti_id_str).values(data)
query = update(TI).where(TI.id == task_instance_id).values(data)

previous_state = ti.state

Expand Down Expand Up @@ -244,7 +242,9 @@ def ti_run(

xcom_keys = list(session.scalars(xcom_query))
task_reschedule_count = (
session.scalar(select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == ti_id_str))
session.scalar(
select(func.count(TaskReschedule.id)).where(TaskReschedule.ti_id == task_instance_id)
)
or 0
)

Expand Down Expand Up @@ -293,12 +293,14 @@ def ti_update_state(
Not all state transitions are valid, and transitioning to some states requires extra information to be
passed along. (Check out the datamodels for details, the rendered docs might not reflect this accurately)
"""
# We only use UUID above for validation purposes
ti_id_str = str(task_instance_id)
bind_contextvars(ti_id=ti_id_str)
bind_contextvars(ti_id=str(task_instance_id))
log.debug("Updating task instance state", new_state=ti_patch_payload.state)

old = select(TI.state, TI.try_number, TI.max_tries, TI.dag_id).where(TI.id == ti_id_str).with_for_update()
old = (
select(TI.state, TI.try_number, TI.max_tries, TI.dag_id)
.where(TI.id == task_instance_id)
.with_for_update()
)
try:
(
previous_state,
Expand Down Expand Up @@ -338,12 +340,12 @@ def ti_update_state(

# We exclude_unset to avoid updating fields that are not set in the payload
data = ti_patch_payload.model_dump(exclude={"task_outlets", "outlet_events"}, exclude_unset=True)
query = update(TI).where(TI.id == ti_id_str).values(data)
query = update(TI).where(TI.id == task_instance_id).values(data)

try:
query, updated_state = _create_ti_state_update_query_and_update_state(
ti_patch_payload=ti_patch_payload,
ti_id_str=ti_id_str,
task_instance_id=task_instance_id,
session=session,
query=query,
dag_id=dag_id,
Expand All @@ -355,7 +357,7 @@ def ti_update_state(
"Error updating Task Instance state. Setting the task to failed.",
payload=ti_patch_payload,
)
ti = session.get(TI, ti_id_str, with_for_update=True)
ti = session.get(TI, task_instance_id, with_for_update=True)
if session.bind is not None:
query = TI.duration_expression_update(timezone.utcnow(), query, session.bind)
query = query.values(state=(updated_state := TaskInstanceState.FAILED))
Expand Down Expand Up @@ -398,14 +400,14 @@ def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep, dag_bag:
def _create_ti_state_update_query_and_update_state(
*,
ti_patch_payload: TIStateUpdate,
ti_id_str: str,
task_instance_id: UUID,
query: Update,
session: SessionDep,
dag_bag: DagBagDep,
dag_id: str,
) -> tuple[Update, TaskInstanceState]:
if isinstance(ti_patch_payload, (TITerminalStatePayload, TIRetryStatePayload, TISuccessStatePayload)):
ti = session.get(TI, ti_id_str, with_for_update=True)
ti = session.get(TI, task_instance_id, with_for_update=True)
updated_state = TaskInstanceState(ti_patch_payload.state.value)
if session.bind is not None:
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
Expand Down Expand Up @@ -450,7 +452,7 @@ def _create_ti_state_update_query_and_update_state(
# TODO: HANDLE execution timeout later as it requires a call to the DB
# either get it from the serialised DAG or get it from the API

query = update(TI).where(TI.id == ti_id_str)
query = update(TI).where(TI.id == task_instance_id)

# Store next_kwargs directly (already serialized by worker)
query = query.values(
Expand All @@ -477,7 +479,7 @@ def _create_ti_state_update_query_and_update_state(
mysql_timestamp_max=_MYSQL_TIMESTAMP_MAX,
)
data = ti_patch_payload.model_dump(exclude={"reschedule_date"}, exclude_unset=True)
query = update(TI).where(TI.id == ti_id_str).values(data)
query = update(TI).where(TI.id == task_instance_id).values(data)
if session.bind is not None:
query = TI.duration_expression_update(timezone.utcnow(), query, session.bind)
query = query.values(state=TaskInstanceState.FAILED)
Expand All @@ -486,19 +488,19 @@ def _create_ti_state_update_query_and_update_state(
# in SQLA2. The task is marked as FAILED regardless.
return query, TaskInstanceState.FAILED

# We can directly use ti_id_str instead of fetching the TaskInstance object to avoid SQLA2
# We can directly use task_instance_id instead of fetching the TaskInstance object to avoid SQLA2
# lock contention issues when the TaskInstance row is already locked from before.
actual_start_date = timezone.utcnow()
session.add(
TaskReschedule(
ti_id_str,
task_instance_id,
actual_start_date,
ti_patch_payload.end_date,
ti_patch_payload.reschedule_date,
)
)

query = update(TI).where(TI.id == ti_id_str)
query = update(TI).where(TI.id == task_instance_id)
# calculate the duration for TI table too
if session.bind is not None:
query = TI.duration_expression_update(ti_patch_payload.end_date, query, session.bind)
Expand All @@ -524,14 +526,13 @@ def ti_skip_downstream(
ti_patch_payload: TISkippedDownstreamTasksStatePayload,
session: SessionDep,
):
ti_id_str = str(task_instance_id)
bind_contextvars(ti_id=ti_id_str)
bind_contextvars(ti_id=str(task_instance_id))
log.info("Skipping downstream tasks", task_count=len(ti_patch_payload.tasks))

now = timezone.utcnow()
tasks = ti_patch_payload.tasks

query_result = session.execute(select(TI.dag_id, TI.run_id).where(TI.id == ti_id_str))
query_result = session.execute(select(TI.dag_id, TI.run_id).where(TI.id == task_instance_id))
row_result = query_result.fetchone()
if row_result is None:
raise HTTPException(
Expand Down Expand Up @@ -572,14 +573,13 @@ def ti_heartbeat(
session: SessionDep,
):
"""Update the heartbeat of a TaskInstance to mark it as alive & still running."""
ti_id_str = str(task_instance_id)
bind_contextvars(ti_id=ti_id_str)
bind_contextvars(ti_id=str(task_instance_id))
log.debug("Processing heartbeat", hostname=ti_payload.hostname, pid=ti_payload.pid)

# Hot path: since heartbeating a task is a very common operation, we try to do minimize the number of queries
# and DB round trips as much as possible.

old = select(TI.state, TI.hostname, TI.pid).where(TI.id == ti_id_str).with_for_update()
old = select(TI.state, TI.hostname, TI.pid).where(TI.id == task_instance_id).with_for_update()

try:
(previous_state, hostname, pid) = session.execute(old).one()
Expand Down Expand Up @@ -626,7 +626,7 @@ def ti_heartbeat(
)

# Update the last heartbeat time!
session.execute(update(TI).where(TI.id == ti_id_str).values(last_heartbeat_at=timezone.utcnow()))
session.execute(update(TI).where(TI.id == task_instance_id).values(last_heartbeat_at=timezone.utcnow()))
log.debug("Heartbeat updated", state=previous_state)


Expand All @@ -651,11 +651,10 @@ def ti_put_rtif(
session: SessionDep,
):
"""Add an RTIF entry for a task instance, sent by the worker."""
ti_id_str = str(task_instance_id)
bind_contextvars(ti_id=ti_id_str)
bind_contextvars(ti_id=str(task_instance_id))
log.info("Updating RenderedTaskInstanceFields", field_count=len(put_rtif_payload))

task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
task_instance = session.scalar(select(TI).where(TI.id == task_instance_id))
if not task_instance:
log.error("Task Instance not found")
raise HTTPException(
Expand All @@ -681,8 +680,7 @@ def ti_patch_rendered_map_index(
session: SessionDep,
):
"""Update rendered_map_index for a task instance, sent by the worker during task execution."""
ti_id_str = str(task_instance_id)
bind_contextvars(ti_id=ti_id_str)
bind_contextvars(ti_id=str(task_instance_id))

if not rendered_map_index:
log.error("rendered_map_index cannot be empty")
Expand All @@ -693,7 +691,7 @@ def ti_patch_rendered_map_index(

log.debug("Updating rendered_map_index", length=len(rendered_map_index))

query = update(TI).where(TI.id == ti_id_str).values(rendered_map_index=rendered_map_index)
query = update(TI).where(TI.id == task_instance_id).values(rendered_map_index=rendered_map_index)
result = session.execute(query)

result = cast("CursorResult[Any]", result)
Expand All @@ -720,11 +718,10 @@ def get_previous_successful_dagrun(

The data from this endpoint is used to get values for Task Context.
"""
ti_id_str = str(task_instance_id)
bind_contextvars(ti_id=ti_id_str)
bind_contextvars(ti_id=str(task_instance_id))
log.debug("Retrieving previous successful DAG run")

task_instance = session.scalar(select(TI).where(TI.id == ti_id_str))
task_instance = session.scalar(select(TI).where(TI.id == task_instance_id))
if not task_instance or not task_instance.logical_date:
log.debug("No task instance or logical date found")
return PrevSuccessfulDagRunResponse()
Expand Down Expand Up @@ -977,10 +974,9 @@ def validate_inlets_and_outlets(
dag_bag: DagBagDep,
) -> InactiveAssetsResponse:
"""Validate whether there're inactive assets in inlets and outlets of a given task instance."""
ti_id_str = str(task_instance_id)
bind_contextvars(ti_id=ti_id_str)
bind_contextvars(ti_id=str(task_instance_id))

ti = session.scalar(select(TI).where(TI.id == ti_id_str))
ti = session.scalar(select(TI).where(TI.id == task_instance_id))
if not ti:
log.error("Task Instance not found")
raise HTTPException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_start_date(task_instance_id: UUID, session: SessionDep) -> UtcDateTime |
"""Get the first reschedule date if found, None if no records exist."""
start_date = session.scalar(
select(TaskReschedule.start_date)
.where(TaskReschedule.ti_id == str(task_instance_id))
.where(TaskReschedule.ti_id == task_instance_id)
.order_by(TaskReschedule.id.asc())
.limit(1)
)
Expand Down
Loading