Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add composite foreign key #1385

Merged
merged 10 commits into from
Dec 4, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ private val comparator: Comparator<Column<*>> = compareBy({ it.table.tableName }
* Represents a column.
*/
class Column<T>(
/** Table where the columns is declared. */
/** Table where the columns are declared. */
val table: Table,
/** Name of the column. */
val name: String,
Expand All @@ -25,7 +25,7 @@ class Column<T>(

/** Returns the column that this column references. */
val referee: Column<*>?
get() = foreignKey?.target
get() = foreignKey?.targetOf(this)

/** Returns the column that this column references, casted as a column of type [S], or `null` if the cast fails. */
@Suppress("UNCHECKED_CAST")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,30 +51,45 @@ enum class ReferenceOption {
* Represents a foreign key constraint.
*/
data class ForeignKeyConstraint(
val target: Column<*>,
val from: Column<*>,
val references: Map<Column<*>, Column<*>>,
private val onUpdate: ReferenceOption?,
private val onDelete: ReferenceOption?,
private val name: String?
) : DdlAware {
constructor(
target: Column<*>,
from: Column<*>,
onUpdate: ReferenceOption?,
onDelete: ReferenceOption?,
name: String?
) : this(mapOf(from to target), onUpdate, onDelete, name)

naftalmm marked this conversation as resolved.
Show resolved Hide resolved
private val tx: Transaction
get() = TransactionManager.current()

val target: LinkedHashSet<Column<*>> = LinkedHashSet(references.values)

val targetTable: Table = target.first().table

/** Name of the child table. */
val targetTable: String
get() = tx.identity(target.table)
val targetTableName: String
get() = tx.identity(targetTable)

/** Names of the foreign key columns. */
private val targetColumns: String
get() = target.joinToString { tx.identity(it) }

val from: LinkedHashSet<Column<*>> = LinkedHashSet(references.keys)

/** Name of the foreign key column. */
val targetColumn: String
get() = tx.identity(target)
val fromTable: Table = from.first().table

/** Name of the parent table. */
val fromTable: String
get() = tx.identity(from.table)
val fromTableName: String
get() = tx.identity(fromTable)

/** Name of the key column from the parent table. */
val fromColumn
get() = tx.identity(from)
/** Names of the key columns from the parent table. */
private val fromColumns: String
get() = from.joinToString { tx.identity(it) }

/** Reference option when performing update operations. */
val updateRule: ReferenceOption?
Expand All @@ -91,27 +106,27 @@ data class ForeignKeyConstraint(
/** Name of this constraint. */
val fkName: String
get() = tx.db.identifierManager.cutIfNecessaryAndQuote(
name ?: "fk_${from.table.tableNameWithoutScheme}_${from.name}_${target.name}"
name ?: "fk_${fromTable.tableNameWithoutScheme}_${from.joinToString("_") { it.name }}__${target.joinToString("_") { it.name }}"
).inProperCase()
internal val foreignKeyPart: String
get() = buildString {
if (fkName.isNotBlank()) {
append("CONSTRAINT $fkName ")
}
append("FOREIGN KEY ($fromColumn) REFERENCES $targetTable($targetColumn)")
append("FOREIGN KEY ($fromColumns) REFERENCES $targetTableName($targetColumns)")
if (deleteRule != ReferenceOption.NO_ACTION) {
append(" ON DELETE $deleteRule")
}
if (updateRule != ReferenceOption.NO_ACTION) {
if (currentDialect is OracleDialect) {
exposedLogger.warn("Oracle doesn't support FOREIGN KEY with ON UPDATE clause. Please check your $fromTable table.")
exposedLogger.warn("Oracle doesn't support FOREIGN KEY with ON UPDATE clause. Please check your $fromTableName table.")
} else {
append(" ON UPDATE $updateRule")
}
}
}

override fun createStatement(): List<String> = listOf("ALTER TABLE $fromTable ADD $foreignKeyPart")
override fun createStatement(): List<String> = listOf("ALTER TABLE $fromTableName ADD $foreignKeyPart")

override fun modifyStatement(): List<String> = dropStatement() + createStatement()

Expand All @@ -120,8 +135,16 @@ data class ForeignKeyConstraint(
is MysqlDialect -> "FOREIGN KEY"
else -> "CONSTRAINT"
}
return listOf("ALTER TABLE $fromTable DROP $constraintType $fkName")
return listOf("ALTER TABLE $fromTableName DROP $constraintType $fkName")
}

fun targetOf(from: Column<*>): Column<*>? = references[from]

operator fun plus(other: ForeignKeyConstraint): ForeignKeyConstraint {
return copy(references = references + other.references)
}

override fun toString() = "ForeignKeyConstraint(fkName='$fkName')"
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,13 @@ object SchemaUtils {
}
}

fun createFKey(reference: Column<*>): List<String> {
val foreignKey = reference.foreignKey
require(foreignKey != null && (foreignKey.deleteRule != null || foreignKey.updateRule != null)) { "$reference does not reference anything" }
fun createFKey(foreignKey: ForeignKeyConstraint): List<String> {
naftalmm marked this conversation as resolved.
Show resolved Hide resolved
val allFromColumnsBelongsToTheSameTable = foreignKey.from.all { it.table == foreignKey.fromTable }
require(allFromColumnsBelongsToTheSameTable) { "not all referencing columns of $foreignKey belong to the same table " }
val allTargetColumnsBelongToTheSameTable = foreignKey.target.all { it.table == foreignKey.targetTable }
require(allTargetColumnsBelongToTheSameTable) { "not all referenced columns of $foreignKey belong to the same table " }
require(foreignKey.from.size == foreignKey.target.size) { "$foreignKey referencing columns are not in accordance with referenced" }
require(foreignKey.deleteRule != null || foreignKey.updateRule != null) { "$foreignKey has no reference constraint actions" }
return foreignKey.createStatement()
}

Expand Down Expand Up @@ -204,19 +208,16 @@ object SchemaUtils {
}

for (table in tables) {
for (column in table.columns) {
val foreignKey = column.foreignKey
if (foreignKey != null) {
val existingConstraint = existingColumnConstraint[table to column]?.firstOrNull()
if (existingConstraint == null) {
statements.addAll(createFKey(column))
} else if (existingConstraint.target.table != foreignKey.target.table ||
foreignKey.deleteRule != existingConstraint.deleteRule ||
foreignKey.updateRule != existingConstraint.updateRule
) {
statements.addAll(existingConstraint.dropStatement())
statements.addAll(createFKey(column))
}
for (foreignKey in table.foreignKeys) {
val existingConstraint = existingColumnConstraint[table to foreignKey.from]?.firstOrNull()
if (existingConstraint == null) {
statements.addAll(createFKey(foreignKey))
} else if (existingConstraint.targetTable != foreignKey.targetTable ||
foreignKey.deleteRule != existingConstraint.deleteRule ||
foreignKey.updateRule != existingConstraint.updateRule
) {
statements.addAll(existingConstraint.dropStatement())
statements.addAll(createFKey(foreignKey))
}
}
}
Expand Down Expand Up @@ -352,7 +353,7 @@ object SchemaUtils {
val constraint = fk.first()
val fkPartToLog = fk.joinToString(", ") { it.fkName }
exposedLogger.warn(
"\t\t\t'${pair.first}'.'${pair.second}' -> '${constraint.fromTable}'.'${constraint.fromColumn}':\t$fkPartToLog"
"\t\t\t'${pair.first}'.'${pair.second}' -> '${constraint.fromTableName}':\t$fkPartToLog"
)
}

Expand Down Expand Up @@ -397,7 +398,7 @@ object SchemaUtils {
val fKeyConstraints = currentDialect.columnConstraints(*tables).keys
val existingIndices = currentDialect.existingIndices(*tables)
fun List<Index>.filterFKeys() = if (isMysql) {
filterNot { it.table to it.columns.singleOrNull() in fKeyConstraints }
filterNot { it.table to LinkedHashSet(it.columns) in fKeyConstraints }
} else {
this
}
Expand Down
55 changes: 48 additions & 7 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Table.kt
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,11 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
/** Returns all indices declared on the table. */
val indices: List<Index> get() = _indices

private val _foreignKeys = mutableListOf<ForeignKeyConstraint>()

/** Returns all foreignKeys declared on the table. */
val foreignKeys: List<ForeignKeyConstraint> get() = columns.mapNotNull { it.foreignKey } + _foreignKeys

private val checkConstraints = mutableListOf<Pair<String, Op<Boolean>>>()

/**
Expand Down Expand Up @@ -818,7 +823,7 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
onDelete: ReferenceOption? = null,
onUpdate: ReferenceOption? = null,
fkName: String? = null
): Column<T?> = Column<T>(this, name, refColumn.columnType.cloneAsBaseType()).references(refColumn, onDelete, onUpdate, fkName).nullable()
): Column<T?> = reference(name, refColumn, onDelete, onUpdate, fkName).nullable()

/**
* Creates a column with the specified [name] with an optional reference to the [refColumn] column with [onDelete], [onUpdate], and [fkName] options.
Expand All @@ -841,10 +846,7 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
onDelete: ReferenceOption? = null,
onUpdate: ReferenceOption? = null,
fkName: String? = null
): Column<E?> {
val entityIdColumn = entityId(name, (refColumn.columnType as EntityIDColumnType<T>).idColumn) as Column<E>
return entityIdColumn.references(refColumn, onDelete, onUpdate, fkName).nullable()
}
): Column<E?> = reference(name, refColumn, onDelete, onUpdate, fkName).nullable()

/**
* Creates a column with the specified [name] with an optional reference to the `id` column in [foreign] table with [onDelete], [onUpdate], and [fkName] options.
Expand All @@ -865,7 +867,7 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
onDelete: ReferenceOption? = null,
onUpdate: ReferenceOption? = null,
fkName: String? = null
): Column<EntityID<T>?> = entityId(name, foreign).references(foreign.id, onDelete, onUpdate, fkName).nullable()
): Column<EntityID<T>?> = reference(name, foreign, onDelete, onUpdate, fkName).nullable()

// Miscellaneous

Expand Down Expand Up @@ -939,6 +941,45 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {
*/
fun uniqueIndex(customIndexName: String? = null, vararg columns: Column<*>): Unit = index(customIndexName, true, *columns)

/**
* Creates a composite foreign key.
*
* @param from Columns that compose the foreign key. Their order should match the order of columns in referenced primary key.
* @param target Primary key of the referenced table.
* @param onUpdate Reference option when performing update operations.
* @param onUpdate Reference option when performing delete operations.
* @param name Custom foreign key name
*/
fun foreignKey(
vararg from: Column<*>,
target: PrimaryKey,
onUpdate: ReferenceOption? = null,
onDelete: ReferenceOption? = null,
name: String? = null
) {
_foreignKeys.add(ForeignKeyConstraint(from.zip(target.columns).toMap(), onUpdate, onDelete, name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that check that size of from and target.columns is the same. Also, maybe it's worth to check column types?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added check for collections sizes. But columns type check is rather tricky - naive type equality check (from.columnType == target.columnType) will be too restrictive because SQL (at least postgreSQL) allows:

  • referencing from INT to SERIAL
  • referencing from INT NOT NULL to INT NULL (and vice versa!)
  • referencing from VARCHAR(50) to VARCHAR(100) (and vice versa!)
  • referencing from INT/SMALLINT to BIGINT (and vice versa!)

And definitely, this list is not exhaustive - it's just my experiments in sqlfiddle

So, the new method IColumnType.isCompatibleWith(other: IColumnType) : Boolean needs to be added first to make this check possible. But it looks like a big research task, at least I couldn't quickly find any docs about types compatibility (which type could be referenced to which in foreign key constraint) for any of SQL dialects.

}

/**
* Creates a composite foreign key.
*
* @param references Pairs of columns that compose the foreign key.
* First value of pair is a column of referencing table, second value - a column of a referenced one.
* All referencing columns must belong to this table.
* All referenced columns must belong to the same table.
* @param onUpdate Reference option when performing update operations.
* @param onUpdate Reference option when performing delete operations.
* @param name Custom foreign key name
*/
fun foreignKey(
vararg references: Pair<Column<*>, Column<*>>,
onUpdate: ReferenceOption? = null,
onDelete: ReferenceOption? = null,
name: String? = null
) {
_foreignKeys.add(ForeignKeyConstraint(references.toMap(), onUpdate, onDelete, name))
}

// Check constraints

/**
Expand Down Expand Up @@ -1019,7 +1060,7 @@ open class Table(name: String = "") : ColumnSet(), DdlAware {

val addForeignKeysInAlterPart = SchemaUtils.checkCycle(this) && currentDialect !is SQLiteDialect

val foreignKeyConstraints = columns.mapNotNull { it.foreignKey }
val foreignKeyConstraints = foreignKeys

val createTable = buildString {
append("CREATE TABLE ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import org.jetbrains.exposed.sql.transactions.TransactionManager
import java.nio.ByteBuffer
import java.util.*
import java.util.concurrent.ConcurrentHashMap
import kotlin.collections.HashMap
import kotlin.collections.LinkedHashSet

/**
* Provides definitions for all the supported SQL data types.
Expand Down Expand Up @@ -578,8 +580,8 @@ interface DatabaseDialect {
/** Returns a map with the column metadata of all the defined columns in each of the specified [tables]. */
fun tableColumns(vararg tables: Table): Map<Table, List<ColumnMetadata>> = emptyMap()

/** Returns a map with the foreign key constraints of all the defined columns in each of the specified [tables]. */
fun columnConstraints(vararg tables: Table): Map<Pair<Table, Column<*>>, List<ForeignKeyConstraint>> = emptyMap()
/** Returns a map with the foreign key constraints of all the defined columns sets in each of the specified [tables]. */
fun columnConstraints(vararg tables: Table): Map<Pair<Table, LinkedHashSet<Column<*>>>, List<ForeignKeyConstraint>> = emptyMap()

/** Returns a map with all the defined indices in each of the specified [tables]. */
fun existingIndices(vararg tables: Table): Map<Table, List<Index>> = emptyMap()
Expand Down Expand Up @@ -701,15 +703,15 @@ abstract class VendorDialect(
override fun tableColumns(vararg tables: Table): Map<Table, List<ColumnMetadata>> =
TransactionManager.current().connection.metadata { columns(*tables) }

override fun columnConstraints(vararg tables: Table): Map<Pair<Table, Column<*>>, List<ForeignKeyConstraint>> {
val constraints = HashMap<Pair<Table, Column<*>>, MutableList<ForeignKeyConstraint>>()
override fun columnConstraints(vararg tables: Table): Map<Pair<Table, LinkedHashSet<Column<*>>>, List<ForeignKeyConstraint>> {
val constraints = HashMap<Pair<Table, LinkedHashSet<Column<*>>>, MutableList<ForeignKeyConstraint>>()

val tablesToLoad = tables.filter { !columnConstraintsCache.containsKey(it.nameInDatabaseCase()) }

fillConstraintCacheForTables(tablesToLoad)
tables.forEach { table ->
columnConstraintsCache[table.nameInDatabaseCase()].orEmpty().forEach {
constraints.getOrPut(it.from.table to it.from) { arrayListOf() }.add(it)
constraints.getOrPut(table to it.from) { arrayListOf() }.add(it)
}
}
return constraints
Expand All @@ -725,7 +727,7 @@ abstract class VendorDialect(
protected fun String.quoteIdentifierWhenWrongCaseOrNecessary(tr: Transaction): String =
tr.db.identifierManager.quoteIdentifierWhenWrongCaseOrNecessary(this)

protected val columnConstraintsCache: MutableMap<String, List<ForeignKeyConstraint>> = ConcurrentHashMap()
protected val columnConstraintsCache: MutableMap<String, Collection<ForeignKeyConstraint>> = ConcurrentHashMap()

protected open fun fillConstraintCacheForTables(tables: List<Table>): Unit =
columnConstraintsCache.putAll(TransactionManager.current().db.metadata { tableConstraints(tables) })
Expand Down
Loading