diff --git a/reorder.py b/reorder.py index fa01a04..e6732f4 100755 --- a/reorder.py +++ b/reorder.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 +import os import re import subprocess from inspect import cleandoc @@ -9,20 +10,40 @@ import psycopg2 import psycopg2.extras +def find(file, paths): -def get_dump_sql(database: str, schema: str, table: str) -> str: + for path in paths: + for root, dirs, files in os.walk(path): + if file in files: + return os.path.join(root, file) + + +def get_pg_dump() -> str: + if os.name == "nt": + file = "pg_dump.exe" + paths = os.environ["PATH"].split(";") + paths.append("C:\\Program Files\\PostgreSQL") + else: + file = "pg_dump" + paths = os.environ["PATH"].split(":") + exe = find(file, paths) + return exe + + +def get_dump_sql(database: str, schema: str, table: str, user: str) -> str: """Get SQL that would be returned with `pg_dump`""" result = subprocess.run( - ["pg_dump", "--schema-only", f"--table={schema}.{table}", database], + [get_pg_dump(), f"--username={user}", "--schema-only", f"--table={schema}.{table}", database], capture_output=True, check=True, + text=True, ) - return result.stdout.decode() + return result.stdout -def get_columns(database: str, schema: str, table: str) -> Tuple[List[str], List[str]]: +def get_columns(database: str, schema: str, table: str, user: str) -> Tuple[List[str], List[str]]: """Get columns for a table""" - sql_text = get_dump_sql(database, schema, table) + sql_text = get_dump_sql(database, schema, table, user) table_re = re.compile( fr"(?P
(?:\n|.)+CREATE TABLE {schema}.{table}\s+\(\n)(?P(?:\n|.)+?)(?P\);(?:\n|.)+)"
     )
@@ -42,10 +63,10 @@ def get_columns(database: str, schema: str, table: str) -> Tuple[List[str], List
 
 
 def get_migration_sql(
-    database: str, schema: str, table: str, columns: List[str], extras: List[str]
+    database: str, schema: str, table: str, user:str, password: str, columns: List[str], extras: List[str]
 ) -> str:
     """Get SQL command to migrate a source table into the target table"""
-    sql_text = get_dump_sql(database, schema, table)
+    sql_text = get_dump_sql(database, schema, table, user)
     table_re = re.compile(
         fr"(?P
(?:\n|.)+)(?PCREATE TABLE {schema}\.{table}\s+\(\n(?:\n|.)+?\);)(?P(?:\n|.)+)"
     )
@@ -59,7 +80,7 @@ def get_migration_sql(
             DROP CONSTRAINT {fk['constraint']};
         """
             )
-            for fk in get_foreign_keys(database, schema, table)
+            for fk in get_foreign_keys(database, schema, table, user, password)
         ]
     )
     fk_enable = "\n".join(
@@ -71,7 +92,7 @@ def get_migration_sql(
             REFERENCES {fk['schema']}.{fk['foreign_table']} ({fk['foreign_column']});
         """
             )
-            for fk in get_foreign_keys(database, schema, table)
+            for fk in get_foreign_keys(database, schema, table, user, password)
         ]
     )
     extra_features = "\n".join(
@@ -103,9 +124,9 @@ def get_migration_sql(
     )
 
 
-def get_foreign_keys(database: str, schema: str, table: str) -> List[dict]:
+def get_foreign_keys(database: str, schema: str, table: str, user: str, password: str) -> List[dict]:
     """Get foreign keys referencing a given table"""
-    with psycopg2.connect(database=database, user="postgres") as conn:
+    with psycopg2.connect(database=database, user=user, password=password) as conn:
         with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
             curs.execute(
                 """
@@ -177,6 +198,8 @@ def printcols(cols: List[str], header: Optional[str] = None) -> None:
 @click.option("--exclude", "-e", multiple=True, help="Exclude a column (can be used multiple times).")
 @click.option("--database", "-d", help="The name of the database.")
 @click.option("--schema", "-n", default="public", help="The schema of the target table.")
+@click.option("--user", "-u", default="postgres", help="User name.")
+@click.option("--password", "-p", default="", help="Password.")
 @click.option("--migrate", "-m", is_flag=True, help="Output full migration sql.")
 @click.option("--file", "-f", "output_file", type=click.File("w"), help="Write output into a file.")
 @click.argument("table")
@@ -186,6 +209,8 @@ def main(
     exclude,
     database: str,
     schema,
+    user: str,
+    password: str,
     output_file,
     table: str,
     columns: Tuple[str],
@@ -198,14 +223,14 @@ def main(
     and the last column will be placed at the end of the table. When entered as
     "... col1 col2 col3" all three columns will be placed at the end of the table.
     """
-    cols, extras = get_columns(database, schema, table)
+    cols, extras = get_columns(database, schema, table, user)
 
     if len(columns):
         target_start, target_end = sort_input_columns(list(columns))
         cols = reorder_columns(target_start, target_end, list(exclude), cols)
 
         if migrate:
-            query = get_migration_sql(database, schema, table, cols, extras)
+            query = get_migration_sql(database, schema, table, user, password, cols, extras)
 
             if output_file is not None:
                 output_file.write(query)