From 986bb255e232cc3855b4dfdb98a20b057a55960d Mon Sep 17 00:00:00 2001 From: luolingchun Date: Tue, 17 Oct 2023 17:30:28 +0800 Subject: [PATCH] Support the `-x` flag in `env.py` --- .../templates/aioflask-multidb/env.py | 23 +++++++++++++++++-- src/flask_migrate/templates/aioflask/env.py | 17 ++++++++++++++ .../templates/flask-multidb/env.py | 23 +++++++++++++++++-- src/flask_migrate/templates/flask/env.py | 17 ++++++++++++++ 4 files changed, 76 insertions(+), 4 deletions(-) diff --git a/src/flask_migrate/templates/aioflask-multidb/env.py b/src/flask_migrate/templates/aioflask-multidb/env.py index f14bd6c..2d061eb 100644 --- a/src/flask_migrate/templates/aioflask-multidb/env.py +++ b/src/flask_migrate/templates/aioflask-multidb/env.py @@ -2,7 +2,7 @@ import logging from logging.config import fileConfig -from sqlalchemy import MetaData +from sqlalchemy import MetaData, text from flask import current_app from alembic import context @@ -135,8 +135,27 @@ def process_revision_directives(context, revision, directives): if conf_args.get("process_revision_directives") is None: conf_args["process_revision_directives"] = process_revision_directives + current_schema = context.get_x_argument(as_dictionary=True).get("schema") + for name, rec in engines.items(): - rec['sync_connection'] = conn = rec['connection']._sync_connection() + connection = rec['connection'] + if current_schema and connection.dialect.name == "postgresql": + # set search path on the connection, which ensures that + # PostgreSQL will emit all CREATE / ALTER / DROP statements + # in terms of this schema by default + connection.execute( + text('set search_path to "%s"' % current_schema) + ) + # in SQLAlchemy v2+ the search path change + # needs to be committed + connection.commit() + + # make use of non-supported SQLAlchemy attribute to ensure + # the dialect reflects tables in terms of + # the current schema name + connection.dialect.default_schema_name = current_schema + + rec['sync_connection'] = conn = connection._sync_connection() if USE_TWOPHASE: rec['transaction'] = conn.begin_twophase() else: diff --git a/src/flask_migrate/templates/aioflask/env.py b/src/flask_migrate/templates/aioflask/env.py index 3a1ece5..a6b2773 100644 --- a/src/flask_migrate/templates/aioflask/env.py +++ b/src/flask_migrate/templates/aioflask/env.py @@ -5,6 +5,7 @@ from flask import current_app from alembic import context +from sqlalchemy import text # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -88,6 +89,22 @@ def process_revision_directives(context, revision, directives): if conf_args.get("process_revision_directives") is None: conf_args["process_revision_directives"] = process_revision_directives + current_schema = context.get_x_argument(as_dictionary=True).get("schema") + + if current_schema and connection.dialect.name == "postgresql": + # set search path on the connection, which ensures that + # PostgreSQL will emit all CREATE / ALTER / DROP statements + # in terms of this schema by default + connection.execute( + text('set search_path to "%s"' % current_schema) + ) + # in SQLAlchemy v2+ the search path change needs to be committed + connection.commit() + + # make use of non-supported SQLAlchemy attribute to ensure + # the dialect reflects tables in terms of the current schema name + connection.dialect.default_schema_name = current_schema + context.configure( connection=connection, target_metadata=get_metadata(), diff --git a/src/flask_migrate/templates/flask-multidb/env.py b/src/flask_migrate/templates/flask-multidb/env.py index 31b8e72..c905b23 100644 --- a/src/flask_migrate/templates/flask-multidb/env.py +++ b/src/flask_migrate/templates/flask-multidb/env.py @@ -1,7 +1,7 @@ import logging from logging.config import fileConfig -from sqlalchemy import MetaData +from sqlalchemy import MetaData, text from flask import current_app from alembic import context @@ -158,11 +158,30 @@ def process_revision_directives(context, revision, directives): else: rec['transaction'] = conn.begin() + current_schema = context.get_x_argument(as_dictionary=True).get("schema") + try: for name, rec in engines.items(): + connection = rec['connection'] + if current_schema and connection.dialect.name == "postgresql": + # set search path on the connection, which ensures that + # PostgreSQL will emit all CREATE / ALTER / DROP statements + # in terms of this schema by default + connection.execute( + text('set search_path to "%s"' % current_schema) + ) + # in SQLAlchemy v2+ the search path change + # needs to be committed + connection.commit() + + # make use of non-supported SQLAlchemy attribute to ensure + # the dialect reflects tables in terms of + # the current schema name + connection.dialect.default_schema_name = current_schema + logger.info("Migrating database %s" % (name or '')) context.configure( - connection=rec['connection'], + connection=connection, upgrade_token="%s_upgrades" % name, downgrade_token="%s_downgrades" % name, target_metadata=get_metadata(name), diff --git a/src/flask_migrate/templates/flask/env.py b/src/flask_migrate/templates/flask/env.py index 4c97092..265174c 100644 --- a/src/flask_migrate/templates/flask/env.py +++ b/src/flask_migrate/templates/flask/env.py @@ -4,6 +4,7 @@ from flask import current_app from alembic import context +from sqlalchemy import text # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -96,7 +97,23 @@ def process_revision_directives(context, revision, directives): connectable = get_engine() + current_schema = context.get_x_argument(as_dictionary=True).get("schema") + with connectable.connect() as connection: + if current_schema and connection.dialect.name == "postgresql": + # set search path on the connection, which ensures that + # PostgreSQL will emit all CREATE / ALTER / DROP statements + # in terms of this schema by default + connection.execute( + text('set search_path to "%s"' % current_schema) + ) + # in SQLAlchemy v2+ the search path change needs to be committed + connection.commit() + + # make use of non-supported SQLAlchemy attribute to ensure + # the dialect reflects tables in terms of the current schema name + connection.dialect.default_schema_name = current_schema + context.configure( connection=connection, target_metadata=get_metadata(),