Skip to content

Commit

Permalink
fix: linux-setup mariadb json data types (#714)
Browse files Browse the repository at this point in the history
  • Loading branch information
devrimyatar authored Jan 26, 2022
1 parent 9071db4 commit 4c21be2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 21 deletions.
20 changes: 12 additions & 8 deletions jans-linux-setup/setup_app/installers/rdbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(self):
self.install_type = InstallOption.OPTONAL
self.install_var = 'rdbm_install'
self.register_progess()

self.qchar = '`' if Config.rdbm_type in ('mysql', 'spanner') else '"'
self.output_dir = os.path.join(Config.outputFolder, Config.rdbm_type)

def install(self):
Expand Down Expand Up @@ -111,14 +111,21 @@ def get_sql_col_type(self, attrname, table=None):

return data_type


def get_col_def(self, attrname, sql_tbl_name):
data_type = self.get_sql_col_type(attrname, sql_tbl_name)
col_def = '{0}{1}{0} {2}'.format(self.qchar, attrname, data_type)
if Config.rdbm_type == 'mysql' and data_type == 'JSON':
col_def += ' comment "json"'
return col_def

def create_tables(self, jans_schema_files):
self.logIt("Creating tables for {}".format(jans_schema_files))
qchar = '`' if Config.rdbm_type in ('mysql', 'spanner') else '"'
tables = []
all_schema = {}
all_attribs = {}
column_add = 'COLUMN ' if Config.rdbm_type == 'spanner' else ''
alter_table_sql_cmd = 'ALTER TABLE %s{}%s ADD %s{};' % (qchar, qchar, column_add)
alter_table_sql_cmd = 'ALTER TABLE %s{}%s ADD %s{};' % (self.qchar, self.qchar, column_add)

for jans_schema_fn in jans_schema_files:
jans_schema = base.readJsonFile(jans_schema_fn)
Expand Down Expand Up @@ -161,9 +168,7 @@ def create_tables(self, jans_schema_files):
continue

cols_.append(attrname)
data_type = self.get_sql_col_type(attrname, sql_tbl_name)

col_def = '{0}{1}{0} {2}'.format(qchar, attrname, data_type)
col_def = self.get_col_def(attrname, sql_tbl_name)
sql_tbl_cols.append(col_def)

if not self.dbUtils.table_exists(sql_tbl_name):
Expand All @@ -180,8 +185,7 @@ def create_tables(self, jans_schema_files):
for attrname in all_attribs:
attr = all_attribs[attrname]
if attr.get('sql', {}).get('add_table'):
data_type = self.get_sql_col_type(attrname, sql_tbl_name)
col_def = '{0}{1}{0} {2}'.format(qchar, attrname, data_type)
col_def = self.get_col_def(attrname, sql_tbl_name)
sql_cmd = alter_table_sql_cmd.format(attr['sql']['add_table'], col_def)

if Config.rdbm_type == 'spanner':
Expand Down
20 changes: 7 additions & 13 deletions jans-linux-setup/setup_app/utils/db_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class DBUtils:
Base = None
session = None
cbm = None
mariadb = False

def bind(self, use_ssl=True, force=False):

Expand Down Expand Up @@ -121,12 +120,6 @@ def sqlconnection(self, log=True):
self.metadata = sqlalchemy.MetaData()
myconn = self.session.connection()

# are we running on MariaDB?
query = myconn.execute("select version()")
result = query.first()
if result and 'mariadb' in result[0].lower():
self.mariadb = True

base.logIt("{} Connection was successful".format(Config.rdbm_type.upper()))

return True, self.session
Expand All @@ -138,8 +131,6 @@ def sqlconnection(self, log=True):

@property
def json_dialects_instance(self):
if self.mariadb:
return sqlalchemy.dialects.mysql.LONGTEXT
return sqlalchemy.dialects.mysql.json.JSON if Config.rdbm_type == 'mysql' else sqlalchemy.dialects.postgresql.json.JSON

def mysqlconnection(self, log=True):
Expand Down Expand Up @@ -687,6 +678,12 @@ def rdm_automapper(self, force=False):
self.Base = sqlalchemy.ext.automap.automap_base(metadata=self.metadata)
self.Base.prepare()

# fix JSON type for mariadb
for tbl in self.Base.classes:
for col in tbl.__table__.columns:
if isinstance(col.type, sqlalchemy.dialects.mysql.LONGTEXT) and col.comment.lower() == 'json':
col.type = sqlalchemy.dialects.mysql.json.JSON()

base.logIt("Reflected tables {}".format(list(self.metadata.tables.keys())))

def get_sqlalchObj_for_dn(self, dn):
Expand Down Expand Up @@ -891,10 +888,7 @@ def import_ldif(self, ldif_files, bucket=None, force=None):
sqlalchObj = sqlalchCls()

for v in vals:
vval = vals[v]
if self.mariadb and isinstance(vval, dict):
vval = json.dumps(vals[v])
setattr(sqlalchObj, v, vval)
setattr(sqlalchObj, v, vals[v])

base.logIt("Adding {}".format(sqlalchObj.doc_id))
self.session.add(sqlalchObj)
Expand Down

0 comments on commit 4c21be2

Please sign in to comment.