diff --git a/airflow-core/pyproject.toml b/airflow-core/pyproject.toml index 3f3cfd996e49e..911892d563b32 100644 --- a/airflow-core/pyproject.toml +++ b/airflow-core/pyproject.toml @@ -130,8 +130,8 @@ dependencies = [ "rich-argparse>=1.0.0", "rich>=13.6.0", "setproctitle>=1.3.3", - # The issue tracking deprecations for sqlalchemy 2 is https://github.com/apache/airflow/issues/28723 - "sqlalchemy[asyncio]>=1.4.49", + # SQLAlchemy >=2.0.36 fixes Python 3.13 TypingOnly import AssertionError caused by new typing attributes (__static_attributes__, __firstlineno__) + "sqlalchemy[asyncio]>=2.0.36", "sqlalchemy-jsonfield>=1.0", "sqlalchemy-utils>=0.41.2", "svcs>=25.1.0", diff --git a/airflow-core/src/airflow/utils/db.py b/airflow-core/src/airflow/utils/db.py index 8b9d095deb394..e2d7ff638784e 100644 --- a/airflow-core/src/airflow/utils/db.py +++ b/airflow-core/src/airflow/utils/db.py @@ -682,18 +682,39 @@ def __exit__(self, exc_type, exc_val, exc_tb): settings.configure_orm() -def _create_db_from_orm(session): - """Create database tables from ORM models and stamp alembic version.""" +def _create_db_from_orm_mysql(session) -> None: + """Create database tables from ORM models for MySQL.""" + from alembic import command + from airflow.models.base import Base - log.info("Creating Airflow database tables from the ORM") + # MySQL: Commit session to release metadata locks before DDL + log.info("MySQL: Committing session to release metadata locks") + session.commit() - # Debug setup if requested - _setup_debug_logging_if_needed() + engine = session.get_bind().engine + log.info("Creating tables (MySQL)") + Base.metadata.create_all(engine) + + log.info("Getting alembic config") + config = _get_alembic_config() + + with AutocommitEngineForMySQL(): + log.info("Stamping migration head") + command.stamp(config, "head") + + log.info("Airflow database tables created") - log.info("Creating context") + +def _create_db_from_orm_default(session) -> None: + """Create database tables from ORM models for PostgreSQL/SQLite.""" + from alembic import command + + from airflow.models.base import Base + + # PostgreSQL / SQLite: Use transactional global lock + log.info("Creating global lock context") with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): - log.info("Binding engine") engine = session.get_bind().engine log.info("Pool status: %s", engine.pool.status()) @@ -703,14 +724,21 @@ def _create_db_from_orm(session): log.info("Getting alembic config") config = _get_alembic_config() - # Use AUTOCOMMIT for DDL to avoid metadata lock issues - with AutocommitEngineForMySQL(): # TODO: enable for sqlite too - from alembic import command + log.info("Stamping migration head") + command.stamp(config, "head") - log.info("Stamping migration head") - command.stamp(config, "head") + log.info("Airflow database tables created") - log.info("Airflow database tables created") + +def _create_db_from_orm(session): + """Create database tables from ORM models and stamp alembic version.""" + log.info("Creating Airflow database tables from the ORM") + _setup_debug_logging_if_needed() + + if get_dialect_name(session) == "mysql": + _create_db_from_orm_mysql(session) + else: + _create_db_from_orm_default(session) def _setup_debug_logging_if_needed(): @@ -718,24 +746,79 @@ def _setup_debug_logging_if_needed(): if not os.environ.get("SQLALCHEMY_ENGINE_DEBUG"): return + import atexit import faulthandler - import threading + from contextlib import suppress # Enable SQLA debug logging logging.getLogger("sqlalchemy.engine").setLevel(logging.DEBUG) - # Enable Fault Handler + # Enable faulthandler for debugging long-running threads and deadlocks, + # but disable it before interpreter shutdown to avoid segfaults during + # cleanup (especially with SQLAlchemy 2.0 + pytest teardown) faulthandler.enable(file=sys.stderr, all_threads=True) - # Print Active Threads and Stack Traces Periodically - def dump_stacks(): - while True: - for thread_id, frame in sys._current_frames().items(): - log.info("\nThread %s stack:", thread_id) - traceback.print_stack(frame) - time.sleep(300) + # Cancel any pending traceback dumps and disable faulthandler before exit + # to prevent it from interfering with C extension cleanup + def cleanup_faulthandler(): + with suppress(Exception): + faulthandler.cancel_dump_traceback_later() + with suppress(Exception): + faulthandler.disable() + + atexit.register(cleanup_faulthandler) + + # Set up periodic traceback dumps for debugging hanging tests/threads + faulthandler.dump_traceback_later(timeout=300, repeat=True, file=sys.stderr) + - threading.Thread(target=dump_stacks, daemon=True).start() +@contextlib.contextmanager +def _mysql_lock_session_for_migration(original_session: Session) -> Generator[Session, None, None]: + """ + Create a MySQL-specific lock session for migration operations. + + This context manager: + 1. Commits the original session to release metadata locks + 2. Creates a new session bound to the engine + 3. Ensures the session is properly closed on exit + + :param original_session: The original session to commit + :return: A new session suitable for use with create_global_lock + """ + from sqlalchemy.orm import Session as SASession + + log.info("MySQL: Committing session to release metadata locks") + original_session.commit() + + lock_session = SASession(bind=settings.engine) + try: + yield lock_session + finally: + lock_session.close() + + +@contextlib.contextmanager +def _single_connection_pool() -> Generator[None, None, None]: + """ + Temporarily reconfigure ORM to use exactly one connection. + + This is needed for migrations because some database engines hang forever + trying to ALTER TABLEs when multiple connections exist in the pool. + + Saves and restores the AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE environment variable. + """ + import sqlalchemy.pool + + previous_pool_size = os.environ.get("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE") + try: + os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = "1" + settings.reconfigure_orm(pool_class=sqlalchemy.pool.SingletonThreadPool) + yield + finally: + os.environ.pop("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE", None) + if previous_pool_size is not None: + os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = previous_pool_size + settings.reconfigure_orm() @provide_session @@ -1064,6 +1147,38 @@ def _revisions_above_min_for_offline(config, revisions) -> None: ) +def _run_upgradedb(config, to_revision: str | None, session: Session) -> None: + """Run database upgrade with appropriate locking for the dialect.""" + from alembic import command + + is_mysql = settings.get_engine().dialect.name == "mysql" + dialect_label = " (MySQL)" if is_mysql else "" + log.info("Migrating the Airflow database%s", dialect_label) + + # MySQL needs a separate lock session; others use the original session + session_cm: contextlib.AbstractContextManager[Session] = ( + _mysql_lock_session_for_migration(session) if is_mysql else contextlib.nullcontext(session) + ) + + with ( + session_cm as work_session, + create_global_lock(session=work_session, lock=DBLocks.MIGRATIONS), + _single_connection_pool(), + ): + command.upgrade(config, revision=to_revision or "heads") + + current_revision = _get_current_revision(session=work_session) + with _configured_alembic_environment() as env: + source_heads = env.script.get_heads() + + if current_revision == source_heads[0]: + external_db_manager = RunDBManager() + external_db_manager.upgradedb(work_session) + + add_default_pool_if_not_exists(session=work_session) + synchronize_log_template(session=work_session) + + @provide_session def upgradedb( *, @@ -1073,7 +1188,7 @@ def upgradedb( session: Session = NEW_SESSION, ): """ - Upgrades the DB. + Upgrade the DB. :param to_revision: Optional Alembic revision ID to upgrade *to*. If omitted, upgrades to latest revision. @@ -1086,9 +1201,9 @@ def upgradedb( if from_revision and not show_sql_only: raise AirflowException("`from_revision` only supported with `sql_only=True`.") - # alembic adds significant import time, so we import it lazily if not settings.SQL_ALCHEMY_CONN: raise RuntimeError("The settings.SQL_ALCHEMY_CONN not set. This is a critical assertion.") + from alembic import command import_all_models() @@ -1110,7 +1225,7 @@ def upgradedb( _revisions_above_min_for_offline(config=config, revisions=[from_revision, to_revision]) _offline_migration(command.upgrade, config, f"{from_revision}:{to_revision}") - return # only running sql; our job is done + return errors_seen = False for err in _check_migration_errors(session=session): @@ -1123,38 +1238,56 @@ def upgradedb( exit(1) if not _get_current_revision(session=session) and not to_revision: - # Don't load default connections # New DB; initialize and exit initdb(session=session) return - with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): - import sqlalchemy.pool - log.info("Migrating the Airflow database") - val = os.environ.get("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE") - try: - # Reconfigure the ORM to use _EXACTLY_ one connection, otherwise some db engines hang forever - # trying to ALTER TABLEs - os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = "1" - settings.reconfigure_orm(pool_class=sqlalchemy.pool.SingletonThreadPool) - command.upgrade(config, revision=to_revision or "heads") - current_revision = _get_current_revision(session=session) - with _configured_alembic_environment() as env: - source_heads = env.script.get_heads() - if current_revision == source_heads[0]: - # Only run external DB upgrade migration if user upgraded to heads + _run_upgradedb(config, to_revision, session) + + +def _resetdb_mysql(session: Session) -> None: + """Drop all Airflow tables for MySQL.""" + from sqlalchemy.orm import Session as SASession + + # MySQL: Release metadata locks and use AUTOCOMMIT for DDL + log.info("MySQL: Releasing metadata locks before DDL operations") + session.commit() + session.close() + + # Use create_global_lock for migration safety (now handles MySQL with AUTOCOMMIT) + engine = settings.get_engine() + lock_session = SASession(bind=engine) + try: + with ( + create_global_lock(session=lock_session, lock=DBLocks.MIGRATIONS), + engine.connect() as connection, + ): + ddl_conn = connection.execution_options(isolation_level="AUTOCOMMIT") + + drop_airflow_models(ddl_conn) + drop_airflow_moved_tables(ddl_conn) + log.info("Dropped all Airflow tables") + + # Use raw Session to avoid scoped session issues + work_session = SASession(bind=ddl_conn) + try: external_db_manager = RunDBManager() - external_db_manager.upgradedb(session) + external_db_manager.drop_tables(work_session, ddl_conn) + finally: + work_session.close() + finally: + lock_session.close() - finally: - if val is None: - os.environ.pop("AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE") - else: - os.environ["AIRFLOW__DATABASE__SQL_ALCHEMY_MAX_SIZE"] = val - settings.reconfigure_orm() - add_default_pool_if_not_exists(session=session) - synchronize_log_template(session=session) +def _resetdb_default(session: Session) -> None: + """Drop all Airflow tables for PostgreSQL/SQLite.""" + connection = settings.get_engine().connect() + with create_global_lock(session=session, lock=DBLocks.MIGRATIONS), connection.begin(): + drop_airflow_models(connection) + drop_airflow_moved_tables(connection) + log.info("Dropped all Airflow tables") + external_db_manager = RunDBManager() + external_db_manager.drop_tables(session, connection) @provide_session @@ -1166,17 +1299,19 @@ def resetdb(session: Session = NEW_SESSION, skip_init: bool = False): import_all_models() - connection = settings.engine.connect() - - with create_global_lock(session=session, lock=DBLocks.MIGRATIONS), connection.begin(): - drop_airflow_models(connection) - drop_airflow_moved_tables(connection) - log.info("Dropped all Airflow tables") - external_db_manager = RunDBManager() - external_db_manager.drop_tables(session, connection) + if get_dialect_name(session) == "mysql": + _resetdb_mysql(session) + else: + _resetdb_default(session) if not skip_init: - initdb(session=session) + # Create a fresh non-scoped session for initdb since the original was closed (MySQL) + # or used (Postgres). Using scoped=False ensures we get a new session even if the + # scoped session registry has the old closed session. + from airflow.utils.session import create_session + + with create_session(scoped=False) as new_session: + initdb(session=new_session) @provide_session @@ -1206,18 +1341,32 @@ def downgrade(*, to_revision, from_revision=None, show_sql_only=False, session: log.info("Attempting downgrade to revision %s", to_revision) config = _get_alembic_config() + # If downgrading to less than 3.0.0, we need to handle the FAB provider if _revision_greater(config, _REVISION_HEADS_MAP["2.10.3"], to_revision): _handle_fab_downgrade(session=session) - with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): + + # Determine which session to use for the migration operations + if get_dialect_name(session) == "mysql": + # MySQL: Commit session to release metadata locks before Alembic DDL + session_cm: contextlib.AbstractContextManager[Session] = _mysql_lock_session_for_migration(session) + else: + # PostgreSQL / SQLite: Use original session + session_cm = contextlib.nullcontext(session) + + with ( + session_cm as work_session, + create_global_lock(session=work_session, lock=DBLocks.MIGRATIONS), + ): if show_sql_only: log.warning("Generating sql scripts for manual migration.") if not from_revision: - from_revision = _get_current_revision(session) + from_revision = _get_current_revision(work_session) revision_range = f"{from_revision}:{to_revision}" _offline_migration(command.downgrade, config=config, revision=revision_range) else: - log.info("Applying downgrade migrations to Airflow database.") + dialect_label = " (MySQL)" if get_dialect_name(work_session) == "mysql" else "" + log.info("Applying downgrade migrations to Airflow database%s.", dialect_label) command.downgrade(config, revision=to_revision, sql=show_sql_only) @@ -1259,14 +1408,15 @@ def _handle_fab_downgrade(*, session: Session) -> None: fab_version, ) return - connection = settings.get_engine().connect() - insp = inspect(connection) - if not fab_version and insp.has_table("ab_user"): - log.info( - "FAB migration version not found, but FAB tables exist. " - "FAB provider is not required for downgrade.", - ) - return + + with settings.get_engine().connect() as connection: + insp = inspect(connection) + if not fab_version and insp.has_table("ab_user"): + log.info( + "FAB migration version not found, but FAB tables exist. " + "FAB provider is not required for downgrade.", + ) + return # FAB db version is different or not found - require the FAB provider try: @@ -1276,6 +1426,12 @@ def _handle_fab_downgrade(*, session: Session) -> None: "Import error occurred while importing FABDBManager. The apache-airflow-provider-fab package must be installed before we can " "downgrade to <3.0.0." ) + + # For MySQL: commit session to release metadata locks before FABDBManager operations + if get_dialect_name(session) == "mysql": + log.info("MySQL: Committing session to release metadata locks before FAB operations") + session.commit() + dbm = FABDBManager(session) if hasattr(dbm, "reset_to_2_x"): dbm.reset_to_2_x() @@ -1342,58 +1498,107 @@ def __str__(self): @contextlib.contextmanager -def create_global_lock( - session: Session, - lock: DBLocks, - lock_timeout: int = 1800, +def _create_global_lock_mysql(lock: DBLocks, lock_timeout: int) -> Generator[None, None, None]: + """ + Create a global advisory lock for MySQL. + + Uses a dedicated AUTOCOMMIT connection because: + - GET_LOCK is session-level, not transaction-level + - DDL operations cause implicit commits that would break transaction wrappers + """ + lock_conn = settings.get_engine().connect() + try: + lock_conn = lock_conn.execution_options(isolation_level="AUTOCOMMIT") + + # GET_LOCK returns: 1 = acquired, 0 = timeout, NULL = error + lock_result = lock_conn.execute( + text("SELECT GET_LOCK(:lock_name, :timeout)"), + {"lock_name": str(lock), "timeout": lock_timeout}, + ).scalar() + + if lock_result != 1: + raise RuntimeError( + f"Could not acquire MySQL advisory lock '{lock}'. " + f"Result: {lock_result}. Another process may be holding the lock." + ) + + try: + yield + finally: + lock_conn.execute(text("SELECT RELEASE_LOCK(:lock_name)"), {"lock_name": str(lock)}) + finally: + lock_conn.close() + + +@contextlib.contextmanager +def _create_global_lock_postgresql( + session: Session, lock: DBLocks, lock_timeout: int ) -> Generator[None, None, None]: - """Contextmanager that will create and teardown a global db lock.""" + """Create a global advisory lock for PostgreSQL using transactional advisory locks.""" bind = session.get_bind() if hasattr(bind, "connect"): conn = bind.connect() + owns_connection = True else: conn = bind - dialect_name = get_dialect_name(session) + owns_connection = False + try: - if dialect_name == "postgresql": - if _USE_PSYCOPG3: - # psycopg3 doesn't support parameters for `SET`. Use `set_config` instead. - # The timeout value must be passed as a string of milliseconds. - conn.execute( - text("SELECT set_config('lock_timeout', :timeout, false)"), - {"timeout": str(lock_timeout)}, - ) - conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) - else: - conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout}) - conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) - elif ( - dialect_name == "mysql" - and conn.dialect.server_version_info - and conn.dialect.server_version_info >= (5, 6) - ): - conn.execute(text("SELECT GET_LOCK(:id, :timeout)"), {"id": str(lock), "timeout": lock_timeout}) + if _USE_PSYCOPG3: + conn.execute( + text("SELECT set_config('lock_timeout', :timeout, false)"), + {"timeout": str(lock_timeout)}, + ) + conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) + else: + conn.execute(text("SET LOCK_TIMEOUT to :timeout"), {"timeout": lock_timeout}) + conn.execute(text("SELECT pg_advisory_lock(:id)"), {"id": lock.value}) yield finally: - if dialect_name == "postgresql": - if _USE_PSYCOPG3: - # Use set_config() to reset the timeout to its default (0 = off/wait forever). - conn.execute(text("SELECT set_config('lock_timeout', '0', false)")) - else: - conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT")) - result = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone() - if result is None: - raise RuntimeError("Error releasing DB lock!") - (unlocked,) = result - if not unlocked: - raise RuntimeError("Error releasing DB lock!") - elif ( - dialect_name == "mysql" - and conn.dialect.server_version_info - and conn.dialect.server_version_info >= (5, 6) - ): - conn.execute(text("select RELEASE_LOCK(:id)"), {"id": str(lock)}) + if _USE_PSYCOPG3: + conn.execute(text("SELECT set_config('lock_timeout', '0', false)")) + else: + conn.execute(text("SET LOCK_TIMEOUT TO DEFAULT")) + + result = conn.execute(text("SELECT pg_advisory_unlock(:id)"), {"id": lock.value}).fetchone() + + if result is None: + raise RuntimeError("Error releasing DB lock!") + (unlocked,) = result + if not unlocked: + raise RuntimeError("Error releasing DB lock!") + + if owns_connection: + conn.close() + + +@contextlib.contextmanager +def create_global_lock( + session: Session, + lock: DBLocks, + lock_timeout: int = 1800, +) -> Generator[None, None, None]: + """ + Contextmanager that will create and teardown a global db lock. + + For MySQL, uses a dedicated AUTOCOMMIT connection because: + - GET_LOCK is session-level, not transaction-level + - DDL operations cause implicit commits that would break transaction wrappers + + For PostgreSQL, uses transactional advisory locks as before. + """ + dialect_name = get_dialect_name(session) + + if dialect_name == "mysql": + with _create_global_lock_mysql(lock, lock_timeout): + yield + elif dialect_name == "postgresql": + with _create_global_lock_postgresql(session, lock, lock_timeout): + yield + else: + # SQLite and others: no advisory lock support + yield def compare_type(context, inspected_column, metadata_column, inspected_type, metadata_type): diff --git a/airflow-core/src/airflow/utils/db_manager.py b/airflow-core/src/airflow/utils/db_manager.py index ca3e181c1dbf8..93959b5bf68bf 100644 --- a/airflow-core/src/airflow/utils/db_manager.py +++ b/airflow-core/src/airflow/utils/db_manager.py @@ -27,6 +27,7 @@ from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.utils.log.logging_mixin import LoggingMixin +from airflow.utils.sqlalchemy import get_dialect_name if TYPE_CHECKING: from alembic.script import ScriptDirectory @@ -48,6 +49,21 @@ def __init__(self, session): super().__init__() self.session = session + def _release_metadata_locks_if_needed(self) -> None: + """ + Release MySQL metadata locks by committing the session. + + MySQL requires metadata locks to be released before DDL operations. + This is done by committing the current transaction. + This method is a no-op for non-MySQL databases. + """ + if get_dialect_name(self.session) != "mysql": + return + + self.log.debug("MySQL: Releasing metadata locks for DDL operations") + self.session.commit() + self.log.debug("MySQL: Session committed, metadata locks released") + def get_alembic_config(self): from alembic.config import Config @@ -90,6 +106,7 @@ def check_migration(self): def create_db_from_orm(self): """Create database from ORM.""" self.log.info("Creating %s tables from the ORM", self.__class__.__name__) + self._release_metadata_locks_if_needed() engine = self.session.get_bind().engine self.metadata.create_all(engine) config = self.get_alembic_config() @@ -107,6 +124,8 @@ def drop_tables(self, connection): def resetdb(self, skip_init=False): from airflow.utils.db import DBLocks, create_global_lock + self._release_metadata_locks_if_needed() + connection = settings.engine.connect() with create_global_lock(self.session, lock=DBLocks.MIGRATIONS), connection.begin(): @@ -117,6 +136,7 @@ def resetdb(self, skip_init=False): def initdb(self): """Initialize the database.""" + self._release_metadata_locks_if_needed() db_exists = self.get_current_revision() if db_exists: self.upgradedb() @@ -127,6 +147,8 @@ def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False): """Upgrade the database.""" self.log.info("Upgrading the %s database", self.__class__.__name__) + self._release_metadata_locks_if_needed() + config = self.get_alembic_config() command.upgrade(config, revision=to_revision or "heads", sql=show_sql_only) self.log.info("Migrated the %s database", self.__class__.__name__) diff --git a/providers/fab/docs/index.rst b/providers/fab/docs/index.rst index d22c3877b5cda..0344a9f01c4ed 100644 --- a/providers/fab/docs/index.rst +++ b/providers/fab/docs/index.rst @@ -116,7 +116,6 @@ PIP package Version required ``flask-session`` ``>=0.8.0; python_version < "3.13"`` ``msgpack`` ``>=1.0.0; python_version < "3.13"`` ``flask-sqlalchemy`` ``>=3.0.5; python_version < "3.13"`` -``sqlalchemy`` ``>=1.4.36,<2; python_version < "3.13"`` ``flask-wtf`` ``>=1.1.0; python_version < "3.13"`` ``connexion[flask]`` ``>=2.14.2,<3.0; python_version < "3.13"`` ``jmespath`` ``>=0.7.0; python_version < "3.13"`` diff --git a/providers/fab/pyproject.toml b/providers/fab/pyproject.toml index 3944d207069ab..1e5b5c250b553 100644 --- a/providers/fab/pyproject.toml +++ b/providers/fab/pyproject.toml @@ -73,12 +73,9 @@ dependencies = [ # In particular, make sure any breaking changes, for example any new methods, are accounted for. "flask-appbuilder==5.0.1; python_version < '3.13'", "flask-login>=0.6.2; python_version < '3.13'", - # Flask-Session 0.6 add new arguments into the SqlAlchemySessionInterface constructor as well as - # all parameters now are mandatory which make AirflowDatabaseSessionInterface incompatible with this version. "flask-session>=0.8.0; python_version < '3.13'", "msgpack>=1.0.0; python_version < '3.13'", "flask-sqlalchemy>=3.0.5; python_version < '3.13'", - "sqlalchemy>=1.4.36,<2; python_version < '3.13'", "flask-wtf>=1.1.0; python_version < '3.13'", "connexion[flask]>=2.14.2,<3.0; python_version < '3.13'", "jmespath>=0.7.0; python_version < '3.13'", diff --git a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py index fee62bb56704a..ee1c3a2291d4f 100644 --- a/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py +++ b/providers/fab/tests/unit/fab/auth_manager/test_fab_auth_manager.py @@ -888,20 +888,26 @@ def test_get_db_manager(self, auth_manager): @mock.patch("airflow.utils.db.drop_airflow_models") @mock.patch("airflow.utils.db.drop_airflow_moved_tables") @mock.patch("airflow.utils.db.initdb") -@mock.patch("airflow.settings.engine.connect") +@mock.patch("airflow.settings.engine") def test_resetdb( - mock_connect, + mock_engine, mock_init, mock_drop_moved, mock_drop_airflow, mock_fabdb_manager, skip_init, ): + # Mock as non-MySQL to use the simpler PostgreSQL/SQLite path + mock_engine.dialect.name = "postgresql" + mock_connect = mock_engine.connect.return_value + session_mock = MagicMock() resetdb(session_mock, skip_init=skip_init) - mock_drop_airflow.assert_called_once_with(mock_connect.return_value) - mock_drop_moved.assert_called_once_with(mock_connect.return_value) + + # In the non-MySQL path, drop functions are called with the raw connection + mock_drop_airflow.assert_called_once_with(mock_connect) + mock_drop_moved.assert_called_once_with(mock_connect) if skip_init: mock_init.assert_not_called() else: - mock_init.assert_called_once_with(session=session_mock) + mock_init.assert_called_once()