Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,34 +729,15 @@ def create_default_connections(session: Session = NEW_SESSION):
)


def _get_flask_db(sql_database_uri):
from flask import Flask
from flask_sqlalchemy import SQLAlchemy

from airflow.providers.fab.www.session import AirflowDatabaseSessionInterface

flask_app = Flask(__name__)
flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(flask_app)
AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
return db


def _create_db_from_orm(session):
log.info("Creating Airflow database tables from the ORM")
from alembic import command

from airflow.models.base import Base

def _create_flask_session_tbl(sql_database_uri):
db = _get_flask_db(sql_database_uri)
db.create_all()

with create_global_lock(session=session, lock=DBLocks.MIGRATIONS):
engine = session.get_bind().engine
Base.metadata.create_all(engine)
_create_flask_session_tbl(engine.url)
# stamp the migration head
config = _get_alembic_config()
command.stamp(config, "head")
Expand Down Expand Up @@ -1254,8 +1235,6 @@ def drop_airflow_models(connection):
from airflow.models.base import Base

Base.metadata.drop_all(connection)
db = _get_flask_db(connection.engine.url)
db.drop_all()
# alembic adds significant import time, so we import it lazily
from alembic.migration import MigrationContext

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@
}


def _get_flask_db(sql_database_uri):
from flask import Flask
from flask_sqlalchemy import SQLAlchemy

from airflow.providers.fab.www.session import AirflowDatabaseSessionInterface

flask_app = Flask(__name__)
flask_app.config["SQLALCHEMY_DATABASE_URI"] = sql_database_uri
flask_app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db = SQLAlchemy(flask_app)
AirflowDatabaseSessionInterface(app=flask_app, db=db, table="session", key_prefix="")
return db


class FABDBManager(BaseDBManager):
"""Manages FAB database."""

Expand All @@ -40,6 +54,10 @@ class FABDBManager(BaseDBManager):
alembic_file = (PACKAGE_DIR / "alembic.ini").as_posix()
supports_table_dropping = True

def _create_db_from_orm(self):
super()._create_db_from_orm()
_get_flask_db(settings.SQL_ALCHEMY_CONN).create_all()

def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False):
"""Upgrade the database."""
if from_revision and not show_sql_only:
Expand Down Expand Up @@ -68,11 +86,6 @@ def upgradedb(self, to_revision=None, from_revision=None, show_sql_only=False):
_offline_migration(command.upgrade, config, f"{from_revision}:{to_revision}")
return # only running sql; our job is done

if not self.get_current_revision():
# New DB; initialize and exit
self.initdb()
return

command.upgrade(config, revision=to_revision or "heads")

def downgrade(self, to_revision, from_revision=None, show_sql_only=False):
Expand Down Expand Up @@ -104,3 +117,7 @@ def downgrade(self, to_revision, from_revision=None, show_sql_only=False):
else:
self.log.info("Applying FAB downgrade migrations.")
command.downgrade(config, revision=to_revision, sql=show_sql_only)

def drop_tables(self, connection):
super().drop_tables(connection)
_get_flask_db(settings.SQL_ALCHEMY_CONN).drop_all()
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,12 @@ def test_sqlite_offline_upgrade_raises_with_revision(self, mock_gcr, session):

@mock.patch("airflow.utils.db_manager.inspect")
@mock.patch.object(FABDBManager, "metadata")
def test_drop_tables(self, mock_metadata, mock_inspect, session):
@mock.patch("airflow.providers.fab.auth_manager.models.db._get_flask_db")
def test_drop_tables(self, mock__get_flask_db, mock_metadata, mock_inspect, session):
manager = FABDBManager(session)
connection = mock.MagicMock()
manager.drop_tables(connection)
mock__get_flask_db.return_value.drop_all.assert_called_once_with()
mock_metadata.drop_all.assert_called_once_with(connection)

@pytest.mark.parametrize("skip_init", [True, False])
Expand Down