Skip to content

Commit fa60fc9

Browse files
committed
[SPARK-29037] For static partition overwrite, spark may give duplicate result.
1 parent dffd92e commit fa60fc9

File tree

5 files changed

+70
-14
lines changed

5 files changed

+70
-14
lines changed

core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ object FileCommitProtocol extends Logging {
147147
className: String,
148148
jobId: String,
149149
outputPath: String,
150-
dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = {
150+
dynamicPartitionOverwrite: Boolean = false,
151+
staticPartitionKVS: Seq[(String, String)] = Seq.empty[(String, String)]):
152+
FileCommitProtocol = {
151153

152154
logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" +
153155
s" dynamic=$dynamicPartitionOverwrite")
@@ -156,9 +158,11 @@ object FileCommitProtocol extends Logging {
156158
// dynamicPartitionOverwrite: Boolean).
157159
// If that doesn't exist, try the one with (jobId: string, outputPath: String).
158160
try {
159-
val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean])
160-
logDebug("Using (String, String, Boolean) constructor")
161-
ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean])
161+
val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean],
162+
classOf[Seq[(String, String)]])
163+
logDebug("Using (String, String, Boolean, Seq[(String, String)]) constructor")
164+
ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean],
165+
staticPartitionKVS)
162166
} catch {
163167
case _: NoSuchMethodException =>
164168
logDebug("Falling back to (String, String) constructor")

core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.internal.io
1919

20-
import java.io.IOException
20+
import java.io.{File, IOException}
2121
import java.util.{Date, UUID}
2222

2323
import scala.collection.mutable
@@ -26,7 +26,7 @@ import scala.util.Try
2626
import org.apache.hadoop.conf.Configurable
2727
import org.apache.hadoop.fs.Path
2828
import org.apache.hadoop.mapreduce._
29-
import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter
29+
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat}
3030
import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
3131

3232
import org.apache.spark.internal.Logging
@@ -52,7 +52,8 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil
5252
class HadoopMapReduceCommitProtocol(
5353
jobId: String,
5454
path: String,
55-
dynamicPartitionOverwrite: Boolean = false)
55+
dynamicPartitionOverwrite: Boolean = false,
56+
staticPartitionKVS: Seq[(String, String)] = Seq.empty[(String, String)])
5657
extends FileCommitProtocol with Serializable with Logging {
5758

5859
import FileCommitProtocol._
@@ -89,9 +90,15 @@ class HadoopMapReduceCommitProtocol(
8990
* The staging directory of this write job. Spark uses it to deal with files with absolute output
9091
* path, or writing data into partitioned directory with dynamicPartitionOverwrite=true.
9192
*/
92-
private def stagingDir = new Path(path, ".spark-staging-" + jobId)
93+
protected def stagingDir = new Path(path, ".spark-staging-" + jobId)
94+
95+
96+
private def getStaticPartitionPath(): String = {
97+
staticPartitionKVS.map(kv => kv._1 + "=" + kv._2).mkString(File.separator)
98+
}
9399

94100
protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
101+
context.getConfiguration.set(FileOutputFormat.OUTDIR, stagingDir.toString)
95102
val format = context.getOutputFormatClass.getConstructor().newInstance()
96103
// If OutputFormat is Configurable, we should set conf to it.
97104
format match {
@@ -200,6 +207,20 @@ class HadoopMapReduceCommitProtocol(
200207
}
201208
fs.rename(new Path(stagingDir, part), finalPartPath)
202209
}
210+
} else if (!getStaticPartitionPath().isEmpty) {
211+
val finalPartPath = new Path(path, getStaticPartitionPath)
212+
assert(!fs.exists(finalPartPath))
213+
fs.rename(new Path(stagingDir, getStaticPartitionPath), finalPartPath)
214+
} else {
215+
val parts = fs.listStatus(stagingDir)
216+
.filter(_.isDirectory)
217+
.map(_.getPath.getName)
218+
.filter(name => !name.startsWith(".") && name.contains("="))
219+
for (part <- parts) {
220+
val finalPartPath = new Path(path, part)
221+
assert(!fs.exists(finalPartPath))
222+
fs.rename(new Path(stagingDir, part), finalPartPath)
223+
}
203224
}
204225

205226
fs.delete(stagingDir, true)

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,16 @@ case class InsertIntoHadoopFsRelationCommand(
105105
val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite &&
106106
staticPartitions.size < partitionColumns.length
107107

108+
val staticPartitionKVs = partitionColumns
109+
.filter(c => staticPartitions.contains(c.name))
110+
.map(att => (att.name, staticPartitions.get(att.name).get))
111+
108112
val committer = FileCommitProtocol.instantiate(
109113
sparkSession.sessionState.conf.fileCommitProtocolClass,
110-
jobId = java.util.UUID.randomUUID().toString,
111-
outputPath = outputPath.toString,
112-
dynamicPartitionOverwrite = dynamicPartitionOverwrite)
114+
java.util.UUID.randomUUID().toString,
115+
outputPath.toString,
116+
dynamicPartitionOverwrite,
117+
staticPartitionKVs)
113118

114119
val doInsertion = (mode, pathExists) match {
115120
case (SaveMode.ErrorIfExists, true) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,9 @@ import org.apache.spark.sql.internal.SQLConf
3232
class SQLHadoopMapReduceCommitProtocol(
3333
jobId: String,
3434
path: String,
35-
dynamicPartitionOverwrite: Boolean = false)
36-
extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite)
35+
dynamicPartitionOverwrite: Boolean = false,
36+
staticPartitionKVS: Seq[(String, String)] = Seq.empty[(String, String)])
37+
extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite, staticPartitionKVS)
3738
with Serializable with Logging {
3839

3940
override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = {
@@ -55,7 +56,7 @@ class SQLHadoopMapReduceCommitProtocol(
5556
// The specified output committer is a FileOutputCommitter.
5657
// So, we will use the FileOutputCommitter-specified constructor.
5758
val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext])
58-
committer = ctor.newInstance(new Path(path), context)
59+
committer = ctor.newInstance(stagingDir, context)
5960
} else {
6061
// The specified output committer is just an OutputCommitter.
6162
// So, we will use the no-argument constructor.

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.net.{MalformedURLException, URL}
2222
import java.sql.{Date, Timestamp}
2323
import java.util.concurrent.atomic.AtomicBoolean
2424

25+
import com.google.common.io.Files
2526
import org.apache.spark.{AccumulatorSuite, SparkException}
2627
import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
2728
import org.apache.spark.sql.catalyst.util.StringUtils
@@ -33,6 +34,7 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan
3334
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec}
3435
import org.apache.spark.sql.functions._
3536
import org.apache.spark.sql.internal.SQLConf
37+
import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode}
3638
import org.apache.spark.sql.test.{SharedSparkSession, TestSQLContext}
3739
import org.apache.spark.sql.test.SQLTestData._
3840
import org.apache.spark.sql.types._
@@ -3192,6 +3194,29 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession {
31923194
checkAnswer(df3, Array(Row(new java.math.BigDecimal("0.100000000000000000000000100"))))
31933195
}
31943196
}
3197+
3198+
test("SPARK-29037: For non dynamic partition overwrite, set a unique staging dir") {
3199+
withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) {
3200+
withTable("test") {
3201+
sql("create table test(id int, p1 int, p2 int) using parquet partitioned by (p1, p2)")
3202+
sql("insert overwrite table test partition(p1=1,p2) select 1, 3")
3203+
val df1 = sql("select * from test order by p2")
3204+
checkAnswer(df1, Array(Row(1, 1, 3)))
3205+
sql("insert overwrite table test partition(p1=1,p2) select 1, 4")
3206+
val df2 = sql("select * from test order by p2")
3207+
checkAnswer(df2, Array(Row(1, 1, 4)))
3208+
sql("insert overwrite table test partition(p1=1,p2=5) select 1")
3209+
val df3 = sql("select * from test order by p2")
3210+
checkAnswer(df3, Array(Row(1, 1, 4), Row(1, 1, 5)))
3211+
sql("insert overwrite table test select 1, 2, 3")
3212+
val df4 = sql("select * from test order by p2")
3213+
checkAnswer(df4, Array(Row(1, 2, 3)))
3214+
sql("insert overwrite table test select 9, 9, 9")
3215+
val df5 = sql("select * from test order by p2")
3216+
checkAnswer(df5, Array(Row(9, 9, 9)))
3217+
}
3218+
}
3219+
}
31953220
}
31963221

31973222
case class Foo(bar: Option[String])

0 commit comments

Comments
 (0)