Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 46 additions & 9 deletions airflow/utils/db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class _TableConfig:
in the table. to ignore certain records even if they are the latest in the table, you can
supply additional filters here (e.g. externally triggered dag runs)
:param keep_last_group_by: if keeping the last record, can keep the last record for each group
:param dependent_tables: list of tables which have FK relationship with this table
"""

table_name: str
Expand All @@ -76,6 +77,10 @@ class _TableConfig:
keep_last: bool = False
keep_last_filters: Any | None = None
keep_last_group_by: Any | None = None
# We explicitly list these tables instead of detecting foreign keys automatically,
# because the relationships are unlikely to change and the number of tables is small.
# Relying on automation here would increase complexity and reduce maintainability.
dependent_tables: list[str] | None = None

def __post_init__(self):
self.recency_column = column(self.recency_column_name)
Expand Down Expand Up @@ -107,20 +112,29 @@ def readable_config(self):
keep_last=True,
keep_last_filters=[column("external_trigger") == false()],
keep_last_group_by=["dag_id"],
dependent_tables=["task_instance"],
),
_TableConfig(table_name="dataset_event", recency_column_name="timestamp"),
_TableConfig(table_name="import_error", recency_column_name="timestamp"),
_TableConfig(table_name="log", recency_column_name="dttm"),
_TableConfig(table_name="sla_miss", recency_column_name="timestamp"),
_TableConfig(table_name="task_fail", recency_column_name="start_date"),
_TableConfig(table_name="task_instance", recency_column_name="start_date"),
_TableConfig(
table_name="task_instance",
recency_column_name="start_date",
dependent_tables=["task_instance_history", "xcom"],
),
_TableConfig(table_name="task_instance_history", recency_column_name="start_date"),
_TableConfig(table_name="task_reschedule", recency_column_name="start_date"),
_TableConfig(table_name="xcom", recency_column_name="timestamp"),
_TableConfig(table_name="callback_request", recency_column_name="created_at"),
_TableConfig(table_name="celery_taskmeta", recency_column_name="date_done"),
_TableConfig(table_name="celery_tasksetmeta", recency_column_name="date_done"),
_TableConfig(table_name="trigger", recency_column_name="created_date"),
_TableConfig(
table_name="trigger",
recency_column_name="created_date",
dependent_tables=["task_instance"],
),
]

if conf.get("webserver", "session_backend") == "database":
Expand Down Expand Up @@ -363,17 +377,37 @@ def _suppress_with_logging(table, session):
session.rollback()


def _effective_table_names(*, table_names: list[str] | None):
def _effective_table_names(*, table_names: list[str] | None) -> tuple[list[str], dict[str, _TableConfig]]:
desired_table_names = set(table_names or config_dict)
effective_config_dict = {k: v for k, v in config_dict.items() if k in desired_table_names}
effective_table_names = set(effective_config_dict)
if desired_table_names != effective_table_names:
outliers = desired_table_names - effective_table_names

outliers = desired_table_names - set(config_dict.keys())
if outliers:
logger.warning(
"The following table(s) are not valid choices and will be skipped: %s", sorted(outliers)
"The following table(s) are not valid choices and will be skipped: %s",
sorted(outliers),
)
if not effective_table_names:
desired_table_names = desired_table_names - outliers

visited: set[str] = set()
effective_table_names: list[str] = []

def collect_deps(table: str):
if table in visited:
return
visited.add(table)
config = config_dict[table]
for dep in config.dependent_tables or []:
collect_deps(dep)
effective_table_names.append(table)

for table_name in desired_table_names:
collect_deps(table_name)

effective_config_dict = {n: config_dict[n] for n in effective_table_names}

if not effective_config_dict:
raise SystemExit("No tables selected for db cleanup. Please choose valid table names.")

return effective_table_names, effective_config_dict


Expand Down Expand Up @@ -421,6 +455,8 @@ def run_cleanup(
:param session: Session representing connection to the metadata database.
"""
clean_before_timestamp = timezone.coerce_datetime(clean_before_timestamp)

# Get all tables to clean (root + dependents)
effective_table_names, effective_config_dict = _effective_table_names(table_names=table_names)
if dry_run:
print("Performing dry run for db cleanup.")
Expand All @@ -432,6 +468,7 @@ def run_cleanup(
if not dry_run and confirm:
_confirm_delete(date=clean_before_timestamp, tables=sorted(effective_table_names))
existing_tables = reflect_tables(tables=None, session=session).tables

for table_name, table_config in effective_config_dict.items():
if table_name in existing_tables:
with _suppress_with_logging(table_name, session):
Expand Down
47 changes: 46 additions & 1 deletion tests/utils/test_db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import pendulum
import pytest
from sqlalchemy import text
from sqlalchemy import inspect, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.ext.declarative import DeclarativeMeta

Expand Down Expand Up @@ -269,6 +269,51 @@ def test__cleanup_table(self, table_name, date_add_kwargs, expected_to_delete, e
else:
raise Exception("unexpected")

@pytest.mark.parametrize(
"table_name, expected_archived",
[
(
"dag_run",
{"dag_run", "task_instance"}, # Only these are populated
),
],
)
def test_run_cleanup_archival_integration(self, table_name, expected_archived):
"""
Integration test that verifies:
1. Recursive FK-dependent tables are resolved via _effective_table_names().
2. run_cleanup() archives only tables with data.
3. Archive tables are not created for empty dependent tables.
"""
base_date = pendulum.datetime(2022, 1, 1, tz="UTC")
num_tis = 5

# Create test data for DAG Run and TIs
if table_name in {"dag_run", "task_instance"}:
create_tis(base_date=base_date, num_tis=num_tis, external_trigger=False)

clean_before_date = base_date.add(days=10)

with create_session() as session:
run_cleanup(
clean_before_timestamp=clean_before_date,
table_names=[table_name],
dry_run=False,
confirm=False,
session=session,
)

# Inspect archive tables created
inspector = inspect(session.bind)
archive_tables = {
name for name in inspector.get_table_names() if name.startswith(ARCHIVE_TABLE_PREFIX)
}
actual_archived = {t.split("__", 1)[-1].split("__")[0] for t in archive_tables}

assert (
expected_archived <= actual_archived
), f"Expected archive tables not found: {expected_archived - actual_archived}"

@pytest.mark.parametrize(
"skip_archive, expected_archives",
[pytest.param(True, 0, id="skip_archive"), pytest.param(False, 1, id="do_archive")],
Expand Down