Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.plans.physical

import scala.language.existentials

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, IntegerType}

Expand All @@ -30,7 +32,10 @@ import org.apache.spark.sql.types.{DataType, IntegerType}
* - Intra-partition ordering of data: In this case the distribution describes guarantees made
* about how tuples are distributed within a single partition.
*/
sealed trait Distribution
sealed trait Distribution {
/** If defined, then represents how many partitions are expected by the distribution */
def numPartitions: Option[Int] = None
}

/**
* Represents a distribution where no promises are made about co-location of data.
Expand All @@ -49,12 +54,20 @@ case object AllTuples extends Distribution
* can mean such tuples are either co-located in the same partition or they will be contiguous
* within a single partition.
*/
case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution {
case class ClusteredDistribution(
clustering: Seq[Expression],
numClusters: Option[Int] = None,
hashingFunctionClass: Option[Class[_ <: HashExpression[Int]]] = None)
extends Distribution {
require(
clustering != Nil,
"The clustering expressions of a ClusteredDistribution should not be Nil. " +
"An AllTuples should be used to represent a distribution that only has " +
"a single partition.")
require(numClusters.isEmpty || numClusters.get > 0,
"Number of cluster (if set) should only be a positive integer")

override def numPartitions: Option[Int] = numClusters
}

/**
Expand Down Expand Up @@ -234,7 +247,10 @@ case object SinglePartition extends Partitioning {
* of `expressions`. All rows where `expressions` evaluate to the same values are guaranteed to be
* in the same partition.
*/
case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case class HashPartitioning(
expressions: Seq[Expression],
numPartitions: Int,
hashingFunctionClass: Class[_ <: HashExpression[Int]] = classOf[Murmur3Hash])
extends Expression with Partitioning with Unevaluable {

override def children: Seq[Expression] = expressions
Expand All @@ -243,8 +259,10 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)

override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
case ClusteredDistribution(requiredClustering) =>
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case ClusteredDistribution(requiredClustering, numClusters, hashingFunctionClazz) =>
(numClusters.isEmpty || numClusters.get == numPartitions) &&
(hashingFunctionClazz.isEmpty || hashingFunctionClazz.get == hashingFunctionClass) &&
expressions.forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}

Expand All @@ -260,9 +278,16 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)

/**
* Returns an expression that will produce a valid partition ID(i.e. non-negative and is less
* than numPartitions) based on hashing expressions.
* than numPartitions) based on hashing expression(s) and the hashing function.
*/
def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions))
def partitionIdExpression: Expression = {
val hashExpression = hashingFunctionClass match {
case m if m == classOf[Murmur3Hash] => new Murmur3Hash(expressions)
case h if h == classOf[HiveHash] => HiveHash(expressions)
case _ => throw new Exception(s"Unsupported hashingFunction: $hashingFunctionClass")
}
Pmod(hashExpression, Literal(numPartitions))
}
}

/**
Expand All @@ -289,8 +314,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case OrderedDistribution(requiredOrdering) =>
val minSize = Seq(requiredOrdering.size, ordering.size).min
requiredOrdering.take(minSize) == ordering.take(minSize)
case ClusteredDistribution(requiredClustering) =>
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
case ClusteredDistribution(requiredClustering, numClusters, hashingFunctionClass) =>
(numClusters.isEmpty || numClusters.get == numPartitions) && hashingFunctionClass.isEmpty &&
ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x)))
case _ => false
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst
import org.apache.spark.SparkFunSuite
/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.{HiveHash, Murmur3Hash}
import org.apache.spark.sql.catalyst.plans.physical._

class DistributionSuite extends SparkFunSuite {
Expand Down Expand Up @@ -79,6 +80,26 @@ class DistributionSuite extends SparkFunSuite {
ClusteredDistribution(Seq('d, 'e)),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('a, 'b, 'c), Some(10), Some(classOf[Murmur3Hash])),
true)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('a, 'b, 'c), Some(12), Some(classOf[Murmur3Hash])),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('d, 'e), Some(10), Some(classOf[Murmur3Hash])),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
ClusteredDistribution(Seq('a, 'b, 'c), Some(10), Some(classOf[HiveHash])),
false)

checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),
AllTuples,
Expand Down Expand Up @@ -127,19 +148,34 @@ class DistributionSuite extends SparkFunSuite {

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('a, 'b, 'c)),
ClusteredDistribution(Seq('a, 'b, 'c), Some(10), None),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('c, 'b, 'a)),
ClusteredDistribution(Seq('c, 'b, 'a), Some(10), None),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('b, 'c, 'a, 'd)),
ClusteredDistribution(Seq('b, 'c, 'a, 'd), Some(10), None),
true)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('b, 'c, 'a, 'd), Some(10), Some(classOf[Murmur3Hash])),
false)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('b, 'c, 'a, 'd), Some(12), Some(classOf[Murmur3Hash])),
false)

checkSatisfied(
RangePartitioning(Seq('a.asc, 'b.asc, 'c.asc), 10),
ClusteredDistribution(Seq('b, 'c, 'a, 'd), Some(10), Some(classOf[HiveHash])),
false)

// Cases which need an exchange between two data properties.
// TODO: We can have an optimization to first sort the dataset
// by a.asc and then sort b, and c in a partition. This optimization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@
package org.apache.spark.sql.catalyst

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, Literal}
import org.apache.spark.sql.catalyst.expressions.{HiveHash, InterpretedMutableProjection, Literal, Murmur3Hash}
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning}

class PartitioningSuite extends SparkFunSuite {
private val expressions = Seq(Literal(2), Literal(3))

test("HashPartitioning compatibility should be sensitive to expression ordering (SPARK-9785)") {
val expressions = Seq(Literal(2), Literal(3))
// Consider two HashPartitionings that have the same _set_ of hash expressions but which are
// created with different orderings of those expressions:
val partitioningA = HashPartitioning(expressions, 100)
Expand All @@ -34,11 +35,13 @@ class PartitioningSuite extends SparkFunSuite {
val distribution = ClusteredDistribution(expressions)
assert(partitioningA.satisfies(distribution))
assert(partitioningB.satisfies(distribution))

// These partitionings compute different hashcodes for the same input row:
def computeHashCode(partitioning: HashPartitioning): Int = {
val hashExprProj = new InterpretedMutableProjection(partitioning.expressions, Seq.empty)
hashExprProj.apply(InternalRow.empty).hashCode()
}

assert(computeHashCode(partitioningA) != computeHashCode(partitioningB))
// Thus, these partitionings are incompatible:
assert(!partitioningA.compatibleWith(partitioningB))
Expand All @@ -52,4 +55,18 @@ class PartitioningSuite extends SparkFunSuite {
assert(partitioningA.guarantees(partitioningA))
assert(partitioningA.compatibleWith(partitioningA))
}

test("HashPartitioning compatibility should be sensitive to hashing function") {
val partitioningA = HashPartitioning(expressions, 100, classOf[Murmur3Hash])
val partitioningB = HashPartitioning(expressions, 100, classOf[HiveHash])
assert(partitioningA != partitioningB)
assert(!partitioningA.compatibleWith(partitioningB))
}

test("HashPartitioning compatibility should be sensitive to number of partitions") {
val partitioningA = HashPartitioning(expressions, 10, classOf[Murmur3Hash])
val partitioningB = HashPartitioning(expressions, 1212, classOf[Murmur3Hash])
assert(partitioningA != partitioningB)
assert(!partitioningA.compatibleWith(partitioningB))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, SortOrder}
import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.debug._
import org.apache.spark.sql.execution.metric.SQLMetric
Expand All @@ -43,6 +44,10 @@ trait RunnableCommand extends logical.Command {
// `ExecutedCommand` during query planning.
lazy val metrics: Map[String, SQLMetric] = Map.empty

def requiredDistribution: Seq[Distribution] = Seq.fill(children.size)(UnspecifiedDistribution)

def requiredOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)

def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = {
throw new NotImplementedError
}
Expand Down Expand Up @@ -94,6 +99,10 @@ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) e

override def executeTake(limit: Int): Array[InternalRow] = sideEffectResult.take(limit).toArray

override def requiredChildDistribution: Seq[Distribution] = cmd.requiredDistribution

override def requiredChildOrdering: Seq[Seq[SortOrder]] = cmd.requiredOrdering

protected override def doExecute(): RDD[InternalRow] = {
sqlContext.sparkContext.parallelize(sideEffectResult, 1)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,11 @@ import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.catalog.{BucketSpec, ExternalCatalogUtils}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _}
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.{SparkPlan, SQLExecution}
import org.apache.spark.sql.types.StringType
import org.apache.spark.util.{SerializableConfiguration, Utils}

Expand Down Expand Up @@ -107,7 +106,7 @@ object FileFormatWriter extends Logging {
outputSpec: OutputSpec,
hadoopConf: Configuration,
partitionColumns: Seq[Attribute],
bucketSpec: Option[BucketSpec],
bucketIdExpression: Option[Expression],
statsTrackers: Seq[WriteJobStatsTracker],
options: Map[String, String])
: Set[String] = {
Expand All @@ -121,17 +120,6 @@ object FileFormatWriter extends Logging {
val partitionSet = AttributeSet(partitionColumns)
val dataColumns = allColumns.filterNot(partitionSet.contains)

val bucketIdExpression = bucketSpec.map { spec =>
val bucketColumns = spec.bucketColumnNames.map(c => dataColumns.find(_.name == c).get)
// Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
// guarantee the data distribution is same between shuffle and bucketed data source, which
// enables us to only shuffle one side when join a bucketed table and a normal one.
HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
}
val sortColumns = bucketSpec.toSeq.flatMap {
spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get)
}

val caseInsensitiveOptions = CaseInsensitiveMap(options)

// Note: prepareWrite has side effect. It sets "job".
Expand All @@ -155,34 +143,14 @@ object FileFormatWriter extends Logging {
statsTrackers = statsTrackers
)

// We should first sort by partition columns, then bucket id, and finally sorting columns.
val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns
// the sort order doesn't matter
val actualOrdering = plan.outputOrdering.map(_.child)
val orderingMatched = if (requiredOrdering.length > actualOrdering.length) {
false
} else {
requiredOrdering.zip(actualOrdering).forall {
case (requiredOrder, childOutputOrder) =>
requiredOrder.semanticEquals(childOutputOrder)
}
}

SQLExecution.checkSQLExecutionId(sparkSession)

// This call shouldn't be put into the `try` block below because it only initializes and
// prepares the job, any exception thrown from here shouldn't cause abortJob() to be called.
committer.setupJob(job)

try {
val rdd = if (orderingMatched) {
plan.execute()
} else {
SortExec(
requiredOrdering.map(SortOrder(_, Ascending)),
global = false,
child = plan).execute()
}
val rdd = plan.execute()
val ret = new Array[WriteTaskResult](rdd.partitions.length)
sparkSession.sparkContext.runJob(
rdd,
Expand All @@ -195,7 +163,7 @@ object FileFormatWriter extends Logging {
committer,
iterator = iter)
},
0 until rdd.partitions.length,
rdd.partitions.indices,
(index, res: WriteTaskResult) => {
committer.onTaskCommit(res.commitMsg)
ret(index) = res
Expand Down Expand Up @@ -514,18 +482,18 @@ object FileFormatWriter extends Logging {
var recordsInFile: Long = 0L
var fileCounter = 0
val updatedPartitions = mutable.Set[String]()
var currentPartionValues: Option[UnsafeRow] = None
var currentPartitionValues: Option[UnsafeRow] = None
var currentBucketId: Option[Int] = None

for (row <- iter) {
val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(row)) else None
val nextBucketId = if (isBucketed) Some(getBucketId(row)) else None

if (currentPartionValues != nextPartitionValues || currentBucketId != nextBucketId) {
if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) {
// See a new partition or bucket - write to a new partition dir (or a new bucket file).
if (isPartitioned && currentPartionValues != nextPartitionValues) {
currentPartionValues = Some(nextPartitionValues.get.copy())
statsTrackers.foreach(_.newPartition(currentPartionValues.get))
if (isPartitioned && currentPartitionValues != nextPartitionValues) {
currentPartitionValues = Some(nextPartitionValues.get.copy())
statsTrackers.foreach(_.newPartition(currentPartitionValues.get))
}
if (isBucketed) {
currentBucketId = nextBucketId
Expand All @@ -536,7 +504,7 @@ object FileFormatWriter extends Logging {
fileCounter = 0

releaseResources()
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions)
} else if (desc.maxRecordsPerFile > 0 &&
recordsInFile >= desc.maxRecordsPerFile) {
// Exceeded the threshold in terms of the number of records per file.
Expand All @@ -547,7 +515,7 @@ object FileFormatWriter extends Logging {
s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER")

releaseResources()
newOutputWriter(currentPartionValues, currentBucketId, fileCounter, updatedPartitions)
newOutputWriter(currentPartitionValues, currentBucketId, fileCounter, updatedPartitions)
}
val outputRow = getOutputRow(row)
currentWriter.write(outputRow)
Expand Down
Loading