diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 5c79c6905801c..49b4f842c977f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -30,6 +30,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection @@ -133,6 +134,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour override def newInstance( path: String, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { if (bucketId.isDefined) { sys.error("LibSVM doesn't support bucketing") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/BucketingInfoExtractor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/BucketingInfoExtractor.scala new file mode 100644 index 0000000000000..c2582e7812f27 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/BucketingInfoExtractor.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +abstract class BucketingInfoExtractor extends Serializable { + /** + * Given a input `filename`, computes the corresponding bucket id + */ + def getBucketId(fileName: String): Option[Int] + + /** + * Given a bucket id returns the string representation to be used in output file name + */ + def bucketIdToString(id: Int): String + + def getBucketedFilename(split: Int, + uniqueWriteJobId: String, + bucketId: Option[Int], + extension: String): String +} + +class DefaultBucketingInfoExtractor extends BucketingInfoExtractor { + // The file name of bucketed data should have 3 parts: + // 1. some other information in the head of file name + // 2. bucket id part, some numbers, starts with "_" + // * The other-information part may use `-` as separator and may have numbers at the end, + // e.g. a normal parquet file without bucketing may have name: + // part-r-00000-2dd664f9-d2c4-4ffe-878f-431234567891.gz.parquet, and we will mistakenly + // treat `431234567891` as bucket id. So here we pick `_` as separator. + // 3. optional file extension part, in the tail of file name, starts with `.` + // An example of bucketed parquet file name with bucket id 3: + // part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + private val bucketedFileName = """.*_(\d+)(?:\..*)?$""".r + + override def getBucketId(fileName: String): Option[Int] = fileName match { + case bucketedFileName(bucketId) => Some(bucketId.toInt) + case other => None + } + + override def bucketIdToString(id: Int): String = f"_$id%05d" + + override def getBucketedFilename(split: Int, + uniqueWriteJobId: String, + bucketId: Option[Int], + extension: String): String = { + val bucketString = bucketId.map(bucketIdToString).getOrElse("") + f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension" + } +} + +object DefaultBucketingInfoExtractor { + val Instance = new DefaultBucketingInfoExtractor +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 83e01f95c06af..8e775eb8b2e94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -67,7 +67,6 @@ case class CatalogStorageFormat( serdePropsToString) output.filter(_.nonEmpty).mkString("Storage(", ", ", ")") } - } object CatalogStorageFormat { @@ -99,7 +98,8 @@ case class CatalogTablePartition( case class BucketSpec( numBuckets: Int, bucketColumnNames: Seq[String], - sortColumnNames: Seq[String]) { + sortColumnNames: Seq[String], + infoExtractor: BucketingInfoExtractor = DefaultBucketingInfoExtractor.Instance) { if (numBuckets <= 0) { throw new AnalysisException(s"Expected positive number of buckets, but got `$numBuckets`.") } @@ -162,7 +162,7 @@ case class CatalogTable( val tableProperties = properties.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]") val partitionColumns = partitionColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") val bucketStrings = bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => + case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames, _)) => val bucketColumnsString = bucketColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") val sortColumnsString = sortColumnNames.map(quoteIdentifier).mkString("[", ", ", "]") Seq( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 1a8d0e310aec0..259166562f8a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -382,7 +382,7 @@ case class FileSourceScanExec( PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts) } }.groupBy { f => - BucketingUtils + bucketSpec.infoExtractor .getBucketId(new Path(f.filePath).getName) .getOrElse(sys.error(s"Invalid bucket file ${f.filePath}")) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index b4a15b8b2882e..55bb194d4b37b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -461,7 +461,7 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF private def describeBucketingInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = { metadata.bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => + case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames, _)) => append(buffer, "Num Buckets:", numBuckets.toString, "") append(buffer, "Bucket Columns:", bucketColumnNames.mkString("[", ", ", "]"), "") append(buffer, "Sort Columns:", sortColumnNames.mkString("[", ", ", "]"), "") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala index ea4fe9c8ade5f..f6a010a1a7011 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BucketingUtils.scala @@ -18,22 +18,5 @@ package org.apache.spark.sql.execution.datasources object BucketingUtils { - // The file name of bucketed data should have 3 parts: - // 1. some other information in the head of file name - // 2. bucket id part, some numbers, starts with "_" - // * The other-information part may use `-` as separator and may have numbers at the end, - // e.g. a normal parquet file without bucketing may have name: - // part-r-00000-2dd664f9-d2c4-4ffe-878f-431234567891.gz.parquet, and we will mistakenly - // treat `431234567891` as bucket id. So here we pick `_` as separator. - // 3. optional file extension part, in the tail of file name, starts with `.` - // An example of bucketed parquet file name with bucket id 3: - // part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet - private val bucketedFileName = """.*_(\d+)(?:\..*)?$""".r - - def getBucketId(fileName: String): Option[Int] = fileName match { - case bucketedFileName(bucketId) => Some(bucketId.toInt) - case other => None - } - def bucketIdToString(id: Int): String = f"_$id%05d" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 7880c7cfa16f8..fb370bab5e58a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -28,7 +28,7 @@ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, DefaultBucketingInfoExtractor} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow @@ -133,7 +133,18 @@ private[datasources] abstract class BaseWriterContainer( protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = { try { - outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) + val bucketingInfoExtractor = if (relation.bucketSpec.isDefined) { + relation.bucketSpec.get.infoExtractor + } else { + DefaultBucketingInfoExtractor.Instance + } + + outputWriterFactory.newInstance( + path, + bucketId, + bucketingInfoExtractor, + dataSchema, + taskAttemptContext) } catch { case e: org.apache.hadoop.fs.FileAlreadyExistsException => if (outputCommitter.getClass.getName.contains("Direct")) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 33b170bc31f62..89e52a7f59a19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile, WriterContainer} @@ -172,6 +173,7 @@ private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWrit override def newInstance( path: String, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { if (bucketId.isDefined) sys.error("csv doesn't support bucketing") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala index e03a2323c7493..55331a06f415c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -29,7 +29,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.catalog.{BucketingInfoExtractor, BucketSpec} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.execution.FileRelation @@ -63,6 +63,7 @@ abstract class OutputWriterFactory extends Serializable { def newInstance( path: String, bucketId: Option[Int], // TODO: This doesn't belong here... + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext): OutputWriter diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala index 7421314df7aa5..a10ef7ce45c2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonFileFormat.scala @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType @@ -81,9 +82,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister { override def newInstance( path: String, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, parsedOptions, bucketId, dataSchema, context) + new JsonOutputWriter( + path, parsedOptions, bucketId, bucketingInfoExtractor, dataSchema, context + ) } } } @@ -151,6 +155,7 @@ private[json] class JsonOutputWriter( path: String, options: JSONOptions, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter with Logging { @@ -163,12 +168,13 @@ private[json] class JsonOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString.json$extension") + val filename = bucketingInfoExtractor.getBucketedFilename( + context.getTaskAttemptID.getTaskID.getId, + context.getConfiguration.get(WriterContainer.DATASOURCE_WRITEJOBUUID), + bucketId, + s".json$extension" + ) + new Path(path, filename) } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index 9208c82179d8d..c97616cede480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -41,6 +41,7 @@ import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser @@ -135,9 +136,10 @@ class ParquetFileFormat override def newInstance( path: String, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, bucketId, context) + new ParquetOutputWriter(path, bucketId, bucketingInfoExtractor, context) } } } @@ -516,6 +518,7 @@ private[parquet] class ParquetOutputWriterFactory( def newInstance( path: String, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { throw new UnsupportedOperationException( @@ -529,6 +532,7 @@ private[parquet] class ParquetOutputWriterFactory( private[parquet] class ParquetOutputWriter( path: String, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, context: TaskAttemptContext) extends OutputWriter { @@ -545,15 +549,16 @@ private[parquet] class ParquetOutputWriter( // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = context.getConfiguration - val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val split = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") // It has the `.parquet` extension at the end because (de)compression tools // such as gunzip would not be able to decompress this as the compression // is not applied on this whole file but on each "page" in Parquet format. - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") + val filename = bucketingInfoExtractor.getBucketedFilename( + context.getTaskAttemptID.getTaskID.getId, + context.getConfiguration.get(WriterContainer.DATASOURCE_WRITEJOBUUID), + bucketId, + extension + ) + new Path(path, filename) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index f14c63c19f905..38253030bd6ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -132,7 +132,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] { private def checkBucketColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = { tableDesc.bucketSpec match { - case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) => + case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames, _)) => val normalizedBucketCols = bucketColumnNames.map { colName => normalizeColumnName(tableDesc.identifier, schema, colName, "bucket") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala index a0c3fd53fb53b..14f1bba4d97e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextFileFormat.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} import org.apache.spark.sql.execution.datasources._ @@ -73,6 +74,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister { override def newInstance( path: String, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { if (bucketId.isDefined) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 7f50e38d30c9a..e66a20a6405c9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -214,7 +214,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } if (bucketSpec.isDefined) { - val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get + val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames, _) = bucketSpec.get tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETS, numBuckets.toString) tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETCOLS, bucketColumnNames.length.toString) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 286197b50e229..3ffda97583151 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveInspectors, HiveShim} @@ -84,9 +85,10 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable override def newInstance( path: String, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(path, bucketId, dataSchema, context) + new OrcOutputWriter(path, bucketId, bucketingInfoExtractor, dataSchema, context) } } } @@ -207,6 +209,7 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) private[orc] class OrcOutputWriter( path: String, bucketId: Option[Int], + bucketingInfoExtractor: BucketingInfoExtractor, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { @@ -221,10 +224,6 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val uniqueWriteJobId = conf.get(WriterContainer.DATASOURCE_WRITEJOBUUID) - val taskAttemptId = context.getTaskAttemptID - val partition = taskAttemptId.getTaskID.getId - val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("") val compressionExtension = { val name = conf.get(OrcRelation.ORC_COMPRESSION) OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") @@ -232,7 +231,13 @@ private[orc] class OrcOutputWriter( // It has the `.orc` extension at the end because (de)compression tools // such as gunzip would not be able to decompress this as the compression // is not applied on this whole file but on each "stream" in ORC format. - val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString$compressionExtension.orc" + val filename = bucketingInfoExtractor.getBucketedFilename( + context.getTaskAttemptID.getTaskID.getId, + conf.get(WriterContainer.DATASOURCE_WRITEJOBUUID), + bucketId, + s"$compressionExtension.orc" + ) + new Path(path, filename) new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 997445114ba58..b3c97c46824e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -22,9 +22,9 @@ import java.net.URI import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalyst.catalog.DefaultBucketingInfoExtractor import org.apache.spark.sql.catalyst.expressions.UnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning -import org.apache.spark.sql.execution.datasources.BucketingUtils import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf @@ -101,9 +101,11 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle ) for (bucketFile <- allBucketFiles) { - val bucketId = BucketingUtils.getBucketId(bucketFile.getName).getOrElse { - fail(s"Unable to find the related bucket files.") - } + val bucketId = DefaultBucketingInfoExtractor.Instance + .getBucketId(bucketFile.getName) + .getOrElse { + fail(s"Unable to find the related bucket files.") + } // Remove the duplicate columns in bucketCols and sortCols; // Otherwise, we got analysis errors due to duplicate names