Skip to content

Commit

Permalink
Refactored DBMigration to accept DB Urls instead of DB keys
Browse files Browse the repository at this point in the history
  • Loading branch information
puehringer committed Mar 3, 2020
1 parent 15c21e7 commit 8b3f898
Showing 1 changed file with 45 additions and 34 deletions.
79 changes: 45 additions & 34 deletions tdp_core/dbmigration.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,31 +22,23 @@ class DBMigration(object):
DBMigration object stores the required arguments to execute commands using Alembic.
"""

def __init__(self, id: str, db_key: str, script_location: str, auto_upgrade: bool=False):
def __init__(self, id: str, db_url: str, script_location: str, auto_upgrade: bool=False):
"""
Initializes a new migration object and optionally carries out an upgrade.
:param str id: ID of the migration object
:param str db_key: Key of the engine (coming from tdp_core#db)
:param str db_url: DB connection url
:param str script_location: Location of the base directory (containing env.py and the versions directory)
:param bool auto_upgrade: True if the migration should automatically upgrade the database to head
"""
if not id or not db_url or not script_location:
raise ValueError('Empty id or db_url or script_location')

self.id: str = id
self.db_key: str = db_key
self.db_url: str = db_url
self.script_location: str = script_location
self.auto_upgrade: bool = auto_upgrade
self.custom_commands: Dict[str, str] = dict()

missing_fields = []
if not self.id:
missing_fields.append('id')
if not self.script_location:
missing_fields.append('scriptLocation')
if not self.db_key:
missing_fields.append('dbKey')

if len(missing_fields) > 0:
raise ValueError('No {} defined for DBMigration {} - is your configuration up to date?'.format(', '.join(missing_fields), self.id or '<UNKNOWN>'))

# Because we can't easily pass "-1" as npm argument, we add a custom command for that without the space
self.add_custom_command('downgrade-(\d+)', 'downgrade -{}')

Expand Down Expand Up @@ -76,20 +68,21 @@ def add_custom_command(self, pattern: str, target: str):
def remove_custom_command(self, origin: str):
self.custom_commands.pop(origin, None)

def get_custom_command(self, arguments: List[str]) -> List[str]:
def get_custom_command(self, arguments: List[str] = []) -> List[str]:
"""
Returns the rewritten command if it matches the pattern of a custom command.
:param List[str] arguments: Argument to rewrite.
"""
# Join the list with spaces
arguments = ' '.join(arguments)
# For all the command patterns we have ..
for key, value in self.custom_commands.items():
# .. check if we can match the command pattern with the given string
matched = re.match(f"{key}$", arguments)
if matched:
# If we have a match, call format with the captured groups and split by ' '
return value.format(*matched.groups()).split(' ')
if arguments:
# Join the list with spaces
arguments = ' '.join(arguments)
# For all the command patterns we have ..
for key, value in self.custom_commands.items():
# .. check if we can match the command pattern with the given string
matched = re.match(f"{key}$", arguments)
if matched:
# If we have a match, call format with the captured groups and split by ' '
return value.format(*matched.groups()).split(' ')
return None

def execute(self, arguments: List[str] = []) -> bool:
Expand All @@ -99,10 +92,6 @@ def execute(self, arguments: List[str] = []) -> bool:
Example usage: migration.execute(['upgrade', 'head']) upgrades to the database to head.
"""
# Check if engine exists
if self.db_key not in engines:
raise ValueError('No engine called {} found for DBMigration {} - aborting migration'.format(self.db_key, self.id))

# Rewrite command if possible
rewritten_arguments = self.get_custom_command(arguments)
if rewritten_arguments:
Expand All @@ -115,13 +104,10 @@ def execute(self, arguments: List[str] = []) -> bool:
# Parse the options (incl. validation)
options = cmd_parser.parser.parse_args(arguments)

# Retrieve engine
engine = engines.engine(self.db_key)

# Inject options in the configuration object
alembic_cfg.cmd_opts = options
alembic_cfg.set_main_option('script_location', self.script_location)
alembic_cfg.set_main_option('sqlalchemy.url', str(engine.url))
alembic_cfg.set_main_option('sqlalchemy.url', self.db_url)
alembic_cfg.set_main_option('migration_id', self.id)

# Run the command
Expand All @@ -137,6 +123,8 @@ class DBMigrationManager(object):
- configKey: Key of the configuration entry (i.e. <app_name>.migration)
- id: ID of the migration for logging purposes (passed to DBManager)
- dbKey: Key of the engine used for the migration (passed to DBManager)
- dbUrl: URL of the db connection used for the migration (passed to DBManager)
- Either dbKey or dbUrl is required, with dbUrl having precedence
- scriptLocation: Location of the alembic root folder (passed to DBManager)
- autoUpgrade: Flag which auto-upgrades to the latest revision (passed to DBManager)
Expand All @@ -159,12 +147,35 @@ def __init__(self, plugins: List[AExtensionDesc] = []):
# Priority of assignments: Configuration File -> Plugin Definition
id = config.get('id') or (p.id if hasattr(p, 'id') else None)
db_key = config.get('dbKey') or (p.dbKey if hasattr(p, 'dbKey') else None)
db_url = config.get('dbUrl') or (p.dbUrl if hasattr(p, 'dbUrl') else None)
script_location = config.get('scriptLocation') or (p.scriptLocation if hasattr(p, 'scriptLocation') else None)
auto_upgrade = config.get('autoUpgrade') if type(config.get('autoUpgrade')) == bool else \
(p.autoUpgrade if hasattr(p, 'autoUpgrade') and type(p.autoUpgrade) == bool else False)

# Validate the plugin description
missing_fields = []
if not id:
missing_fields.append('id')
if not script_location:
missing_fields.append('scriptLocation')
if not db_key and not db_url:
missing_fields.append('dbUrl or dbKey')

if len(missing_fields) > 0:
raise ValueError('No {} defined for DBMigration {} - is your configuration up to date?'.format(', '.join(missing_fields), id or '<UNKNOWN>'))

if db_key and db_url:
_log.info(f'Both dbKey and dbUrl defined for DBMigration {id} - falling back to dbUrl')
elif db_key:
# Check if engine exists
if db_key not in engines:
raise ValueError(f'No engine called {db_key} found for DBMigration {id} - is your configuration up to date?')

# Retrieve engine and store string as db url
db_url = str(engines.engine(db_key).url)

# Create new migration
migration = DBMigration(id, db_key, script_location, auto_upgrade)
migration = DBMigration(id, db_url, script_location, auto_upgrade)

# Store migration
self._migrations[migration.id] = migration
Expand All @@ -174,7 +185,7 @@ def __contains__(self, item):

def __getitem__(self, item):
if item not in self:
raise NotImplementedError('missing db migration: ' + item)
raise NotImplementedError('Missing DBMigration: ' + item)
return self._migrations[item]

def __len__(self):
Expand Down

0 comments on commit 8b3f898

Please sign in to comment.