Skip to content

Commit 9f0bf51

Browse files
dtenedorgengliangwang
authored andcommitted
[SPARK-43018][SQL] Fix bug for INSERT commands with timestamp literals
### What changes were proposed in this pull request? This PR fixes a correctness bug for INSERT commands with timestamp literals. The bug manifests when: * An INSERT command includes a user-specified column list of fewer columns than the target table. * The provided values include timestamp literals. The bug was that the long integer values stored in the rows to represent these timestamp literals were getting assigned back to `UnresolvedInlineTable` rows without the timestamp type. Then the analyzer inserted an implicit cast from `LongType` to `TimestampType` later, which incorrectly caused the value to change during execution. This PR fixes the bug by propagating the timestamp type directly to the output table instead. ### Why are the changes needed? This PR fixes a correctness bug. ### Does this PR introduce _any_ user-facing change? Yes, this PR fixes a correctness bug. ### How was this patch tested? This PR adds a new unit test suite. Closes #40652 from dtenedor/assign-correct-insert-types. Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com> Signed-off-by: Gengliang Wang <gengliang@apache.org>
1 parent 026dafd commit 9f0bf51

File tree

3 files changed

+160
-44
lines changed

3 files changed

+160
-44
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDefaultColumns.scala

Lines changed: 54 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,10 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
107107
insertTableSchemaWithoutPartitionColumns.map { schema: StructType =>
108108
val regenerated: InsertIntoStatement =
109109
regenerateUserSpecifiedCols(i, schema)
110-
val expanded: LogicalPlan =
110+
val (expanded: LogicalPlan, addedDefaults: Boolean) =
111111
addMissingDefaultValuesForInsertFromInlineTable(node, schema, i.userSpecifiedCols.size)
112112
val replaced: Option[LogicalPlan] =
113-
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
113+
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults)
114114
replaced.map { r: LogicalPlan =>
115115
node = r
116116
for (child <- children.reverse) {
@@ -131,10 +131,10 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
131131
insertTableSchemaWithoutPartitionColumns.map { schema =>
132132
val regenerated: InsertIntoStatement = regenerateUserSpecifiedCols(i, schema)
133133
val project: Project = i.query.asInstanceOf[Project]
134-
val expanded: Project =
134+
val (expanded: Project, addedDefaults: Boolean) =
135135
addMissingDefaultValuesForInsertFromProject(project, schema, i.userSpecifiedCols.size)
136136
val replaced: Option[LogicalPlan] =
137-
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded)
137+
replaceExplicitDefaultValuesForInputOfInsertInto(schema, expanded, addedDefaults)
138138
replaced.map { r =>
139139
regenerated.copy(query = r)
140140
}.getOrElse(i)
@@ -270,67 +270,83 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
270270

271271
/**
272272
* Updates an inline table to generate missing default column values.
273+
* Returns the resulting plan plus a boolean indicating whether such values were added.
273274
*/
274-
private def addMissingDefaultValuesForInsertFromInlineTable(
275+
def addMissingDefaultValuesForInsertFromInlineTable(
275276
node: LogicalPlan,
276277
insertTableSchemaWithoutPartitionColumns: StructType,
277-
numUserSpecifiedColumns: Int): LogicalPlan = {
278+
numUserSpecifiedColumns: Int): (LogicalPlan, Boolean) = {
278279
val schema = insertTableSchemaWithoutPartitionColumns
279-
val newDefaultExpressions: Seq[Expression] =
280-
getDefaultExpressionsForInsert(schema, numUserSpecifiedColumns)
281-
val newNames: Seq[String] = if (numUserSpecifiedColumns > 0) {
282-
schema.fields.drop(numUserSpecifiedColumns).map(_.name)
283-
} else {
284-
schema.fields.map(_.name)
285-
}
286-
node match {
287-
case _ if newDefaultExpressions.isEmpty => node
280+
val newDefaultExpressions: Seq[UnresolvedAttribute] =
281+
getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, node.output.size)
282+
val newNames: Seq[String] = schema.fields.map(_.name)
283+
val resultPlan: LogicalPlan = node match {
284+
case _ if newDefaultExpressions.isEmpty =>
285+
node
288286
case table: UnresolvedInlineTable =>
289287
table.copy(
290-
names = table.names ++ newNames,
288+
names = newNames,
291289
rows = table.rows.map { row => row ++ newDefaultExpressions })
292290
case local: LocalRelation =>
293-
// Note that we have consumed a LocalRelation but return an UnresolvedInlineTable, because
294-
// addMissingDefaultValuesForInsertFromProject must replace unresolved DEFAULT references.
295-
UnresolvedInlineTable(
296-
local.output.map(_.name) ++ newNames,
297-
local.data.map { row =>
298-
val colTypes = StructType(local.output.map(col => StructField(col.name, col.dataType)))
299-
row.toSeq(colTypes).map(Literal(_)) ++ newDefaultExpressions
291+
val newDefaultExpressionsRow = new GenericInternalRow(
292+
// Note that this code path only runs when there is a user-specified column list of fewer
293+
// column than the target table; otherwise, the above 'newDefaultExpressions' is empty and
294+
// we match the first case in this list instead.
295+
schema.fields.drop(local.output.size).map {
296+
case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) =>
297+
analyze(f, "INSERT") match {
298+
case lit: Literal => lit.value
299+
case _ => null
300+
}
301+
case _ => null
300302
})
301-
case _ => node
303+
LocalRelation(
304+
output = schema.toAttributes,
305+
data = local.data.map { row =>
306+
new JoinedRow(row, newDefaultExpressionsRow)
307+
})
308+
case _ =>
309+
node
302310
}
311+
(resultPlan, newDefaultExpressions.nonEmpty)
303312
}
304313

305314
/**
306315
* Adds a new expressions to a projection to generate missing default column values.
316+
* Returns the logical plan plus a boolean indicating if such defaults were added.
307317
*/
308318
private def addMissingDefaultValuesForInsertFromProject(
309319
project: Project,
310320
insertTableSchemaWithoutPartitionColumns: StructType,
311-
numUserSpecifiedColumns: Int): Project = {
321+
numUserSpecifiedColumns: Int): (Project, Boolean) = {
312322
val schema = insertTableSchemaWithoutPartitionColumns
313323
val newDefaultExpressions: Seq[Expression] =
314-
getDefaultExpressionsForInsert(schema, numUserSpecifiedColumns)
324+
getNewDefaultExpressionsForInsert(schema, numUserSpecifiedColumns, project.projectList.size)
315325
val newAliases: Seq[NamedExpression] =
316326
newDefaultExpressions.zip(schema.fields).map {
317327
case (expr, field) => Alias(expr, field.name)()
318328
}
319-
project.copy(projectList = project.projectList ++ newAliases)
329+
(project.copy(projectList = project.projectList ++ newAliases),
330+
newDefaultExpressions.nonEmpty)
320331
}
321332

322333
/**
323334
* This is a helper for the addMissingDefaultValuesForInsertFromInlineTable methods above.
324335
*/
325-
private def getDefaultExpressionsForInsert(
326-
schema: StructType,
327-
numUserSpecifiedColumns: Int): Seq[Expression] = {
336+
private def getNewDefaultExpressionsForInsert(
337+
insertTableSchemaWithoutPartitionColumns: StructType,
338+
numUserSpecifiedColumns: Int,
339+
numProvidedValues: Int): Seq[UnresolvedAttribute] = {
328340
val remainingFields: Seq[StructField] = if (numUserSpecifiedColumns > 0) {
329-
schema.fields.drop(numUserSpecifiedColumns)
341+
insertTableSchemaWithoutPartitionColumns.fields.drop(numUserSpecifiedColumns)
330342
} else {
331343
Seq.empty
332344
}
333345
val numDefaultExpressionsToAdd = getStructFieldsForDefaultExpressions(remainingFields).size
346+
// Limit the number of new DEFAULT expressions to the difference of the number of columns in
347+
// the target table and the number of provided values in the source relation. This clamps the
348+
// total final number of provided values to the number of columns in the target table.
349+
.min(insertTableSchemaWithoutPartitionColumns.size - numProvidedValues)
334350
Seq.fill(numDefaultExpressionsToAdd)(UnresolvedAttribute(CURRENT_DEFAULT_COLUMN_NAME))
335351
}
336352

@@ -351,7 +367,8 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
351367
*/
352368
private def replaceExplicitDefaultValuesForInputOfInsertInto(
353369
insertTableSchemaWithoutPartitionColumns: StructType,
354-
input: LogicalPlan): Option[LogicalPlan] = {
370+
input: LogicalPlan,
371+
addedDefaults: Boolean): Option[LogicalPlan] = {
355372
val schema = insertTableSchemaWithoutPartitionColumns
356373
val defaultExpressions: Seq[Expression] = schema.fields.map {
357374
case f if f.metadata.contains(CURRENT_DEFAULT_COLUMN_METADATA_KEY) => analyze(f, "INSERT")
@@ -371,7 +388,11 @@ case class ResolveDefaultColumns(catalog: SessionCatalog) extends Rule[LogicalPl
371388
case project: Project =>
372389
replaceExplicitDefaultValuesForProject(defaultExpressions, project)
373390
case local: LocalRelation =>
374-
Some(local)
391+
if (addedDefaults) {
392+
Some(local)
393+
} else {
394+
None
395+
}
375396
}
376397
}
377398

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.QueryTest
21+
import org.apache.spark.sql.catalyst.expressions.Literal
22+
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
23+
import org.apache.spark.sql.test.SharedSparkSession
24+
import org.apache.spark.sql.types.{StructField, StructType, TimestampType}
25+
26+
class ResolveDefaultColumnsSuite extends QueryTest with SharedSparkSession {
27+
val rule = ResolveDefaultColumns(catalog = null)
28+
// This is the internal storage for the timestamp 2020-12-31 00:00:00.0.
29+
val literal = Literal(1609401600000000L, TimestampType)
30+
val table = UnresolvedInlineTable(
31+
names = Seq("attr1"),
32+
rows = Seq(Seq(literal)))
33+
val localRelation = ResolveInlineTables(table).asInstanceOf[LocalRelation]
34+
35+
def asLocalRelation(result: LogicalPlan): LocalRelation = result match {
36+
case r: LocalRelation => r
37+
case _ => fail(s"invalid result operator type: $result")
38+
}
39+
40+
test("SPARK-43018: Add DEFAULTs for INSERT from VALUES list with user-defined columns") {
41+
// Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with one user-specified
42+
// column. We add a default value of NULL to the row as a result.
43+
val insertTableSchemaWithoutPartitionColumns = StructType(Seq(
44+
StructField("c1", TimestampType),
45+
StructField("c2", TimestampType)))
46+
val (result: LogicalPlan, _: Boolean) =
47+
rule.addMissingDefaultValuesForInsertFromInlineTable(
48+
localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 1)
49+
val relation = asLocalRelation(result)
50+
assert(relation.output.map(_.name) == Seq("c1", "c2"))
51+
val data: Seq[Seq[Any]] = relation.data.map { row =>
52+
row.toSeq(StructType(relation.output.map(col => StructField(col.name, col.dataType))))
53+
}
54+
assert(data == Seq(Seq(literal.value, null)))
55+
}
56+
57+
test("SPARK-43018: Add no DEFAULTs for INSERT from VALUES list with no user-defined columns") {
58+
// Call the 'addMissingDefaultValuesForInsertFromInlineTable' method with zero user-specified
59+
// columns. The table is unchanged because there are no default columns to add in this case.
60+
val insertTableSchemaWithoutPartitionColumns = StructType(Seq(
61+
StructField("c1", TimestampType),
62+
StructField("c2", TimestampType)))
63+
val (result: LogicalPlan, _: Boolean) =
64+
rule.addMissingDefaultValuesForInsertFromInlineTable(
65+
localRelation, insertTableSchemaWithoutPartitionColumns, numUserSpecifiedColumns = 0)
66+
assert(asLocalRelation(result) == localRelation)
67+
}
68+
69+
test("SPARK-43018: INSERT timestamp values into a table with column DEFAULTs") {
70+
withTable("t") {
71+
sql("create table t(id int, ts timestamp) using parquet")
72+
sql("insert into t (ts) values (timestamp'2020-12-31')")
73+
checkAnswer(spark.table("t"),
74+
sql("select null, timestamp'2020-12-31'").collect().head)
75+
}
76+
}
77+
}

sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,9 +1100,15 @@ class DataSourceV2SQLSuiteV1Filter
11001100
exception = intercept[AnalysisException] {
11011101
sql(s"INSERT INTO $t1(data, data) VALUES(5)")
11021102
},
1103-
errorClass = "COLUMN_ALREADY_EXISTS",
1104-
parameters = Map("columnName" -> "`data`")
1105-
)
1103+
errorClass = "_LEGACY_ERROR_TEMP_2305",
1104+
parameters = Map(
1105+
"numCols" -> "3",
1106+
"rowSize" -> "2",
1107+
"ri" -> "0"),
1108+
context = ExpectedContext(
1109+
fragment = s"INSERT INTO $t1(data, data)",
1110+
start = 0,
1111+
stop = 26))
11061112
}
11071113
}
11081114

@@ -1123,14 +1129,20 @@ class DataSourceV2SQLSuiteV1Filter
11231129
assert(intercept[AnalysisException] {
11241130
sql(s"INSERT OVERWRITE $t1 VALUES(4)")
11251131
}.getMessage.contains("not enough data columns"))
1126-
// Duplicate columns
1132+
// Duplicate columns
11271133
checkError(
11281134
exception = intercept[AnalysisException] {
11291135
sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
11301136
},
1131-
errorClass = "COLUMN_ALREADY_EXISTS",
1132-
parameters = Map("columnName" -> "`data`")
1133-
)
1137+
errorClass = "_LEGACY_ERROR_TEMP_2305",
1138+
parameters = Map(
1139+
"numCols" -> "3",
1140+
"rowSize" -> "2",
1141+
"ri" -> "0"),
1142+
context = ExpectedContext(
1143+
fragment = s"INSERT OVERWRITE $t1(data, data)",
1144+
start = 0,
1145+
stop = 31))
11341146
}
11351147
}
11361148

@@ -1152,14 +1164,20 @@ class DataSourceV2SQLSuiteV1Filter
11521164
assert(intercept[AnalysisException] {
11531165
sql(s"INSERT OVERWRITE $t1 VALUES('a', 4)")
11541166
}.getMessage.contains("not enough data columns"))
1155-
// Duplicate columns
1167+
// Duplicate columns
11561168
checkError(
11571169
exception = intercept[AnalysisException] {
11581170
sql(s"INSERT OVERWRITE $t1(data, data) VALUES(5)")
11591171
},
1160-
errorClass = "COLUMN_ALREADY_EXISTS",
1161-
parameters = Map("columnName" -> "`data`")
1162-
)
1172+
errorClass = "_LEGACY_ERROR_TEMP_2305",
1173+
parameters = Map(
1174+
"numCols" -> "4",
1175+
"rowSize" -> "3",
1176+
"ri" -> "0"),
1177+
context = ExpectedContext(
1178+
fragment = s"INSERT OVERWRITE $t1(data, data)",
1179+
start = 0,
1180+
stop = 31))
11631181
}
11641182
}
11651183

0 commit comments

Comments
 (0)