diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 8577803743c8..d4b397b68c2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -51,7 +51,7 @@ private[libsvm] class LibSVMOutputWriter( new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) + val uniqueWriteJobId = configuration.get(WriteOutput.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala index 54d0f3bd6291..3b233cafdbb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteOutput.scala @@ -44,6 +44,8 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** A helper object for writing data out to a location. */ object WriteOutput extends Logging { + val DATASOURCE_WRITEJOBUUID = "spark.sql.sources.writeJobUUID" + /** A shared job description for all the write tasks. */ private class WriteJobDescription( val serializableHadoopConf: SerializableConfiguration, @@ -287,25 +289,26 @@ object WriteOutput extends Logging { } } - private def getBucketIdFromKey(key: InternalRow): Option[Int] = - description.bucketSpec.map { _ => key.getInt(description.partitionColumns.length) } - /** * Open and returns a new OutputWriter given a partition key and optional bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet */ - private def newOutputWriter( - key: InternalRow, - getPartitionString: UnsafeProjection): OutputWriter = { + private def newOutputWriter(key: InternalRow, partString: UnsafeProjection): OutputWriter = { val path = if (description.partitionColumns.nonEmpty) { - val partitionPath = getPartitionString(key).getString(0) + val partitionPath = partString(key).getString(0) new Path(stagingPath, partitionPath).toString } else { stagingPath } - val bucketId = getBucketIdFromKey(key) + + // If the bucket spec is defined, the bucket column is right after the partition columns + val bucketId = if (description.bucketSpec.isDefined) { + Some(key.getInt(description.partitionColumns.length)) + } else { + None + } val newWriter = description.outputWriterFactory.newInstance( path = path, @@ -319,7 +322,7 @@ object WriteOutput extends Logging { override def execute(iter: Iterator[InternalRow]): Unit = { // We should first sort by partition columns, then bucket id, and finally sorting columns. val sortingExpressions: Seq[Expression] = - description.partitionColumns ++ bucketIdExpression ++ sortColumns + description.partitionColumns ++ bucketIdExpression ++ sortColumns val getSortingKey = UnsafeProjection.create(sortingExpressions, description.allColumns) val sortingKeySchema = StructType(sortingExpressions.map { @@ -333,8 +336,8 @@ object WriteOutput extends Logging { description.nonPartitionColumns, description.allColumns) // Returns the partition path given a partition key. - val getPartitionString = - UnsafeProjection.create(Seq(Concat(partitionStringExpression)), description.partitionColumns) + val getPartitionString = UnsafeProjection.create( + Seq(Concat(partitionStringExpression)), description.partitionColumns) // Sorts the data before write, so that we only need one writer at the same time. val sorter = new UnsafeKVExternalSorter( @@ -414,7 +417,7 @@ object WriteOutput extends Logging { // `part-r--.parquet`). The reason why this ID is used to identify a job // rather than a single task output file is that, speculative tasks must generate the same // output file name as the original task. - job.getConfiguration.set(WriterContainer.DATASOURCE_WRITEJOBUUID, UUID.randomUUID().toString) + job.getConfiguration.set(WriteOutput.DATASOURCE_WRITEJOBUUID, UUID.randomUUID().toString) val taskAttemptContext = new TaskAttemptContextImpl(job.getConfiguration, taskAttemptId) val outputCommitter = newOutputCommitter( @@ -474,7 +477,3 @@ object WriteOutput extends Logging { } } } - -object WriterContainer { - val DATASOURCE_WRITEJOBUUID = "spark.sql.sources.writeJobUUID" -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 55cb26d6513a..ad42e37ebdd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile, WriterContainer} +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile, WriteOutput} import org.apache.spark.sql.types._ object CSVRelation extends Logging { @@ -200,7 +200,7 @@ private[csv] class CsvOutputWriter( new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) + val uniqueWriteJobId = configuration.get(WriteOutput.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.csv$extension") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 9fe38ccc9fdc..ad580eaaeaa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -169,7 +169,7 @@ private[json] class JsonOutputWriter( new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) + val uniqueWriteJobId = configuration.get(WriteOutput.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index f89ce05d82d9..c26659a0d7a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -26,7 +26,7 @@ import org.apache.parquet.hadoop.util.ContextUtil import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.{BucketingUtils, OutputWriter, OutputWriterFactory, WriterContainer} +import org.apache.spark.sql.execution.datasources.{BucketingUtils, OutputWriter, OutputWriterFactory, WriteOutput} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -155,7 +155,7 @@ private[parquet] class ParquetOutputWriter( // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) + val uniqueWriteJobId = configuration.get(WriteOutput.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index 9f9666731101..2de0e0f73cb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -133,7 +133,7 @@ class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemp new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) + val uniqueWriteJobId = configuration.get(WriteOutput.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId.txt$extension") diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 1af3280e18a8..cff686f93d22 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -226,7 +226,7 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val uniqueWriteJobId = conf.get(WriterContainer.DATASOURCE_WRITEJOBUUID) + val uniqueWriteJobId = conf.get(WriteOutput.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val partition = taskAttemptId.getTaskID.getId val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 906de6bbcbee..f726ae725a56 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -144,7 +144,7 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) + val uniqueWriteJobId = configuration.get(WriteOutput.DATASOURCE_WRITEJOBUUID) val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context)