Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun as DR
from airflow.models.expandinput import NotFullyPopulated
from airflow.models.log import Log
from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_tasks
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
Expand All @@ -91,6 +93,30 @@
log = structlog.get_logger(__name__)


def _add_log(
session: SessionDep,
event: TaskInstanceState,
dag_id: str,
task_id: str,
run_id: str,
map_index: int,
try_number: int,
) -> None:
"""Add task instance state change to the audit log."""
session.add(
Log(
event=event,
task_instance=TaskInstanceKey(
dag_id=dag_id,
task_id=task_id,
run_id=run_id,
map_index=map_index,
try_number=try_number,
),
)
)


@ti_id_router.patch(
"/{task_instance_id}/run",
status_code=status.HTTP_200_OK,
Expand Down Expand Up @@ -212,6 +238,15 @@ def ti_run(

try:
result = session.execute(query)
_add_log(
session=session,
event=TaskInstanceState.RUNNING,
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
map_index=ti.map_index,
try_number=ti.try_number,
)
log.info("Task instance state updated", rows_affected=getattr(result, "rowcount", 0))

dr = (
Expand Down Expand Up @@ -354,13 +389,20 @@ def ti_update_state(
bind_contextvars(ti_id=ti_id_str)
log.debug("Updating task instance state", new_state=ti_patch_payload.state)

old = select(TI.state, TI.try_number, TI.max_tries, TI.dag_id).where(TI.id == ti_id_str).with_for_update()
old = (
select(TI.state, TI.try_number, TI.max_tries, TI.dag_id, TI.task_id, TI.run_id, TI.map_index)
.where(TI.id == ti_id_str)
.with_for_update()
)
try:
(
previous_state,
try_number,
max_tries,
dag_id,
task_id,
run_id,
map_index,
) = session.execute(old).one()
log.debug(
"Retrieved current task instance state",
Expand Down Expand Up @@ -422,11 +464,21 @@ def ti_update_state(
# https://fastapi.tiangolo.com/tutorial/handling-errors/#install-custom-exception-handlers
try:
result = session.execute(query)
_add_log(
session=session,
event=updated_state,
dag_id=dag_id,
task_id=task_id,
run_id=run_id,
map_index=map_index,
try_number=try_number,
)
log.info(
"Task instance state updated",
new_state=updated_state,
rows_affected=getattr(result, "rowcount", 0),
)

except SQLAlchemyError as e:
log.error("Error updating Task Instance state", error=str(e))
raise HTTPException(
Expand Down
2 changes: 2 additions & 0 deletions airflow-core/src/airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
)
.execution_options(synchronize_session=False)
)
# add queue events to audit log
session.add_all([Log(event=TaskInstanceState.QUEUED, task_instance=ti) for ti in executable_tis])

for ti in executable_tis:
ti.emit_state_change_metric(TaskInstanceState.QUEUED)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.exceptions import AirflowSkipException
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel
from airflow.models.log import Log
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.providers.standard.operators.empty import EmptyOperator
Expand All @@ -44,6 +45,7 @@
from tests_common.test_utils.db import (
clear_db_assets,
clear_db_dags,
clear_db_logs,
clear_db_runs,
clear_db_serialized_dags,
clear_rendered_ti_fields,
Expand Down Expand Up @@ -130,11 +132,13 @@ def setup_method(self):
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()
clear_db_logs()

def teardown_method(self):
clear_db_runs()
clear_db_serialized_dags()
clear_db_dags()
clear_db_logs()

@pytest.mark.parametrize(
("max_tries", "should_retry"),
Expand Down Expand Up @@ -848,15 +852,67 @@ def test_ti_run_with_triggering_user_name(
assert dag_run["run_id"] == "test"
assert dag_run["state"] == "running"

def test_ti_run_creates_audit_log_entry(
self,
client,
session,
create_task_instance,
time_machine,
):
"""Test that calling ti_run creates an audit log entry for the task instance RUNNING state."""
instant_str = "2026-01-16T12:00:00Z"
instant = timezone.parse(instant_str)
time_machine.move_to(instant, tick=False)

ti = create_task_instance(
task_id="test_ti_run_creates_audit_log_running",
state=State.QUEUED,
dagrun_state=DagRunState.RUNNING,
session=session,
start_date=instant,
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/run",
json={
"state": "running",
"hostname": "test-hostname",
"unixname": "test-user",
"pid": 100,
"start_date": instant_str,
},
)

assert response.status_code == 200

audit_log_entry = session.scalar(
select(Log).where(
Log.dag_id == ti.dag_id,
Log.task_id == ti.task_id,
Log.run_id == ti.run_id,
Log.event == TaskInstanceState.RUNNING.value,
)
)

assert audit_log_entry is not None, "Audit log entry should be inserted for RUNNING state"
assert audit_log_entry.dag_id == ti.dag_id
assert audit_log_entry.task_id == ti.task_id
assert audit_log_entry.run_id == ti.run_id
assert audit_log_entry.map_index == ti.map_index
assert audit_log_entry.try_number == ti.try_number


class TestTIUpdateState:
def setup_method(self):
clear_db_assets()
clear_db_runs()
clear_db_logs()

def teardown_method(self):
clear_db_assets()
clear_db_runs()
clear_db_logs()

@pytest.mark.parametrize(
("state", "end_date", "expected_state"),
Expand Down Expand Up @@ -1118,8 +1174,12 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta
mock.patch(
"airflow.api_fastapi.common.db.common.Session.execute",
side_effect=[
mock.Mock(one=lambda: ("running", 1, 0, "dag")), # First call returns "queued"
mock.Mock(one=lambda: ("running", 1, 0, "dag")), # Second call returns "queued"
mock.Mock(
one=lambda: ("running", 1, 0, "dag", "task_id", "run_id", -1)
), # First call returns "queued"
mock.Mock(
one=lambda: ("running", 1, 0, "dag", "task_id", "run_id", -1)
), # Second call returns "queued"
SQLAlchemyError("Database error"), # Last call raises an error
],
),
Expand Down Expand Up @@ -1516,6 +1576,98 @@ def test_ti_update_state_to_failed_with_fail_fast(self, client, session, dag_mak
ti1 = session.get(TaskInstance, ti1.id)
assert ti1.state == State.FAILED

@pytest.mark.parametrize(
"terminal_state",
[
pytest.param(State.SUCCESS, id=State.SUCCESS),
pytest.param(State.FAILED, id=State.FAILED),
pytest.param(State.SKIPPED, id=State.SKIPPED),
],
)
def test_ti_update_state_creates_audit_log_for_terminal_states(
self,
client,
session,
create_task_instance,
terminal_state,
):
"""Test that calling ti_update_state creates an audit log entry for the terminal state."""
ti = create_task_instance(
task_id=f"test_ti_update_state_creates_audit_log_{terminal_state}",
state=State.RUNNING,
start_date=DEFAULT_START_DATE,
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": terminal_state,
"end_date": DEFAULT_END_DATE.isoformat(),
},
)

assert response.status_code == 204

audit_log_entry = session.scalar(
select(Log).where(
Log.dag_id == ti.dag_id,
Log.task_id == ti.task_id,
Log.run_id == ti.run_id,
Log.event == terminal_state,
)
)

assert audit_log_entry is not None, f"Audit log entry should be inserted for {terminal_state} state"
assert audit_log_entry.dag_id == ti.dag_id
assert audit_log_entry.task_id == ti.task_id
assert audit_log_entry.run_id == ti.run_id
assert audit_log_entry.map_index == ti.map_index
assert audit_log_entry.try_number == ti.try_number

def test_ti_update_state_creates_audit_log_for_deferred_state(
self,
client,
session,
create_task_instance,
):
"""Test that calling ti_update_state for deferred state creates an audit log entry."""
ti = create_task_instance(
task_id="test_ti_update_state_creates_audit_log_deferred",
state=State.RUNNING,
start_date=DEFAULT_START_DATE,
)
session.commit()

response = client.patch(
f"/execution/task-instances/{ti.id}/state",
json={
"state": "deferred",
"trigger_kwargs": {"moment": "2026-01-16T00:00:00Z"},
"trigger_timeout": "P1D",
"classpath": "my-classpath",
"next_method": "execute_callback",
},
)

assert response.status_code == 204

audit_log_entry = session.scalar(
select(Log).where(
Log.dag_id == ti.dag_id,
Log.task_id == ti.task_id,
Log.run_id == ti.run_id,
Log.event == TaskInstanceState.DEFERRED.value,
)
)

assert audit_log_entry is not None, "Audit log entry should be inserted for DEFERRED state"
assert audit_log_entry.dag_id == ti.dag_id
assert audit_log_entry.task_id == ti.task_id
assert audit_log_entry.run_id == ti.run_id
assert audit_log_entry.map_index == ti.map_index
assert audit_log_entry.try_number == ti.try_number


class TestTISkipDownstream:
def setup_method(self):
Expand Down
39 changes: 39 additions & 0 deletions airflow-core/tests/unit/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@
clear_db_deadline,
clear_db_import_errors,
clear_db_jobs,
clear_db_logs,
clear_db_pools,
clear_db_runs,
clear_db_teams,
Expand Down Expand Up @@ -201,6 +202,7 @@ def _clean_db():
clear_db_pools()
clear_db_import_errors()
clear_db_jobs()
clear_db_logs()
clear_db_assets()
clear_db_deadline()
clear_db_callbacks()
Expand Down Expand Up @@ -1000,6 +1002,43 @@ def test_find_executable_task_instances_backfill(self, dag_maker):
assert {x.key for x in queued_tis} == {ti_non_backfill.key, ti_backfill.key}
session.rollback()

def test_executable_task_instances_to_queued_creates_audit_log(self, dag_maker, session):
"""Test that queuing tasks creates audit log entries."""
dag_id = "SchedulerJobTest.test_executable_task_instances_to_queued_creates_audit_log"
task_id = "dummy"

with dag_maker(dag_id=dag_id, session=session):
EmptyOperator(task_id=task_id)

scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job)

dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
ti = dr.get_task_instance(task_id, session)
ti.state = State.SCHEDULED
session.merge(ti)
session.flush()

# Queue the task
queued_tis = self.job_runner._executable_task_instances_to_queued(max_tis=8, session=session)
session.flush()

assert len(queued_tis) == 1

# Verify audit log was created
audit_log_entry = session.scalar(
select(Log).where(
Log.dag_id == dag_id,
Log.task_id == task_id,
Log.event == TaskInstanceState.QUEUED.value,
)
)

assert audit_log_entry is not None, "Audit log entry should be created for QUEUED state"
assert audit_log_entry.dag_id == ti.dag_id
assert audit_log_entry.task_id == ti.task_id
assert audit_log_entry.run_id == ti.run_id

def test_find_executable_task_instances_pool(self, dag_maker):
dag_id = "SchedulerJobTest.test_find_executable_task_instances_pool"
task_id_1 = "dummy"
Expand Down
Loading