Skip to content

Commit

Permalink
database: add filter to insert logic
Browse files Browse the repository at this point in the history
Turns out clients can be generating more
fields/values than the database/cherrypy server
know about owing to separate MTT based projects
contributing code.

Fixes open-mpi#614

Signed-off-by: Howard Pritchard <howardp@lanl.gov>
  • Loading branch information
hppritcha committed Feb 16, 2018
1 parent 0bb5b7b commit c69ed64
Showing 1 changed file with 26 additions and 12 deletions.
38 changes: 26 additions & 12 deletions server/php/cherrypy/src/webapp/db_pgv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,39 +559,53 @@ def _get_nextval(self, seq_name):
def _select_insert(self, table, table_id, stmt_fields, stmt_values):
found_id = -1

cursor = self.get_cursor()

cursor.execute("SELECT * FROM %s LIMIT 0" % table)
all_fields_for_table = [d[0] for d in cursor.description]

#
# Build the SELECT and INSERT statements
#
select_stmt = "\nSELECT %s FROM %s \n" % (table_id, table)
insert_stmt = "\nINSERT INTO %s \n (%s" % (table, table_id)
insert_stmt_values = []

#
# we filter because the client may have sent us stuff our database has
# no clue about
#
count = 0
for field in stmt_fields:
insert_stmt = insert_stmt + ", " + field

if count == 0:
select_stmt = select_stmt + " WHERE "
else:
select_stmt = select_stmt + " AND "
select_stmt = select_stmt + field + " = %s"
if field in all_fields_for_table:
insert_stmt = insert_stmt + ", " + field
insert_stmt_values.append(stmt_values[count])
if count == 0:
select_stmt = select_stmt + " WHERE "
else:
select_stmt = select_stmt + " AND "
select_stmt = select_stmt + field + " = %s"
count += 1

select_stmt = select_stmt + "\n ORDER BY " + table_id + " ASC LIMIT 1"

insert_stmt = insert_stmt + ") \nVALUES ("
insert_stmt = insert_stmt + " %s"
for value in stmt_values:
insert_stmt = insert_stmt + ", %s"
for field in stmt_fields:
if field in all_fields_for_table:
insert_stmt = insert_stmt + ", %s"
insert_stmt = insert_stmt + ")"

#
# Try the select to see if we need to insert
#
#self._logger.debug(select_stmt)
#self._logger.debug(insert_stmt)
#self._logger.debug(str(insert_stmt_values))

cursor = self.get_cursor()

values = tuple(stmt_values)
values = tuple(insert_stmt_values)
cursor.execute( select_stmt, values )
rows = cursor.fetchone()
if rows is not None:
Expand All @@ -607,8 +621,8 @@ def _select_insert(self, table, table_id, stmt_fields, stmt_values):
self._logger.debug( ", ".join(str(x) for x in values) )
found_id = self._get_nextval( "%s_%s_seq" % (table, table_id))

stmt_values.insert(0, found_id)
values = tuple(stmt_values)
insert_stmt_values.insert(0, found_id)
values = tuple(insert_stmt_values)
cursor.execute( insert_stmt, values )
# Make sure to commit after every INSERT
self._connection.commit()
Expand Down

0 comments on commit c69ed64

Please sign in to comment.