Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add postgresql migration #464

Merged
merged 10 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 87 additions & 3 deletions python/migrate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from decimal import Decimal
import json
import mgp
import mysql.connector as mysql_connector
import oracledb
import pyodbc
import psycopg2
import threading

from typing import Any, Dict
Expand Down Expand Up @@ -84,8 +86,8 @@ def mysql(
def cleanup_migrate_mysql():
global mysql_dict
mysql_dict[threading.get_native_id][Constants.CURSOR] = None
mysql_dict[threading.get_native_id][Constants.CONNECTION].close()
mysql_dict[threading.get_native_id][Constants.CONNECTION].commit()
mysql_dict[threading.get_native_id][Constants.CONNECTION].close()
mysql_dict[threading.get_native_id][Constants.CONNECTION] = None
mysql_dict[threading.get_native_id][Constants.COLUMN_NAMES] = None

Expand Down Expand Up @@ -163,8 +165,8 @@ def sql_server(
def cleanup_migrate_sql_server():
global sql_server_dict
sql_server_dict[threading.get_native_id][Constants.CURSOR] = None
sql_server_dict[threading.get_native_id][Constants.CONNECTION].close()
sql_server_dict[threading.get_native_id][Constants.CONNECTION].commit()
sql_server_dict[threading.get_native_id][Constants.CONNECTION].close()
sql_server_dict[threading.get_native_id][Constants.CONNECTION] = None
sql_server_dict[threading.get_native_id][Constants.COLUMN_NAMES] = None

Expand Down Expand Up @@ -261,6 +263,85 @@ def cleanup_migrate_oracle_db():
mgp.add_batch_read_proc(oracle_db, init_migrate_oracle_db, cleanup_migrate_oracle_db)


# PostgreSQL dictionary to store connections and cursors by thread
postgres_dict = {}


def init_migrate_postgresql(
table_or_sql: str,
config: mgp.Map,
config_path: str = "",
params: mgp.Nullable[mgp.Any] = None,
):
global postgres_dict

if params:
_check_params_type(params, (list, tuple))
else:
params = []

if len(config_path) > 0:
config = _combine_config(config=config, config_path=config_path)

if _query_is_table(table_or_sql):
table_or_sql = f"SELECT * FROM {table_or_sql};"

if threading.get_native_id not in postgres_dict:
postgres_dict[threading.get_native_id] = {}

if Constants.CURSOR not in postgres_dict[threading.get_native_id]:
postgres_dict[threading.get_native_id][Constants.CURSOR] = None

if postgres_dict[threading.get_native_id][Constants.CURSOR] is None:
connection = psycopg2.connect(**config)
cursor = connection.cursor()
cursor.execute(table_or_sql, params)

postgres_dict[threading.get_native_id][Constants.CONNECTION] = connection
postgres_dict[threading.get_native_id][Constants.CURSOR] = cursor
postgres_dict[threading.get_native_id][Constants.COLUMN_NAMES] = [
column.name for column in cursor.description
]


def postgresql(
table_or_sql: str,
config: mgp.Map,
config_path: str = "",
params: mgp.Nullable[mgp.Any] = None,
) -> mgp.Record(row=mgp.Map):
"""
With migrate.postgresql you can access PostgreSQL and execute queries. The result table is converted into a stream,
and returned rows can be used to create or create graph structures. Config must be at least empty map.
If config_path is passed, every key,value pair from JSON file will overwrite any values in config file.

:param table_or_sql: Table name or an SQL query
:param config: Connection configuration parameters (as in psycopg2.connect),
:param config_path: Path to the JSON file containing configuration parameters (as in psycopg2.connect)
:param params: Optionally, queries may be parameterized. In that case, `params` provides parameter values
:return: The result table as a stream of rows
"""
global postgres_dict
cursor = postgres_dict[threading.get_native_id][Constants.CURSOR]
column_names = postgres_dict[threading.get_native_id][Constants.COLUMN_NAMES]

rows = cursor.fetchmany(Constants.BATCH_SIZE)

return [mgp.Record(row=_name_row_cells(row, column_names)) for row in rows]


def cleanup_migrate_postgresql():
global postgres_dict
postgres_dict[threading.get_native_id][Constants.CURSOR] = None
postgres_dict[threading.get_native_id][Constants.CONNECTION].commit()
postgres_dict[threading.get_native_id][Constants.CONNECTION].close()
postgres_dict[threading.get_native_id][Constants.CONNECTION] = None
postgres_dict[threading.get_native_id][Constants.COLUMN_NAMES] = None


mgp.add_batch_read_proc(postgresql, init_migrate_postgresql, cleanup_migrate_postgresql)


def _query_is_table(table_or_sql: str) -> bool:
return len(table_or_sql.split()) == 1

Expand All @@ -283,7 +364,10 @@ def _combine_config(config: mgp.Map, config_path: str) -> Dict[str, Any]:


def _name_row_cells(row_cells, column_names) -> Dict[str, Any]:
return dict(map(lambda column, value: (column, value), column_names, row_cells))
return {
column: (value if not isinstance(value, Decimal) else float(value))
for column, value in zip(column_names, row_cells)
}


def _check_params_type(params: Any, types=(dict, list, tuple)) -> None:
Expand Down
1 change: 1 addition & 0 deletions python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ gqlalchemy==1.4.1
mysql-connector-python==8.0.32
oracledb==1.2.2
pyodbc==4.0.35
psycopg2-binary==2.9.9
defusedxml==0.7.1
scipy==1.12.0
1 change: 1 addition & 0 deletions python/requirements_no_ml.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ gqlalchemy==1.4.1
mysql-connector-python==8.0.32
oracledb==1.2.2
pyodbc==4.0.35
psycopg2-binary==2.9.9
defusedxml==0.7.1
scipy==1.12.0
Loading