diff --git a/python/migrate.py b/python/migrate.py index 055c067af..3e112dac4 100644 --- a/python/migrate.py +++ b/python/migrate.py @@ -1,8 +1,10 @@ +from decimal import Decimal import json import mgp import mysql.connector as mysql_connector import oracledb import pyodbc +import psycopg2 import threading from typing import Any, Dict @@ -84,8 +86,8 @@ def mysql( def cleanup_migrate_mysql(): global mysql_dict mysql_dict[threading.get_native_id][Constants.CURSOR] = None - mysql_dict[threading.get_native_id][Constants.CONNECTION].close() mysql_dict[threading.get_native_id][Constants.CONNECTION].commit() + mysql_dict[threading.get_native_id][Constants.CONNECTION].close() mysql_dict[threading.get_native_id][Constants.CONNECTION] = None mysql_dict[threading.get_native_id][Constants.COLUMN_NAMES] = None @@ -163,8 +165,8 @@ def sql_server( def cleanup_migrate_sql_server(): global sql_server_dict sql_server_dict[threading.get_native_id][Constants.CURSOR] = None - sql_server_dict[threading.get_native_id][Constants.CONNECTION].close() sql_server_dict[threading.get_native_id][Constants.CONNECTION].commit() + sql_server_dict[threading.get_native_id][Constants.CONNECTION].close() sql_server_dict[threading.get_native_id][Constants.CONNECTION] = None sql_server_dict[threading.get_native_id][Constants.COLUMN_NAMES] = None @@ -261,6 +263,85 @@ def cleanup_migrate_oracle_db(): mgp.add_batch_read_proc(oracle_db, init_migrate_oracle_db, cleanup_migrate_oracle_db) +# PostgreSQL dictionary to store connections and cursors by thread +postgres_dict = {} + + +def init_migrate_postgresql( + table_or_sql: str, + config: mgp.Map, + config_path: str = "", + params: mgp.Nullable[mgp.Any] = None, +): + global postgres_dict + + if params: + _check_params_type(params, (list, tuple)) + else: + params = [] + + if len(config_path) > 0: + config = _combine_config(config=config, config_path=config_path) + + if _query_is_table(table_or_sql): + table_or_sql = f"SELECT * FROM {table_or_sql};" + + if threading.get_native_id not in postgres_dict: + postgres_dict[threading.get_native_id] = {} + + if Constants.CURSOR not in postgres_dict[threading.get_native_id]: + postgres_dict[threading.get_native_id][Constants.CURSOR] = None + + if postgres_dict[threading.get_native_id][Constants.CURSOR] is None: + connection = psycopg2.connect(**config) + cursor = connection.cursor() + cursor.execute(table_or_sql, params) + + postgres_dict[threading.get_native_id][Constants.CONNECTION] = connection + postgres_dict[threading.get_native_id][Constants.CURSOR] = cursor + postgres_dict[threading.get_native_id][Constants.COLUMN_NAMES] = [ + column.name for column in cursor.description + ] + + +def postgresql( + table_or_sql: str, + config: mgp.Map, + config_path: str = "", + params: mgp.Nullable[mgp.Any] = None, +) -> mgp.Record(row=mgp.Map): + """ + With migrate.postgresql you can access PostgreSQL and execute queries. The result table is converted into a stream, + and returned rows can be used to create or create graph structures. Config must be at least empty map. + If config_path is passed, every key,value pair from JSON file will overwrite any values in config file. + + :param table_or_sql: Table name or an SQL query + :param config: Connection configuration parameters (as in psycopg2.connect), + :param config_path: Path to the JSON file containing configuration parameters (as in psycopg2.connect) + :param params: Optionally, queries may be parameterized. In that case, `params` provides parameter values + :return: The result table as a stream of rows + """ + global postgres_dict + cursor = postgres_dict[threading.get_native_id][Constants.CURSOR] + column_names = postgres_dict[threading.get_native_id][Constants.COLUMN_NAMES] + + rows = cursor.fetchmany(Constants.BATCH_SIZE) + + return [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows] + + +def cleanup_migrate_postgresql(): + global postgres_dict + postgres_dict[threading.get_native_id][Constants.CURSOR] = None + postgres_dict[threading.get_native_id][Constants.CONNECTION].commit() + postgres_dict[threading.get_native_id][Constants.CONNECTION].close() + postgres_dict[threading.get_native_id][Constants.CONNECTION] = None + postgres_dict[threading.get_native_id][Constants.COLUMN_NAMES] = None + + +mgp.add_batch_read_proc(postgresql, init_migrate_postgresql, cleanup_migrate_postgresql) + + def _query_is_table(table_or_sql: str) -> bool: return len(table_or_sql.split()) == 1 @@ -283,7 +364,10 @@ def _combine_config(config: mgp.Map, config_path: str) -> Dict[str, Any]: def _name_row_cells(row_cells, column_names) -> Dict[str, Any]: - return dict(map(lambda column, value: (column, value), column_names, row_cells)) + return { + column: (value if not isinstance(value, Decimal) else float(value)) + for column, value in zip(column_names, row_cells) + } def _check_params_type(params: Any, types=(dict, list, tuple)) -> None: diff --git a/python/requirements.txt b/python/requirements.txt index 45cc459da..c64bae23c 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -12,5 +12,6 @@ gqlalchemy==1.4.1 mysql-connector-python==8.0.32 oracledb==1.2.2 pyodbc==4.0.35 +psycopg2-binary==2.9.9 defusedxml==0.7.1 scipy==1.12.0 diff --git a/python/requirements_no_ml.txt b/python/requirements_no_ml.txt index 607ba4b2b..67a1686f6 100644 --- a/python/requirements_no_ml.txt +++ b/python/requirements_no_ml.txt @@ -10,5 +10,6 @@ gqlalchemy==1.4.1 mysql-connector-python==8.0.32 oracledb==1.2.2 pyodbc==4.0.35 +psycopg2-binary==2.9.9 defusedxml==0.7.1 scipy==1.12.0