Skip to content

Commit

Permalink
Merge pull request #98 from Pogchamp-company/develop
Browse files Browse the repository at this point in the history
Version 1.6.0
  • Loading branch information
RustyGuard authored Jan 29, 2025
2 parents 87e969b + de2ab82 commit 43fa838
Show file tree
Hide file tree
Showing 13 changed files with 264 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: psf/black@stable
- uses: psf/black@25.1.0
2 changes: 1 addition & 1 deletion .github/workflows/test_on_push.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ on:

jobs:
run_tests:
runs-on: [ubuntu-latest]
runs-on: [ubuntu-22.04]
strategy:
matrix:
sqlalchemy: [ "1.4", "2.0" ]
Expand Down
9 changes: 8 additions & 1 deletion alembic_postgresql_enum/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from .compare_dispatch import compare_enums
from .compare_dispatch import compare_enums as _
from .get_enum_data import ColumnType, TableReference
from .configuration import set_configuration, Config

__all__ = (
"ColumnType",
"TableReference",
"set_configuration",
"Config",
)
6 changes: 5 additions & 1 deletion alembic_postgresql_enum/get_enum_data/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,15 @@ def is_column_type_import_needed(self):
@property
def table_name_with_schema(self):
if self.table_schema:
prefix = f"{self.table_schema}."
prefix = f'"{self.table_schema}".'
else:
prefix = ""
return f'{prefix}"{self.table_name}"'

@property
def escaped_column_name(self):
return f'"{self.column_name}"'


EnumNamesToValues = Dict[str, Tuple[str, ...]]
EnumNamesToTableReferences = Dict[str, FrozenSet[TableReference]]
Expand Down
15 changes: 7 additions & 8 deletions alembic_postgresql_enum/operations/sync_enum_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,19 +76,19 @@ def _set_enum_values(
affected_columns: List[TableReference],
enum_values_to_rename: List[Tuple[str, str]],
):
enum_type_name = f"{enum_schema}.{enum_name}"
enum_type_name = f'"{enum_schema}"."{enum_name}"'
temporary_enum_name = f"{enum_name}_old"

rename_type(connection, enum_schema, enum_name, temporary_enum_name)
create_type(connection, enum_schema, enum_name, new_values)
rename_type(connection, enum_type_name, temporary_enum_name)
create_type(connection, enum_type_name, new_values)

create_comparison_operators(connection, enum_schema, enum_name, temporary_enum_name, enum_values_to_rename)

for table_reference in affected_columns:
column_default = table_reference.existing_server_default

if column_default is not None:
drop_default(connection, table_reference.table_name_with_schema, table_reference.column_name)
drop_default(connection, table_reference)

try:
cast_old_enum_type_to_new(connection, table_reference, enum_type_name, enum_values_to_rename)
Expand All @@ -104,12 +104,11 @@ def _set_enum_values(
enum_schema, column_default, enum_name, enum_values_to_rename
)

set_default(
connection, table_reference.table_name_with_schema, table_reference.column_name, column_default
)
set_default(connection, table_reference, column_default)

drop_comparison_operators(connection, enum_schema, enum_name, temporary_enum_name)
drop_type(connection, enum_schema, temporary_enum_name)
temporary_enum_type_name = f'"{enum_schema}"."{temporary_enum_name}"'
drop_type(connection, temporary_enum_type_name)

@classmethod
def sync_enum_values(
Expand Down
20 changes: 12 additions & 8 deletions alembic_postgresql_enum/sql_commands/column_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import sqlalchemy

from alembic_postgresql_enum.get_enum_data.types import TableReference

if TYPE_CHECKING:
from sqlalchemy.engine import Connection

Expand All @@ -26,30 +28,32 @@ def get_column_default(
return default_value


def drop_default(connection: "Connection", table_name_with_schema: str, column_name: str):
def drop_default(
connection: "Connection",
table_reference: TableReference,
):
connection.execute(
sqlalchemy.text(
f"""ALTER TABLE {table_name_with_schema}
ALTER COLUMN {column_name} DROP DEFAULT"""
f"""ALTER TABLE {table_reference.table_name_with_schema}
ALTER COLUMN {table_reference.escaped_column_name} DROP DEFAULT"""
)
)


def set_default(
connection: "Connection",
table_name_with_schema: str,
column_name: str,
table_reference: TableReference,
default_value: str,
):
connection.execute(
sqlalchemy.text(
f"""ALTER TABLE {table_name_with_schema}
ALTER COLUMN {column_name} SET DEFAULT {default_value}"""
f"""ALTER TABLE {table_reference.table_name_with_schema}
ALTER COLUMN {table_reference.escaped_column_name} SET DEFAULT {default_value}"""
)
)


def rename_default_if_required(
def rename_default_if_required( # todo smells like teen shit
schema: str,
default_value: str,
enum_name: str,
Expand Down
28 changes: 13 additions & 15 deletions alembic_postgresql_enum/sql_commands/enum_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@ def cast_old_array_enum_type_to_new(
enum_type_name: str,
enum_values_to_rename: List[Tuple[str, str]],
):
cast_clause = f"{table_reference.column_name}::text[]"
cast_clause = f"{table_reference.escaped_column_name}::text[]"

for old_value, new_value in enum_values_to_rename:
cast_clause = f"""array_replace({cast_clause}, '{old_value}', '{new_value}')"""

connection.execute(
sqlalchemy.text(
f"""ALTER TABLE {table_reference.table_name_with_schema}
ALTER COLUMN {table_reference.column_name} TYPE {enum_type_name}[]
ALTER COLUMN {table_reference.escaped_column_name} TYPE {enum_type_name}[]
USING {cast_clause}::{enum_type_name}[]
"""
)
Expand All @@ -43,13 +43,13 @@ def cast_old_enum_type_to_new(
connection.execute(
sqlalchemy.text(
f"""ALTER TABLE {table_reference.table_name_with_schema}
ALTER COLUMN {table_reference.column_name} TYPE {enum_type_name}
ALTER COLUMN {table_reference.escaped_column_name} TYPE {enum_type_name}
USING CASE
{' '.join(
f"WHEN {table_reference.column_name}::text = '{old_value}' THEN '{new_value}'::{enum_type_name}"
f"WHEN {table_reference.escaped_column_name}::text = '{old_value}' THEN '{new_value}'::{enum_type_name}"
for old_value, new_value in enum_values_to_rename)}
ELSE {table_reference.column_name}::text::{enum_type_name}
ELSE {table_reference.escaped_column_name}::text::{enum_type_name}
END
"""
)
Expand All @@ -58,26 +58,24 @@ def cast_old_enum_type_to_new(
connection.execute(
sqlalchemy.text(
f"""ALTER TABLE {table_reference.table_name_with_schema}
ALTER COLUMN {table_reference.column_name} TYPE {enum_type_name}
USING {table_reference.column_name}::text::{enum_type_name}
ALTER COLUMN {table_reference.escaped_column_name} TYPE {enum_type_name}
USING {table_reference.escaped_column_name}::text::{enum_type_name}
"""
)
)


def drop_type(connection: "Connection", schema: str, type_name: str):
connection.execute(sqlalchemy.text(f"""DROP TYPE {schema}.{type_name}"""))
def drop_type(connection: "Connection", enum_type_name: str):
connection.execute(sqlalchemy.text(f"""DROP TYPE {enum_type_name}"""))


def rename_type(connection: "Connection", schema: str, type_name: str, new_type_name: str):
connection.execute(sqlalchemy.text(f"""ALTER TYPE {schema}.{type_name} RENAME TO {new_type_name}"""))
def rename_type(connection: "Connection", enum_type_name: str, new_type_name: str):
connection.execute(sqlalchemy.text(f"""ALTER TYPE {enum_type_name} RENAME TO {new_type_name}"""))


def create_type(connection: "Connection", schema: str, type_name: str, enum_values: List[str]):
def create_type(connection: "Connection", enum_type_name: str, enum_values: List[str]):
connection.execute(
sqlalchemy.text(
f"""CREATE TYPE {schema}.{type_name} AS ENUM({', '.join(f"'{value}'" for value in enum_values)})"""
)
sqlalchemy.text(f"""CREATE TYPE {enum_type_name} AS ENUM({', '.join(f"'{value}'" for value in enum_values)})""")
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "alembic-postgresql-enum"
version = "1.5.0"
version = "1.6.0"
description = "Alembic autogenerate support for creation, alteration and deletion of enums"
authors = ["RustyGuard"]
license = "MIT"
Expand Down
4 changes: 3 additions & 1 deletion tests/fixtures/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sqlalchemy
from sqlalchemy import create_engine

from tests.schemas import ANOTHER_SCHEMA_NAME, DEFAULT_SCHEMA
from tests.schemas import ANOTHER_SCHEMA_NAME, DEFAULT_SCHEMA, KEYWORD_SCHEMA_NAME

try:
import dotenv
Expand All @@ -27,6 +27,8 @@ def connection() -> Generator:
CREATE SCHEMA {DEFAULT_SCHEMA};
DROP SCHEMA IF EXISTS {ANOTHER_SCHEMA_NAME} CASCADE;
CREATE SCHEMA {ANOTHER_SCHEMA_NAME};
DROP SCHEMA IF EXISTS "{KEYWORD_SCHEMA_NAME}" CASCADE;
CREATE SCHEMA "{KEYWORD_SCHEMA_NAME}";
"""
)
)
Expand Down
1 change: 1 addition & 0 deletions tests/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
CAR_COLORS_ENUM_NAME = "car_color"

ANOTHER_SCHEMA_NAME = "another"
KEYWORD_SCHEMA_NAME = "default"


def get_schema_with_enum_variants(variants: List[str]) -> MetaData:
Expand Down
70 changes: 70 additions & 0 deletions tests/sync_enum_values/test_exotic_column_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Optional

from sqlalchemy import MetaData, Table, Column, Integer
from sqlalchemy.dialects import postgresql

from tests.schemas import USER_TABLE_NAME, USER_STATUS_ENUM_NAME
from tests.base.run_migration_test_abc import CompareAndRunTestCase


class TestExoticColumnNameRender(CompareAndRunTestCase):
"""https://github.com/Pogchamp-company/alembic-postgresql-enum/issues/95"""

old_enum_variants = ["active", "passive"]
new_enum_variants = old_enum_variants + ["banned"]

def get_database_schema(self) -> MetaData:
schema = MetaData()

Table(
USER_TABLE_NAME,
schema,
Column("id", Integer, primary_key=True),
Column(
"case",
postgresql.ENUM(*self.old_enum_variants, name=USER_STATUS_ENUM_NAME),
),
)

return schema

def get_target_schema(self) -> MetaData:
schema = MetaData()

Table(
USER_TABLE_NAME,
schema,
Column("id", Integer, primary_key=True),
Column(
"case",
postgresql.ENUM(*self.new_enum_variants, name=USER_STATUS_ENUM_NAME),
),
)

return schema

def get_expected_upgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values(
enum_schema='public',
enum_name='user_status',
new_values=['active', 'passive', 'banned'],
affected_columns=[TableReference(table_schema='public', table_name='users', column_name='case')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
"""

def get_expected_downgrade(self) -> Optional[str]:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values(
enum_schema='public',
enum_name='user_status',
new_values=['active', 'passive'],
affected_columns=[TableReference(table_schema='public', table_name='users', column_name='case')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
"""
71 changes: 71 additions & 0 deletions tests/sync_enum_values/test_exotic_enum_name.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from typing import Optional

from sqlalchemy import MetaData, Table, Column, Integer
from sqlalchemy.dialects import postgresql

from tests.schemas import USER_STATUS_COLUMN_NAME
from tests.schemas import USER_TABLE_NAME
from tests.base.run_migration_test_abc import CompareAndRunTestCase


class TestExoticColumnNameRender(CompareAndRunTestCase):
"""Test for enum names that are keyword in postgres"""

old_enum_variants = ["active", "passive"]
new_enum_variants = old_enum_variants + ["banned"]

def get_database_schema(self) -> MetaData:
schema = MetaData()

Table(
USER_TABLE_NAME,
schema,
Column("id", Integer, primary_key=True),
Column(
USER_STATUS_COLUMN_NAME,
postgresql.ENUM(*self.old_enum_variants, name="type"),
),
)

return schema

def get_target_schema(self) -> MetaData:
schema = MetaData()

Table(
USER_TABLE_NAME,
schema,
Column("id", Integer, primary_key=True),
Column(
USER_STATUS_COLUMN_NAME,
postgresql.ENUM(*self.new_enum_variants, name="type"),
),
)

return schema

def get_expected_upgrade(self) -> str:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values(
enum_schema='public',
enum_name='type',
new_values=['active', 'passive', 'banned'],
affected_columns=[TableReference(table_schema='public', table_name='users', column_name='status')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
"""

def get_expected_downgrade(self) -> Optional[str]:
return f"""
# ### commands auto generated by Alembic - please adjust! ###
op.sync_enum_values(
enum_schema='public',
enum_name='type',
new_values=['active', 'passive'],
affected_columns=[TableReference(table_schema='public', table_name='users', column_name='status')],
enum_values_to_rename=[],
)
# ### end Alembic commands ###
"""
Loading

0 comments on commit 43fa838

Please sign in to comment.