Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch update query ext #2228

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions exposed-core/src/main/kotlin/org/jetbrains/exposed/sql/Queries.kt
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,26 @@ fun <T : Table> T.updateReturning(
return ReturningStatement(this, returning, update)
}

fun <T : Table, E> T.batchUpdate(
data: Iterable<E>,
limit: Int? = null,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't expect that standard update statment supports limit buy I may be wrong here.

where: (SqlExpressionBuilder.() -> Op<Boolean>)? = null,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I right that every statement may have own where condition?

For example:

update t set v = "v1" where id = 1
update t set v = "v2" where id = 2
...

Will it be reflected with this API? I just wonder because where arg is one for the whole statement, and the parameters for it should be taken somehow from the data parameter.

body: BatchUpdateStatement.(E) -> Unit
): List<ResultRow> = batchUpdate(data.iterator(), limit, where, body)

private fun <T : Table, E> T.batchUpdate(
data: Iterator<E>,
limit: Int? = null,
where: (SqlExpressionBuilder.() -> Op<Boolean>)? = null,
body: BatchUpdateStatement.(E) -> Unit
): List<ResultRow> = executeBatch(data, body) {
BatchUpdateStatement(
this,
limit,
where = where?.let { SqlExpressionBuilder.it() },
)
}

/**
* Represents the SQL statement that either inserts a new row into a table, or updates the existing row if insertion would violate a unique constraint.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,61 +7,41 @@ import org.jetbrains.exposed.dao.id.IdTable
import org.jetbrains.exposed.sql.Column
import org.jetbrains.exposed.sql.Expression
import org.jetbrains.exposed.sql.IColumnType
import org.jetbrains.exposed.sql.Op
import org.jetbrains.exposed.sql.QueryBuilder
import org.jetbrains.exposed.sql.Table
import org.jetbrains.exposed.sql.Transaction
import org.jetbrains.exposed.sql.statements.api.PreparedStatementApi
import org.jetbrains.exposed.sql.vendors.MysqlFunctionProvider

/**
* Represents the SQL statement that batch updates rows of a table.
*
* @param table Identity table to update values from.
*/
open class BatchUpdateStatement(val table: IdTable<*>) : UpdateStatement(table, null) {
/** The mappings of columns to update with their updated values for each entity in the batch. */
val data = ArrayList<Pair<EntityID<*>, Map<Column<*>, Any?>>>()
override val firstDataSet: List<Pair<Column<*>, Any?>> get() = data.first().second.toList()

/**
* Adds the specified entity [id] to the current list of update statements, using the mapping of columns to update
* provided for this `BatchUpdateStatement`.
*/
fun addBatch(id: EntityID<*>) {
val lastBatch = data.lastOrNull()
val different by lazy {
val set1 = firstDataSet.map { it.first }.toSet()
val set2 = lastBatch!!.second.keys
(set1 - set2) + (set2 - set1)
}

if (data.size > 1 && different.isNotEmpty()) {
throw BatchDataInconsistentException("Some values missing for batch update. Different columns: $different")
}

if (data.isNotEmpty()) {
data[data.size - 1] = lastBatch!!.copy(second = values.toMap())
values.clear()
hasBatchedValues = true
}
data.add(id to values)
}
open class BatchUpdateStatement(
table: Table,
val limit: Int?,
val where: Op<Boolean>?
) : BaseBatchInsertStatement(table, ignore = false, shouldReturnGeneratedValues = false) {

override fun <T, S : T?> update(column: Column<T>, value: Expression<S>) = error("Expressions unsupported in batch update")

override fun prepareSQL(transaction: Transaction, prepared: Boolean): String {
val updateSql = super.prepareSQL(transaction, prepared)
val idEqCondition = if (table is CompositeIdTable) {
table.idColumns.joinToString(separator = " AND ") { "${transaction.identity(it)} = ?" }
} else {
"${transaction.identity(table.id)} = ?"
}
return "$updateSql WHERE $idEqCondition"
val dialect = transaction.db.dialect
val functionProvider = UpsertBuilder.getFunctionProvider(dialect)
val insertValues = arguments!!.first()
return functionProvider.update(table, insertValues, limit, where, transaction)
}

override fun PreparedStatementApi.executeInternal(transaction: Transaction): Int = if (data.size == 1) executeUpdate() else executeBatch().sum()

override fun arguments(): Iterable<Iterable<Pair<IColumnType<*>, Any?>>> = data.map { (id, row) ->
val idArgs = (id.value as? CompositeID)?.values?.map {
it.key.columnType to it.value
} ?: listOf(table.id.columnType to id)
firstDataSet.map { it.first.columnType to row[it.first] } + idArgs
override fun arguments(): List<Iterable<Pair<IColumnType<*>, Any?>>> {
val whereArgs = QueryBuilder(true).apply {
where?.toQueryBuilder(this)
}.args
return super.arguments().map {
it + whereArgs
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ class EntityBatchUpdate(private val klass: EntityClass<*, Entity<*>>) {
fun execute(transaction: Transaction): Int {
val updateSets = data.filterNot { it.second.isEmpty() }.groupBy { it.second.keys }
return updateSets.values.fold(0) { acc, set ->
acc + BatchUpdateStatement(klass.table).let {
it.data.addAll(set)
acc + BatchUpdateStatement(klass.table, TODO(), TODO()).let {
// it.data.addAll(set)
it.execute(transaction)!!
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,25 @@ class UpdateTests : DatabaseTestsBase() {
}
}
}

@Test
fun testBatchUpdateWithNoConflict() {
withTables(excludeSettings = TestDB.ALL_H2_V1, Words) {
val amountOfWords = 10
val allWords = List(amountOfWords) { i -> "Word ${'A' + i}" to amountOfWords * i + amountOfWords }

Words.batchUpdate(allWords) { (word, count) ->
this[Words.word] = word
this[Words.count] = count
}

// assertEquals(amountOfWords, generatedIds.size)
assertEquals(amountOfWords.toLong(), Words.selectAll().count())
}
}

private object Words : Table("words") {
val word = varchar("name", 64).uniqueIndex()
val count = integer("count").default(1)
}
}