Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
Expand Down Expand Up @@ -133,6 +134,7 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour
override def newInstance(
path: String,
bucketId: Option[Int],
bucketingInfoExtractor: BucketingInfoExtractor,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
if (bucketId.isDefined) { sys.error("LibSVM doesn't support bucketing") }
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.catalog

abstract class BucketingInfoExtractor extends Serializable {
/**
* Given a input `filename`, computes the corresponding bucket id
*/
def getBucketId(fileName: String): Option[Int]

/**
* Given a bucket id returns the string representation to be used in output file name
*/
def bucketIdToString(id: Int): String

def getBucketedFilename(split: Int,
uniqueWriteJobId: String,
bucketId: Option[Int],
extension: String): String
}

class DefaultBucketingInfoExtractor extends BucketingInfoExtractor {
// The file name of bucketed data should have 3 parts:
// 1. some other information in the head of file name
// 2. bucket id part, some numbers, starts with "_"
// * The other-information part may use `-` as separator and may have numbers at the end,
// e.g. a normal parquet file without bucketing may have name:
// part-r-00000-2dd664f9-d2c4-4ffe-878f-431234567891.gz.parquet, and we will mistakenly
// treat `431234567891` as bucket id. So here we pick `_` as separator.
// 3. optional file extension part, in the tail of file name, starts with `.`
// An example of bucketed parquet file name with bucket id 3:
// part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
private val bucketedFileName = """.*_(\d+)(?:\..*)?$""".r

override def getBucketId(fileName: String): Option[Int] = fileName match {
case bucketedFileName(bucketId) => Some(bucketId.toInt)
case other => None
}

override def bucketIdToString(id: Int): String = f"_$id%05d"

override def getBucketedFilename(split: Int,
uniqueWriteJobId: String,
bucketId: Option[Int],
extension: String): String = {
val bucketString = bucketId.map(bucketIdToString).getOrElse("")
f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension"
}
}

object DefaultBucketingInfoExtractor {
val Instance = new DefaultBucketingInfoExtractor
}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ case class CatalogStorageFormat(
serdePropsToString)
output.filter(_.nonEmpty).mkString("Storage(", ", ", ")")
}

}

object CatalogStorageFormat {
Expand Down Expand Up @@ -99,7 +98,8 @@ case class CatalogTablePartition(
case class BucketSpec(
numBuckets: Int,
bucketColumnNames: Seq[String],
sortColumnNames: Seq[String]) {
sortColumnNames: Seq[String],
infoExtractor: BucketingInfoExtractor = DefaultBucketingInfoExtractor.Instance) {
if (numBuckets <= 0) {
throw new AnalysisException(s"Expected positive number of buckets, but got `$numBuckets`.")
}
Expand Down Expand Up @@ -162,7 +162,7 @@ case class CatalogTable(
val tableProperties = properties.map(p => p._1 + "=" + p._2).mkString("[", ", ", "]")
val partitionColumns = partitionColumnNames.map(quoteIdentifier).mkString("[", ", ", "]")
val bucketStrings = bucketSpec match {
case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) =>
case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames, _)) =>
val bucketColumnsString = bucketColumnNames.map(quoteIdentifier).mkString("[", ", ", "]")
val sortColumnsString = sortColumnNames.map(quoteIdentifier).mkString("[", ", ", "]")
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ case class FileSourceScanExec(
PartitionedFile(p.values, f.getPath.toUri.toString, 0, f.getLen, hosts)
}
}.groupBy { f =>
BucketingUtils
bucketSpec.infoExtractor
.getBucketId(new Path(f.filePath).getName)
.getOrElse(sys.error(s"Invalid bucket file ${f.filePath}"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ case class DescribeTableCommand(table: TableIdentifier, isExtended: Boolean, isF

private def describeBucketingInfo(metadata: CatalogTable, buffer: ArrayBuffer[Row]): Unit = {
metadata.bucketSpec match {
case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) =>
case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames, _)) =>
append(buffer, "Num Buckets:", numBuckets.toString, "")
append(buffer, "Bucket Columns:", bucketColumnNames.mkString("[", ", ", "]"), "")
append(buffer, "Sort Columns:", sortColumnNames.mkString("[", ", ", "]"), "")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,5 @@
package org.apache.spark.sql.execution.datasources

object BucketingUtils {
// The file name of bucketed data should have 3 parts:
// 1. some other information in the head of file name
// 2. bucket id part, some numbers, starts with "_"
// * The other-information part may use `-` as separator and may have numbers at the end,
// e.g. a normal parquet file without bucketing may have name:
// part-r-00000-2dd664f9-d2c4-4ffe-878f-431234567891.gz.parquet, and we will mistakenly
// treat `431234567891` as bucket id. So here we pick `_` as separator.
// 3. optional file extension part, in the tail of file name, starts with `.`
// An example of bucketed parquet file name with bucket id 3:
// part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet
private val bucketedFileName = """.*_(\d+)(?:\..*)?$""".r

def getBucketId(fileName: String): Option[Int] = fileName match {
case bucketedFileName(bucketId) => Some(bucketId.toInt)
case other => None
}

def bucketIdToString(id: Int): String = f"_$id%05d"
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, DefaultBucketingInfoExtractor}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
Expand Down Expand Up @@ -133,7 +133,18 @@ private[datasources] abstract class BaseWriterContainer(

protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = {
try {
outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext)
val bucketingInfoExtractor = if (relation.bucketSpec.isDefined) {
relation.bucketSpec.get.infoExtractor
} else {
DefaultBucketingInfoExtractor.Instance
}

outputWriterFactory.newInstance(
path,
bucketId,
bucketingInfoExtractor,
dataSchema,
taskAttemptContext)
} catch {
case e: org.apache.hadoop.fs.FileAlreadyExistsException =>
if (outputCommitter.getClass.getName.contains("Direct")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile, WriterContainer}
Expand Down Expand Up @@ -172,6 +173,7 @@ private[csv] class CSVOutputWriterFactory(params: CSVOptions) extends OutputWrit
override def newInstance(
path: String,
bucketId: Option[Int],
bucketingInfoExtractor: BucketingInfoExtractor,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
if (bucketId.isDefined) sys.error("csv doesn't support bucketing")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.catalog.{BucketingInfoExtractor, BucketSpec}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.FileRelation
Expand Down Expand Up @@ -63,6 +63,7 @@ abstract class OutputWriterFactory extends Serializable {
def newInstance(
path: String,
bucketId: Option[Int], // TODO: This doesn't belong here...
bucketingInfoExtractor: BucketingInfoExtractor,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -81,9 +82,12 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def newInstance(
path: String,
bucketId: Option[Int],
bucketingInfoExtractor: BucketingInfoExtractor,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new JsonOutputWriter(path, parsedOptions, bucketId, dataSchema, context)
new JsonOutputWriter(
path, parsedOptions, bucketId, bucketingInfoExtractor, dataSchema, context
)
}
}
}
Expand Down Expand Up @@ -151,6 +155,7 @@ private[json] class JsonOutputWriter(
path: String,
options: JSONOptions,
bucketId: Option[Int],
bucketingInfoExtractor: BucketingInfoExtractor,
dataSchema: StructType,
context: TaskAttemptContext)
extends OutputWriter with Logging {
Expand All @@ -163,12 +168,13 @@ private[json] class JsonOutputWriter(
private val recordWriter: RecordWriter[NullWritable, Text] = {
new TextOutputFormat[NullWritable, Text]() {
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString.json$extension")
val filename = bucketingInfoExtractor.getBucketedFilename(
context.getTaskAttemptID.getTaskID.getId,
context.getConfiguration.get(WriterContainer.DATASOURCE_WRITEJOBUUID),
bucketId,
s".json$extension"
)
new Path(path, filename)
}
}.getRecordWriter(context)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser
Expand Down Expand Up @@ -135,9 +136,10 @@ class ParquetFileFormat
override def newInstance(
path: String,
bucketId: Option[Int],
bucketingInfoExtractor: BucketingInfoExtractor,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new ParquetOutputWriter(path, bucketId, context)
new ParquetOutputWriter(path, bucketId, bucketingInfoExtractor, context)
}
}
}
Expand Down Expand Up @@ -516,6 +518,7 @@ private[parquet] class ParquetOutputWriterFactory(
def newInstance(
path: String,
bucketId: Option[Int],
bucketingInfoExtractor: BucketingInfoExtractor,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
throw new UnsupportedOperationException(
Expand All @@ -529,6 +532,7 @@ private[parquet] class ParquetOutputWriterFactory(
private[parquet] class ParquetOutputWriter(
path: String,
bucketId: Option[Int],
bucketingInfoExtractor: BucketingInfoExtractor,
context: TaskAttemptContext)
extends OutputWriter {

Expand All @@ -545,15 +549,16 @@ private[parquet] class ParquetOutputWriter(
// `FileOutputCommitter.getWorkPath()`, which points to the base directory of all
// partitions in the case of dynamic partitioning.
override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
val configuration = context.getConfiguration
val uniqueWriteJobId = configuration.get(WriterContainer.DATASOURCE_WRITEJOBUUID)
val taskAttemptId = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
val bucketString = bucketId.map(BucketingUtils.bucketIdToString).getOrElse("")
// It has the `.parquet` extension at the end because (de)compression tools
// such as gunzip would not be able to decompress this as the compression
// is not applied on this whole file but on each "page" in Parquet format.
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension")
val filename = bucketingInfoExtractor.getBucketedFilename(
context.getTaskAttemptID.getTaskID.getId,
context.getConfiguration.get(WriterContainer.DATASOURCE_WRITEJOBUUID),
bucketId,
extension
)
new Path(path, filename)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ case class PreprocessDDL(conf: SQLConf) extends Rule[LogicalPlan] {

private def checkBucketColumns(schema: StructType, tableDesc: CatalogTable): CatalogTable = {
tableDesc.bucketSpec match {
case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames)) =>
case Some(BucketSpec(numBuckets, bucketColumnNames, sortColumnNames, _)) =>
val normalizedBucketCols = bucketColumnNames.map { colName =>
normalizeColumnName(tableDesc.identifier, schema, colName, "bucket")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat

import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.BucketingInfoExtractor
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter}
import org.apache.spark.sql.execution.datasources._
Expand Down Expand Up @@ -73,6 +74,7 @@ class TextFileFormat extends TextBasedFileFormat with DataSourceRegister {
override def newInstance(
path: String,
bucketId: Option[Int],
bucketingInfoExtractor: BucketingInfoExtractor,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
if (bucketId.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat
}

if (bucketSpec.isDefined) {
val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get
val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames, _) = bucketSpec.get

tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETS, numBuckets.toString)
tableProperties.put(DATASOURCE_SCHEMA_NUMBUCKETCOLS, bucketColumnNames.length.toString)
Expand Down
Loading