diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 70e6955c39b6f..7055a1c76012f 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -729,34 +729,15 @@ def create_default_connections(session: Session = NEW_SESSION): ) -def _get_flask_db(sql_database_uri): - from flask import Flask - from flask_sqlalchemy import SQLAlchemy - - from airflow.providers.fab.www.session import AirflowDatabaseSessionInterface - - flask_app = Flask(__name__) - flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri - flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - db = SQLAlchemy(flask_app) - AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="") - return db - - def _create_db_from_orm(session): log.info("Creating Airflow database tables from the ORM") from alembic import command from airflow.models.base import Base - def _create_flask_session_tbl(sql_database_uri): - db = _get_flask_db(sql_database_uri) - db.create_all() - with create_global_lock(session=session, lock=DBLocks.MIGRATIONS): engine = session.get_bind().engine Base.metadata.create_all(engine) - _create_flask_session_tbl(engine.url) # stamp the migration head config = _get_alembic_config() command.stamp(config, "head") @@ -1254,8 +1235,6 @@ def drop_airflow_models(connection): from airflow.models.base import Base Base.metadata.drop_all(connection) - db = _get_flask_db(connection.engine.url) - db.drop_all() # alembic adds significant import time, so we import it lazily from alembic.migration import MigrationContext diff --git a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py index ce0efef55a1cd..5e2c53977458c 100644 --- a/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py +++ b/providers/fab/src/airflow/providers/fab/auth_manager/models/db.py @@ -31,6 +31,20 @@ } +def _get_flask_db(sql_database_uri): + from flask import Flask + from flask_sqlalchemy import SQLAlchemy + + from airflow.providers.fab.www.session import AirflowDatabaseSessionInterface + + flask_app = Flask(__name__) + flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri + flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False + db = SQLAlchemy(flask_app) + AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="") + return db + + class FABDBManager(BaseDBManager): """Manages FAB database.""" @@ -40,6 +54,10 @@ class FABDBManager(BaseDBManager): alembic_file = (PACKAGE_DIR / "alembic.ini").as_posix() supports_table_dropping = True + def _create_db_from_orm(self): + super()._create_db_from_orm() + _get_flask_db(settings.SQL_ALCHEMY_CONN).create_all() + def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False): """Upgrade the database.""" if from_revision and not show_sql_only: @@ -68,11 +86,6 @@ def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False): _offline_migration(command.upgrade, config, f"{from_revision}:{to_revision}") return # only running sql; our job is done - if not self.get_current_revision(): - # New DB; initialize and exit - self.initdb() - return - command.upgrade(config, revision=to_revision or "heads") def downgrade(self, to_revision, from_revision=None, show_sql_only=False): @@ -104,3 +117,7 @@ def downgrade(self, to_revision, from_revision=None, show_sql_only=False): else: self.log.info("Applying FAB downgrade migrations.") command.downgrade(config, revision=to_revision, sql=show_sql_only) + + def drop_tables(self, connection): + super().drop_tables(connection) + _get_flask_db(settings.SQL_ALCHEMY_CONN).drop_all() diff --git a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py index f0920ebb1512b..50eaf9450e961 100644 --- a/providers/fab/tests/unit/fab/auth_manager/models/test_db.py +++ b/providers/fab/tests/unit/fab/auth_manager/models/test_db.py @@ -110,10 +110,12 @@ def test_sqlite_offline_upgrade_raises_with_revision(self, mock_gcr, session): @mock.patch("airflow.utils.db_manager.inspect") @mock.patch.object(FABDBManager, "metadata") - def test_drop_tables(self, mock_metadata, mock_inspect, session): + @mock.patch("airflow.providers.fab.auth_manager.models.db._get_flask_db") + def test_drop_tables(self, mock__get_flask_db, mock_metadata, mock_inspect, session): manager = FABDBManager(session) connection = mock.MagicMock() manager.drop_tables(connection) + mock__get_flask_db.return_value.drop_all.assert_called_once_with() mock_metadata.drop_all.assert_called_once_with(connection) @pytest.mark.parametrize("skip_init", [True, False])