diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 231f204b12b4..c46e2dd727ec 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -30,6 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.hive._ @@ -165,6 +166,59 @@ case class HiveTableScanExec( override def output: Seq[Attribute] = attributes + override val (outputPartitioning, outputOrdering): (Partitioning, Seq[SortOrder]) = { + val bucketSpec = relation.catalogTable.bucketSpec + + bucketSpec match { + case Some(spec) => + // For bucketed columns: + // ----------------------- + // `HashPartitioning` would be used only when: + // 1. ALL the bucketing columns are being read from the table + // + // For sorted columns: + // --------------------- + // Sort ordering should be used when ALL these criteria's match: + // 1. `HashPartitioning` is being used + // 2. A prefix (or all) of the sort columns are being read from the table. + // + // Sort ordering would be over the prefix subset of `sort columns` being read + // from the table. + // eg. + // Assume (col0, col2, col3) are the columns read from the table + // If sort columns are (col0, col1), then sort ordering would be considered as (col0) + // If sort columns are (col1, col0), then sort ordering would be empty as per rule #2 + // above + + def toAttribute(colName: String): Option[Attribute] = + relation.attributes.find(_.name == colName) + + val bucketColumns = spec.bucketColumnNames.flatMap(n => toAttribute(n)) + if (bucketColumns.size == spec.bucketColumnNames.size) { + val partitioning = HashPartitioning(bucketColumns, spec.numBuckets) + val sortColumns = + spec.sortColumnNames.map(x => toAttribute(x)).takeWhile(x => x.isDefined).map(_.get) + + val sortOrder = if (sortColumns.nonEmpty) { + // In case of bucketing, its possible to have multiple files belonging to the + // same bucket in a given relation. Each of these files are locally sorted + // but those files combined together are not globally sorted. Given that, + // the RDD partition will not be sorted even if the relation has sort columns set + // Current solution is to check if all the buckets have a single file in it + + sortColumns.map(attribute => SortOrder(attribute, Ascending)) + } else { + Nil + } + (partitioning, sortOrder) + } else { + (UnknownPartitioning(0), Nil) + } + case _ => + (UnknownPartitioning(0), Nil) + } + } + override def sameResult(plan: SparkPlan): Boolean = plan match { case other: HiveTableScanExec => val thisPredicates = partitionPruningPred.map(cleanExpression) @@ -172,10 +226,10 @@ case class HiveTableScanExec( val result = relation.sameResult(other.relation) && output.length == other.output.length && - output.zip(other.output) - .forall(p => p._1.name == p._2.name && p._1.dataType == p._2.dataType) && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) + output.zip(other.output) + .forall(p => p._1.name == p._2.name && p._1.dataType == p._2.dataType) && + thisPredicates.length == otherPredicates.length && + thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) result case _ => false } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 53bb3b93db73..d6d34924a76a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -20,11 +20,8 @@ package org.apache.spark.sql.hive.execution import java.io.IOException import java.net.URI import java.text.SimpleDateFormat -import java.util import java.util.{Date, Random} -import scala.collection.JavaConverters._ - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.FileUtils @@ -35,14 +32,14 @@ import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} import org.apache.spark.rdd.RDD import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.SparkException import org.apache.spark.util.SerializableJobConf - case class InsertIntoHiveTable( table: MetastoreRelation, partition: Map[String, Option[String]], @@ -293,6 +290,51 @@ case class InsertIntoHiveTable( override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray + override val (requiredChildDistribution, requiredChildOrdering): + (Seq[Distribution], Seq[Seq[SortOrder]]) = { + + val (requiredDistribution, requiredOrdering) = table.catalogTable.bucketSpec match { + case Some(bucketSpec) => + val numBuckets = bucketSpec.numBuckets + if (numBuckets < 1) { + (UnspecifiedDistribution, Nil) + } else { + def toAttribute(colName: String, colType: String): Attribute = + table.attributes.find(_.name == colName).getOrElse { + throw new AnalysisException( + s"Could not find $colType column $colName for output table " + + s"${table.catalogTable.qualifiedName} in its known columns : " + + s"(${child.output.map(_.name).mkString(", ")})") + } + + val bucketColumns = bucketSpec.bucketColumnNames.map(toAttribute(_, "bucket")) + + if (bucketColumns.size == bucketSpec.bucketColumnNames.size) { + val hashExpression = HashPartitioning(bucketColumns, numBuckets).partitionIdExpression + + // TODO : ClusteredDistribution does NOT guarantee the number of clusters so this + // may not produce desired number of buckets in all cases. + val childDistribution = ClusteredDistribution(Seq(hashExpression)) + + val sortColumnNames = bucketSpec.sortColumnNames + val childOrdering = if (sortColumnNames.nonEmpty) { + sortColumnNames.map(col => SortOrder(toAttribute(col, "sort"), Ascending)) + } else { + Nil + } + (childDistribution, childOrdering) + } else { + (UnspecifiedDistribution, Nil) + } + } + + case None => + (UnspecifiedDistribution, Nil) + } + + (Seq.fill(children.size)(requiredDistribution), Seq.fill(children.size)(requiredOrdering)) + } + protected override def doExecute(): RDD[InternalRow] = { sqlContext.sparkContext.parallelize(sideEffectResult.asInstanceOf[Seq[InternalRow]], 1) }