Skip to content

Commit

Permalink
Merge pull request #3 from JiCiT/master
Browse files Browse the repository at this point in the history
add user and password options; make more Windows friendly
  • Loading branch information
GammaGames authored Nov 4, 2021
2 parents d77c5d6 + fd0ec38 commit c8a5080
Showing 1 changed file with 38 additions and 13 deletions.
51 changes: 38 additions & 13 deletions reorder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python3

import os
import re
import subprocess
from inspect import cleandoc
Expand All @@ -9,20 +10,40 @@
import psycopg2
import psycopg2.extras

def find(file, paths):

def get_dump_sql(database: str, schema: str, table: str) -> str:
for path in paths:
for root, dirs, files in os.walk(path):
if file in files:
return os.path.join(root, file)


def get_pg_dump() -> str:
if os.name == "nt":
file = "pg_dump.exe"
paths = os.environ["PATH"].split(";")
paths.append("C:\\Program Files\\PostgreSQL")
else:
file = "pg_dump"
paths = os.environ["PATH"].split(":")
exe = find(file, paths)
return exe


def get_dump_sql(database: str, schema: str, table: str, user: str) -> str:
"""Get SQL that would be returned with `pg_dump`"""
result = subprocess.run(
["pg_dump", "--schema-only", f"--table={schema}.{table}", database],
[get_pg_dump(), f"--username={user}", "--schema-only", f"--table={schema}.{table}", database],
capture_output=True,
check=True,
text=True,
)
return result.stdout.decode()
return result.stdout


def get_columns(database: str, schema: str, table: str) -> Tuple[List[str], List[str]]:
def get_columns(database: str, schema: str, table: str, user: str) -> Tuple[List[str], List[str]]:
"""Get columns for a table"""
sql_text = get_dump_sql(database, schema, table)
sql_text = get_dump_sql(database, schema, table, user)
table_re = re.compile(
fr"(?P<pre>(?:\n|.)+CREATE TABLE {schema}.{table}\s+\(\n)(?P<rows>(?:\n|.)+?)(?P<post>\);(?:\n|.)+)"
)
Expand All @@ -42,10 +63,10 @@ def get_columns(database: str, schema: str, table: str) -> Tuple[List[str], List


def get_migration_sql(
database: str, schema: str, table: str, columns: List[str], extras: List[str]
database: str, schema: str, table: str, user:str, password: str, columns: List[str], extras: List[str]
) -> str:
"""Get SQL command to migrate a source table into the target table"""
sql_text = get_dump_sql(database, schema, table)
sql_text = get_dump_sql(database, schema, table, user)
table_re = re.compile(
fr"(?P<pre>(?:\n|.)+)(?P<table>CREATE TABLE {schema}\.{table}\s+\(\n(?:\n|.)+?\);)(?P<post>(?:\n|.)+)"
)
Expand All @@ -59,7 +80,7 @@ def get_migration_sql(
DROP CONSTRAINT {fk['constraint']};
"""
)
for fk in get_foreign_keys(database, schema, table)
for fk in get_foreign_keys(database, schema, table, user, password)
]
)
fk_enable = "\n".join(
Expand All @@ -71,7 +92,7 @@ def get_migration_sql(
REFERENCES {fk['schema']}.{fk['foreign_table']} ({fk['foreign_column']});
"""
)
for fk in get_foreign_keys(database, schema, table)
for fk in get_foreign_keys(database, schema, table, user, password)
]
)
extra_features = "\n".join(
Expand Down Expand Up @@ -103,9 +124,9 @@ def get_migration_sql(
)


def get_foreign_keys(database: str, schema: str, table: str) -> List[dict]:
def get_foreign_keys(database: str, schema: str, table: str, user: str, password: str) -> List[dict]:
"""Get foreign keys referencing a given table"""
with psycopg2.connect(database=database, user="postgres") as conn:
with psycopg2.connect(database=database, user=user, password=password) as conn:
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute(
"""
Expand Down Expand Up @@ -177,6 +198,8 @@ def printcols(cols: List[str], header: Optional[str] = None) -> None:
@click.option("--exclude", "-e", multiple=True, help="Exclude a column (can be used multiple times).")
@click.option("--database", "-d", help="The name of the database.")
@click.option("--schema", "-n", default="public", help="The schema of the target table.")
@click.option("--user", "-u", default="postgres", help="User name.")
@click.option("--password", "-p", default="", help="Password.")
@click.option("--migrate", "-m", is_flag=True, help="Output full migration sql.")
@click.option("--file", "-f", "output_file", type=click.File("w"), help="Write output into a file.")
@click.argument("table")
Expand All @@ -186,6 +209,8 @@ def main(
exclude,
database: str,
schema,
user: str,
password: str,
output_file,
table: str,
columns: Tuple[str],
Expand All @@ -198,14 +223,14 @@ def main(
and the last column will be placed at the end of the table. When entered as
"... col1 col2 col3" all three columns will be placed at the end of the table.
"""
cols, extras = get_columns(database, schema, table)
cols, extras = get_columns(database, schema, table, user)

if len(columns):
target_start, target_end = sort_input_columns(list(columns))
cols = reorder_columns(target_start, target_end, list(exclude), cols)

if migrate:
query = get_migration_sql(database, schema, table, cols, extras)
query = get_migration_sql(database, schema, table, user, password, cols, extras)

if output_file is not None:
output_file.write(query)
Expand Down

0 comments on commit c8a5080

Please sign in to comment.