Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions airflow-core/src/airflow/cli/commands/db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,29 +50,39 @@ def resetdb(args):
db.resetdb(skip_init=args.skip_init)


def _get_version_revision(
version: str, recursion_limit: int = 10, revision_heads_map: dict[str, str] | None = None
) -> str | None:
def _get_version_revision(version: str, revision_heads_map: dict[str, str] | None = None) -> str | None:
"""
Recursively search for the revision of the given version in revision_heads_map.
Search for the revision of the given version in revision_heads_map.

This searches given revision_heads_map for the revision of the given version, recursively
searching for the previous version if the given version is not found.

``revision_heads_map`` must already be sorted in the dict in ascending order for this function to work. No
checks are made that this is true
"""
if revision_heads_map is None:
revision_heads_map = _REVISION_HEADS_MAP
# Exact match found, we can just return it
if version in revision_heads_map:
return revision_heads_map[version]

try:
major, minor, patch = map(int, version.split("."))
wanted = tuple(map(int, version.split(".")))
except ValueError:
return None
new_version = f"{major}.{minor}.{patch - 1}"
recursion_limit -= 1
if recursion_limit <= 0:
# Prevent infinite recursion as I can't imagine 10 successive versions without migration

# Else, we walk backwards in the revision map until we find a version that is < the target
for revision, head in reversed(revision_heads_map.items()):
try:
current = tuple(map(int, revision.split(".")))
except ValueError:
log.debug("Unable to parse HEAD revision", exc_info=True)
return None

if current < wanted:
return head
else:
return None
return _get_version_revision(new_version, recursion_limit)


def run_db_migrate_command(args, command, revision_heads_map: dict[str, str]):
Expand Down
28 changes: 14 additions & 14 deletions airflow-core/src/airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,21 +1189,21 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session:
config = _get_alembic_config()
# Check if downgrade is less than 3.0.0 and requires that `ab_user` fab table is present
if _revision_greater(config, _REVISION_HEADS_MAP["2.10.3"], to_revision):
unitest_mode = conf.getboolean("core", "unit_test_mode")
if unitest_mode:
try:
from airflow.providers.fab.auth_manager.models.db import FABDBManager

dbm = FABDBManager(session)
dbm.initdb()
except ImportError:
log.warning("Import error occurred while importing FABDBManager. Skipping the check.")
return
if not inspect(settings.engine).has_table("ab_user") and not unitest_mode:
raise AirflowException(
"Downgrade to revision less than 3.0.0 requires that `ab_user` table is present. "
"Please add FabDBManager to [core] external_db_managers and run fab migrations before proceeding"
try:
from airflow.providers.fab.auth_manager.models.db import FABDBManager
except ImportError:
# Raise the error with a new message
raise RuntimeError(
"Import error occurred while importing FABDBManager. We need that to exist before we can "
"downgrade to <3.0.0"
)
dbm = FABDBManager(session)
if hasattr(dbm, "reset_to_2_x"):
dbm.reset_to_2_x()
else:
# Older version before we added that function, it only has a single migration so we can just
# created
dbm.create_db_from_orm()
with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
if show_sql_only:
log.warning("Generating sql scripts for manual migration.")
Expand Down
6 changes: 0 additions & 6 deletions airflow-core/src/airflow/utils/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,6 @@ def upgradedb(self, session):
m = manager(session)
m.upgradedb()

def downgrade(self, session):
"""Downgrade the external database managers."""
for manager in self._managers:
m = manager(session)
m.downgrade()

def drop_tables(self, session, connection):
"""Drop the external database managers."""
for manager in self._managers:
Expand Down
16 changes: 16 additions & 0 deletions airflow-core/tests/unit/cli/commands/test_db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,3 +658,19 @@ def test_confirm_in_drop_archived_records_command(self, mock_drop_archived_recor
)
db_command.drop_archived(args)
mock_drop_archived_records.assert_called_once_with(table_names=None, needs_confirm=expected)


def test_get_version_revision():
heads: dict[str, str] = {
"2.10.0": "22ed7efa9da2",
"2.10.3": "5f2621c13b39",
"3.0.0": "29ce7909c52b",
"3.0.3": "fe199e1abd77",
"3.1.0": "808787349f22",
}

assert db_command._get_version_revision("3.1.0", heads) == "808787349f22"
assert db_command._get_version_revision("3.1.1", heads) == "808787349f22"
assert db_command._get_version_revision("2.11.1", heads) == "5f2621c13b39"
assert db_command._get_version_revision("2.10.1", heads) == "22ed7efa9da2"
assert db_command._get_version_revision("2.0.0", heads) is None
14 changes: 0 additions & 14 deletions airflow-core/tests/unit/utils/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@
from sqlalchemy import Column, Integer, MetaData, Table, select

from airflow import settings
from airflow.exceptions import AirflowException
from airflow.models import Base as airflow_base
from airflow.utils.db import (
_REVISION_HEADS_MAP,
AutocommitEngineForMySQL,
LazySelectSequence,
_get_alembic_config,
Expand Down Expand Up @@ -380,18 +378,6 @@ def scalar(self, stmt):

assert bool(lss) is False

@conf_vars({("core", "unit_test_mode"): "False"})
def test_downgrade_raises_if_lower_than_v3_0_0_and_no_ab_user(self, mocker):
mock_inspect = mocker.patch("airflow.utils.db.inspect")
mock_inspect.return_value.has_table.return_value = False
msg = (
"Downgrade to revision less than 3.0.0 requires that `ab_user` table is present. "
"Please add FabDBManager to [core] external_db_managers and run fab migrations before "
"proceeding"
)
with pytest.raises(AirflowException, match=re.escape(msg)):
downgrade(to_revision=_REVISION_HEADS_MAP["2.7.0"])


class TestAutocommitEngineForMySQL:
"""Test the AutocommitEngineForMySQL context manager."""
Expand Down
22 changes: 20 additions & 2 deletions devel-common/src/tests_common/test_utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import json
import os
from tempfile import gettempdir
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -101,9 +102,26 @@ def initial_db_init():
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS

db.resetdb()

if AIRFLOW_V_3_0_PLUS:
db.downgrade(to_revision="5f2621c13b39")
db.upgradedb(to_revision="head")
try:
from airflow.providers.fab.auth_manager.models.db import FABDBManager
except ModuleNotFoundError:
# Reasons it might fail: we're in a provider bundle without FAB, or we're on a version of Python
# where FAB isn't yet supported
pass
else:
if os.getenv("TEST_GROUP") != "providers":
# If we loaded the provider, and we're running core (or running via breeze where TEST_GROUP
# isn't specified) run the downgrade+upgrade to ensure migrations are in sync with Model
# classes
db.downgrade(to_revision="5f2621c13b39")
db.upgradedb(to_revision="head")
else:
# Just create the tables so they are there
with create_session() as session:
FABDBManager(session).create_db_from_orm()
session.commit()
else:
from flask import Flask

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def create_db_from_orm(self):
super().create_db_from_orm()
_get_flask_db(settings.SQL_ALCHEMY_CONN).create_all()

def reset_to_2_x(self):
self.create_db_from_orm()
# And ensure it's at the oldest version
self.downgrade(_REVISION_HEADS_MAP["1.4.0"])

def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False):
"""Upgrade the database."""
if from_revision and not show_sql_only:
Expand Down
23 changes: 3 additions & 20 deletions providers/fab/tests/unit/fab/db_manager/test_fab_db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from sqlalchemy import Table

from airflow.exceptions import AirflowException
from airflow.utils.db import downgrade, initdb
from airflow.utils.db import initdb
from airflow.utils.db_manager import RunDBManager

from tests_common.test_utils.config import conf_vars
Expand Down Expand Up @@ -57,27 +57,12 @@ def test_defining_table_same_name_as_airflow_table_name_raises(self):
run_db_manager.validate()
metadata._remove_table("dag_run", None)

@mock.patch.object(RunDBManager, "downgrade")
@mock.patch.object(RunDBManager, "upgradedb")
@mock.patch.object(RunDBManager, "initdb")
def test_init_db_calls_rundbmanager(self, mock_initdb, mock_upgrade_db, mock_downgrade_db, session):
def test_init_db_calls_rundbmanager(self, mock_initdb, mock_upgrade_db, session):
initdb(session=session)
mock_initdb.assert_called()
mock_initdb.assert_called_once_with(session)
mock_downgrade_db.assert_not_called()

@mock.patch.object(RunDBManager, "downgrade")
@mock.patch.object(RunDBManager, "upgradedb")
@mock.patch.object(RunDBManager, "initdb")
@mock.patch("alembic.command")
def test_downgrade_dont_call_rundbmanager(
self, mock_alembic_command, mock_initdb, mock_upgrade_db, mock_downgrade_db, session
):
downgrade(to_revision="base")
mock_alembic_command.downgrade.assert_called_once_with(mock.ANY, revision="base", sql=False)
mock_upgrade_db.assert_not_called()
mock_initdb.assert_not_called()
mock_downgrade_db.assert_not_called()

@conf_vars(
{("database", "external_db_managers"): "airflow.providers.fab.auth_manager.models.db.FABDBManager"}
Expand All @@ -93,9 +78,7 @@ def test_rundbmanager_calls_dbmanager_methods(self, mock_fabdb_manager, session)
# upgradedb
ext_db.upgradedb(session=session)
fabdb_manager.upgradedb.assert_called_once()
# downgrade
ext_db.downgrade(session=session)
mock_fabdb_manager.return_value.downgrade.assert_called_once()
# drop_tables
connection = mock.MagicMock()
ext_db.drop_tables(session, connection)
mock_fabdb_manager.return_value.drop_tables.assert_called_once_with(connection)
Loading