diff --git a/airflow-core/src/airflow/cli/commands/db_command.py b/airflow-core/src/airflow/cli/commands/db_command.py index 0681639ad2130..ad9bfc393adf9 100644 --- a/airflow-core/src/airflow/cli/commands/db_command.py +++ b/airflow-core/src/airflow/cli/commands/db_command.py @@ -50,29 +50,39 @@ def resetdb(args): db.resetdb(skip_init=args.skip_init) -def _get_version_revision( - version: str, recursion_limit: int = 10, revision_heads_map: dict[str, str] | None = None -) -> str | None: +def _get_version_revision(version: str, revision_heads_map: dict[str, str] | None = None) -> str | None: """ - Recursively search for the revision of the given version in revision_heads_map. + Search for the revision of the given version in revision_heads_map. This searches given revision_heads_map for the revision of the given version, recursively searching for the previous version if the given version is not found. + + ``revision_heads_map`` must already be sorted in the dict in ascending order for this function to work. No + checks are made that this is true """ if revision_heads_map is None: revision_heads_map = _REVISION_HEADS_MAP + # Exact match found, we can just return it if version in revision_heads_map: return revision_heads_map[version] + try: - major, minor, patch = map(int, version.split(".")) + wanted = tuple(map(int, version.split("."))) except ValueError: return None - new_version = f"{major}.{minor}.{patch - 1}" - recursion_limit -= 1 - if recursion_limit <= 0: - # Prevent infinite recursion as I can't imagine 10 successive versions without migration + + # Else, we walk backwards in the revision map until we find a version that is < the target + for revision, head in reversed(revision_heads_map.items()): + try: + current = tuple(map(int, revision.split("."))) + except ValueError: + log.debug("Unable to parse HEAD revision", exc_info=True) + return None + + if current < wanted: + return head + else: return None - return _get_version_revision(new_version, recursion_limit) def run_db_migrate_command(args, command, revision_heads_map: dict[str, str]): diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index cd0fea8974d5d..1c686050d665a 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -1189,21 +1189,21 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session: config = _get_alembic_config() # Check if downgrade is less than 3.0.0 and requires that `ab_user` fab table is present if _revision_greater(config, _REVISION_HEADS_MAP["2.10.3"], to_revision): - unitest_mode = conf.getboolean("core", "unit_test_mode") - if unitest_mode: - try: - from airflow.providers.fab.auth_manager.models.db import FABDBManager - - dbm = FABDBManager(session) - dbm.initdb() - except ImportError: - log.warning("Import error occurred while importing FABDBManager. Skipping the check.") - return - if not inspect(settings.engine).has_table("ab_user") and not unitest_mode: - raise AirflowException( - "Downgrade to revision less than 3.0.0 requires that `ab_user` table is present. " - "Please add FabDBManager to [core] external_db_managers and run fab migrations before proceeding" + try: + from airflow.providers.fab.auth_manager.models.db import FABDBManager + except ImportError: + # Raise the error with a new message + raise RuntimeError( + "Import error occurred while importing FABDBManager. We need that to exist before we can " + "downgrade to <3.0.0" ) + dbm = FABDBManager(session) + if hasattr(dbm, "reset_to_2_x"): + dbm.reset_to_2_x() + else: + # Older version before we added that function, it only has a single migration so we can just + # created + dbm.create_db_from_orm() with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): if show_sql_only: log.warning("Generating sql scripts for manual migration.") diff --git a/airflow-core/src/airflow/utils/db_manager.py b/airflow-core/src/airflow/utils/db_manager.py index c746cdaca08d9..37d09b4a60a62 100644 --- a/airflow-core/src/airflow/utils/db_manager.py +++ b/airflow-core/src/airflow/utils/db_manager.py @@ -216,12 +216,6 @@ def upgradedb(self, session): m = manager(session) m.upgradedb() - def downgrade(self, session): - """Downgrade the external database managers.""" - for manager in self._managers: - m = manager(session) - m.downgrade() - def drop_tables(self, session, connection): """Drop the external database managers.""" for manager in self._managers: diff --git a/airflow-core/tests/unit/cli/commands/test_db_command.py b/airflow-core/tests/unit/cli/commands/test_db_command.py index e910cc1e3148b..eeac6aa16885d 100644 --- a/airflow-core/tests/unit/cli/commands/test_db_command.py +++ b/airflow-core/tests/unit/cli/commands/test_db_command.py @@ -658,3 +658,19 @@ def test_confirm_in_drop_archived_records_command(self, mock_drop_archived_recor ) db_command.drop_archived(args) mock_drop_archived_records.assert_called_once_with(table_names=None, needs_confirm=expected) + + +def test_get_version_revision(): + heads: dict[str, str] = { + "2.10.0": "22ed7efa9da2", + "2.10.3": "5f2621c13b39", + "3.0.0": "29ce7909c52b", + "3.0.3": "fe199e1abd77", + "3.1.0": "808787349f22", + } + + assert db_command._get_version_revision("3.1.0", heads) == "808787349f22" + assert db_command._get_version_revision("3.1.1", heads) == "808787349f22" + assert db_command._get_version_revision("2.11.1", heads) == "5f2621c13b39" + assert db_command._get_version_revision("2.10.1", heads) == "22ed7efa9da2" + assert db_command._get_version_revision("2.0.0", heads) is None diff --git a/airflow-core/tests/unit/utils/test_db.py b/airflow-core/tests/unit/utils/test_db.py index 66cddbbd5ed43..4a0d80f907fb9 100644 --- a/airflow-core/tests/unit/utils/test_db.py +++ b/airflow-core/tests/unit/utils/test_db.py @@ -33,10 +33,8 @@ from sqlalchemy import Column, Integer, MetaData, Table, select from airflow import settings -from airflow.exceptions import AirflowException from airflow.models import Base as airflow_base from airflow.utils.db import ( - _REVISION_HEADS_MAP, AutocommitEngineForMySQL, LazySelectSequence, _get_alembic_config, @@ -380,18 +378,6 @@ def scalar(self, stmt): assert bool(lss) is False - @conf_vars({("core", "unit_test_mode"): "False"}) - def test_downgrade_raises_if_lower_than_v3_0_0_and_no_ab_user(self, mocker): - mock_inspect = mocker.patch("airflow.utils.db.inspect") - mock_inspect.return_value.has_table.return_value = False - msg = ( - "Downgrade to revision less than 3.0.0 requires that `ab_user` table is present. " - "Please add FabDBManager to [core] external_db_managers and run fab migrations before " - "proceeding" - ) - with pytest.raises(AirflowException, match=re.escape(msg)): - downgrade(to_revision=_REVISION_HEADS_MAP["2.7.0"]) - class TestAutocommitEngineForMySQL: """Test the AutocommitEngineForMySQL context manager.""" diff --git a/devel-common/src/tests_common/test_utils/db.py b/devel-common/src/tests_common/test_utils/db.py index 5248e361d1888..d05fb99e5f39f 100644 --- a/devel-common/src/tests_common/test_utils/db.py +++ b/devel-common/src/tests_common/test_utils/db.py @@ -18,6 +18,7 @@ from __future__ import annotations import json +import os from tempfile import gettempdir from typing import TYPE_CHECKING @@ -101,9 +102,26 @@ def initial_db_init(): from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS db.resetdb() + if AIRFLOW_V_3_0_PLUS: - db.downgrade(to_revision="5f2621c13b39") - db.upgradedb(to_revision="head") + try: + from airflow.providers.fab.auth_manager.models.db import FABDBManager + except ModuleNotFoundError: + # Reasons it might fail: we're in a provider bundle without FAB, or we're on a version of Python + # where FAB isn't yet supported + pass + else: + if os.getenv("TEST_GROUP") != "providers": + # If we loaded the provider, and we're running core (or running via breeze where TEST_GROUP + # isn't specified) run the downgrade+upgrade to ensure migrations are in sync with Model + # classes + db.downgrade(to_revision="5f2621c13b39") + db.upgradedb(to_revision="head") + else: + # Just create the tables so they are there + with create_session() as session: + FABDBManager(session).create_db_from_orm() + session.commit() else: from flask import Flask diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py index 1a8d66af3707d..d46df36af5d0b 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py @@ -59,6 +59,11 @@ def create_db_from_orm(self): super().create_db_from_orm() _get_flask_db(settings.SQL_ALCHEMY_CONN).create_all() + def reset_to_2_x(self): + self.create_db_from_orm() + # And ensure it's at the oldest version + self.downgrade(_REVISION_HEADS_MAP["1.4.0"]) + def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False): """Upgrade the database.""" if from_revision and not show_sql_only: diff --git a/providers/fab/tests/unit/fab/db_manager/test_fab_db_manager.py b/providers/fab/tests/unit/fab/db_manager/test_fab_db_manager.py index 04fbcfa69edbe..3dc67bf0c9f8a 100644 --- a/providers/fab/tests/unit/fab/db_manager/test_fab_db_manager.py +++ b/providers/fab/tests/unit/fab/db_manager/test_fab_db_manager.py @@ -22,7 +22,7 @@ from sqlalchemy import Table from airflow.exceptions import AirflowException -from airflow.utils.db import downgrade, initdb +from airflow.utils.db import initdb from airflow.utils.db_manager import RunDBManager from tests_common.test_utils.config import conf_vars @@ -57,27 +57,12 @@ def test_defining_table_same_name_as_airflow_table_name_raises(self): run_db_manager.validate() metadata._remove_table("dag_run", None) - @mock.patch.object(RunDBManager, "downgrade") @mock.patch.object(RunDBManager, "upgradedb") @mock.patch.object(RunDBManager, "initdb") - def test_init_db_calls_rundbmanager(self, mock_initdb, mock_upgrade_db, mock_downgrade_db, session): + def test_init_db_calls_rundbmanager(self, mock_initdb, mock_upgrade_db, session): initdb(session=session) mock_initdb.assert_called() mock_initdb.assert_called_once_with(session) - mock_downgrade_db.assert_not_called() - - @mock.patch.object(RunDBManager, "downgrade") - @mock.patch.object(RunDBManager, "upgradedb") - @mock.patch.object(RunDBManager, "initdb") - @mock.patch("alembic.command") - def test_downgrade_dont_call_rundbmanager( - self, mock_alembic_command, mock_initdb, mock_upgrade_db, mock_downgrade_db, session - ): - downgrade(to_revision="base") - mock_alembic_command.downgrade.assert_called_once_with(mock.ANY, revision="base", sql=False) - mock_upgrade_db.assert_not_called() - mock_initdb.assert_not_called() - mock_downgrade_db.assert_not_called() @conf_vars( {("database", "external_db_managers"): "airflow.providers.fab.auth_manager.models.db.FABDBManager"} @@ -93,9 +78,7 @@ def test_rundbmanager_calls_dbmanager_methods(self, mock_fabdb_manager, session) # upgradedb ext_db.upgradedb(session=session) fabdb_manager.upgradedb.assert_called_once() - # downgrade - ext_db.downgrade(session=session) - mock_fabdb_manager.return_value.downgrade.assert_called_once() + # drop_tables connection = mock.MagicMock() ext_db.drop_tables(session, connection) mock_fabdb_manager.return_value.drop_tables.assert_called_once_with(connection)