Skip to content

Commit

Permalink
Merge pull request #257 from informatics-isi-edu/constraint_name_fixes
Browse files Browse the repository at this point in the history
fixes for constraint name handling when creating keys and foreign keys
  • Loading branch information
karlcz authored Oct 23, 2024
2 parents 137db5a + 1b67874 commit 4adb20c
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 51 deletions.
132 changes: 83 additions & 49 deletions ermrest/model/key.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from .. import exception
from ..util import sql_identifier, sql_literal, constraint_exists, OrderedFrozenSet
from .misc import frozendict, AltDict, AclDict, DynaclDict, keying, annotatable, cache_rights, hasacls, hasdynacls, enforce_63byte_id, make_id
from .misc import frozendict, AltDict, AclDict, DynaclDict, keying, annotatable, cache_rights, hasacls, hasdynacls, enforce_63byte_id, make_id, make_constraint_name
from .name import _keyref_join_str, _keyref_join_sql

@annotatable
Expand Down Expand Up @@ -201,24 +201,17 @@ def fromjson(table, keysdoc):
for key in Unique.fromjson_single(table, keydoc):
yield key

def _constraint_name_exists(self, cur, sname, name):
return constraint_exists(cur, name)

def _find_new_constraint_name(self, cur, sname):
n = None
while True:
name = make_id(self.table.name, [c.name for c in self.columns], (('key%d' % n) if n else 'key'))
if not self._constraint_name_exists(cur, sname, name):
break
if n is None:
n = 1
else:
n += 1
return (sname, name)

def add(self, conn, cur):
if not self.constraint_name:
self.constraint_name = self._find_new_constraint_name(cur, self.table.schema.name)
self.constraint_name = (
self.table.schema.name,
make_constraint_name(
cur,
self.table.name,
[c.name for c in self.columns],
'key',
)
)
self.table.alter_table(
conn, cur,
'ADD %s' % self.sql_def(),
Expand Down Expand Up @@ -278,7 +271,7 @@ def has_right(self, aclname, roles=None):
return True

@annotatable
@keying('pseudo_key', {"pkey_rid": ('text', lambda self: self.rid)})
@keying('pseudo_key', {"key_rid": ('text', lambda self: self.rid)})
class PseudoUnique (object):
"""A pseudo-uniqueness constraint."""

Expand Down Expand Up @@ -389,27 +382,39 @@ def prejson(self):
'names': [ self.constraint_name ],
}

def _constraint_name_exists(self, cur, name):
@staticmethod
def _pseudo_key_exists(cur, name):
cur.execute("""
SELECT True
FROM _ermrest.known_pseudo_keys
WHERE constraint_name = %(constraint_name)s;
""" % {
'constraint_name': name,
'constraint_name': sql_literal(name),
})
return cur.fetchone()[0]
for row in cur:
return row[0]
return False

def add(self, conn, cur):
self.table.enforce_right('owner') # since we don't use alter_table which enforces for real keys
if not self.constraint_name:
self.constraint_name = self._find_new_constraint_name(cur, "")
self.constraint_name = (
"",
make_constraint_name(
cur,
self.table.name,
[c.name for c in self.columns],
'key',
probefunc=self._pseudo_key_exists,
)
)
cur.execute("""
SELECT _ermrest.model_version_bump();
INSERT INTO _ermrest.known_pseudo_keys (constraint_name, table_rid, comment)
VALUES (%(constraint_name)s, %(table_rid)s, %(comment)s)
RETURNING "RID";
""" % {
'constraint_name': sql_literal(name),
'constraint_name': sql_literal(self.constraint_name[1]),
'table_rid': sql_literal(self.table.rid),
'comment': sql_literal(self.comment),
})
Expand Down Expand Up @@ -580,6 +585,21 @@ def _keyref_has_right(self, aclname, roles=None):
return True
return self._has_right(aclname, roles, anon_mutation_ok=True)

def _keyref_fk_pk_cols_ordered(self):
"""Return fk_cols, pk_cols sorted by referenced key's stable column order"""
pk_cols = list(self.unique.columns)
pk_cols_ord = {
pk_cols[i]: i
for i in range(len(pk_cols))
}
fk_cols_ord = {
fkc: pk_cols_ord[pkc]
for fkc, pkc in self.reference_map.items()
}
fk_cols = [ p[0] for p in sorted(fk_cols_ord.items(), key=lambda p: p[1]) ]
pk_cols = [ self.reference_map[c] for c in fk_cols ]
return fk_cols, pk_cols

@annotatable
@hasdynacls({ "owner", "insert", "update" })
@hasacls(
Expand Down Expand Up @@ -648,20 +668,8 @@ def verbose(self):
return json.dumps(self.prejson(), indent=2)

def _fk_pk_cols_ordered(self):
"""Return fk_cols, pk_cols sorted by referenced key's stable column order"""
pk_cols = list(self.unique.columns)
pk_cols_ord = {
pk_cols[i]: i
for i in range(len(pk_cols))
}
fk_cols_ord = {
fkc: pk_cols_ord[pkc]
for fkc, pkc in self.reference_map.items()
}
fk_cols = [ p[0] for p in sorted(fk_cols_ord.items(), key=lambda p: p[1]) ]
pk_cols = [ self.reference_map[c] for c in fk_cols ]
return fk_cols, pk_cols

return _keyref_fk_pk_cols_ordered(self)

def sql_def(self):
"""Render SQL table constraint clause for DDL."""
# sort out column ordering to obey referenced key's stable column order
Expand All @@ -680,16 +688,15 @@ def sql_def(self):

def add(self, conn, cur):
if not self.constraint_name:
n = None
while True:
name = make_id(self.foreign_key.table.name, [c.name for c in self.foreign_key.columns], (('fkey%d' % n) if n else 'fkey'))
if not constraint_exists(cur, name):
break
if n is None:
n = 1
else:
n += 1
self.constraint_name = (self.foreign_key.table.schema.name, name)
self.constraint_name = (
self.foreign_key.table.schema.name,
make_constraint_name(
cur,
self.foreign_key.table.name,
[c.name for c in self.foreign_key.columns],
'fkey',
)
)
idx_name = '_'.join(self.constraint_name[1].split('_')[:-1]) + '_idx'
fk_cols, pk_cols = self._fk_pk_cols_ordered()
self.foreign_key.table.alter_table(
Expand Down Expand Up @@ -1032,7 +1039,10 @@ def _from_column_names(self):
def _to_column_names(self):
"""Canonicalized to-column names list."""
return _keyref_to_column_names(self)


def _fk_pk_cols_ordered(self):
return _keyref_fk_pk_cols_ordered(self)

def prejson(self):
return _keyref_prejson(self)

Expand Down Expand Up @@ -1109,8 +1119,32 @@ def update(self, conn, cur, refdoc, ermrest_config):

return newfkr

@staticmethod
def _pseudo_fkey_exists(cur, name):
cur.execute("""
SELECT True
FROM _ermrest.known_pseudo_fkeys
WHERE constraint_name = %(constraint_name)s;
""" % {
'constraint_name': sql_literal(name),
})
for row in cur:
return row[0]
return False

def add(self, conn, cur):
self.foreign_key.table.enforce_right('owner') # since we don't use alter_table which enforces for real keyrefs
if not self.constraint_name:
self.constraint_name = (
self.foreign_key.table.schema.name,
make_constraint_name(
cur,
self.foreign_key.table.name,
[c.name for c in self.foreign_key.columns],
'fkey',
probefunc=self._pseudo_fkey_exists,
)
)
fk_cols = list(self.foreign_key.columns)
cur.execute("""
SELECT _ermrest.model_version_bump();
Expand All @@ -1121,7 +1155,7 @@ def add(self, conn, cur):
'fk_table_rid': sql_literal(self.foreign_key.table.rid),
'pk_table_rid': sql_literal(self.unique.table.rid),
'comment': sql_literal(self.comment),
'constraint_name': sql_literal(self.constraint_name[1]) if self.constraint_name else 'NULL',
'constraint_name': sql_literal(self.constraint_name[1]),
})
self.rid = cur.fetchone()[0]

Expand Down
28 changes: 26 additions & 2 deletions ermrest/model/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from webauthn2.util import deriva_ctx

from .. import exception
from ..util import sql_identifier, sql_literal, table_exists
from ..util import sql_identifier, sql_literal, table_exists, constraint_exists
from .. import ermpath
from .type import _default_config
from .name import Name
Expand Down Expand Up @@ -61,7 +61,7 @@ def make_id(*components):
# accept lists at top-level for convenience (compound keys, etc.)
expanded = []
for e in components:
if isinstance(e, list):
if isinstance(e, (list, tuple)):
expanded.extend(e)
else:
expanded.append(e)
Expand Down Expand Up @@ -119,6 +119,30 @@ def truncate(s, maxlen):
# last-ditch (e.g. multibyte unicode suffix worst case)
return truncate(naive_result, 55) + naive_hash

def make_constraint_name(cur, *parts, probefunc=constraint_exists):
"""Build an identifier that won't collide with any existing constraint names
:param cur: Cursor to database to probe for name collisions
:param parts: Ordered list of strings or lists of strings to join with '_' to form name
:param probefunc: Function to probe the database cursor for an existing name
The final element parts[-1] must be of type str.
The probefunc signature is probefunc(cursor, namestr) -> bool.
"""
n = None
while True:
suffix = parts[-1]
if not isinstance(suffix, str):
raise NotImplementedError('make_constraint_name expects string suffix as final argument')
name = make_id(*parts[:-1], ('%s%d' % (suffix, n) if n else suffix))
if not probefunc(cur, name):
return name
if n is None:
n = 1
else:
n += 1

sufficient_rights = {
"owner": set(),
"create": {"owner"},
Expand Down

0 comments on commit 4adb20c

Please sign in to comment.