diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowExpression.scala new file mode 100644 index 0000000000000..7f352365eaf24 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/WindowExpression.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.types.DataType + + +/** + * @param child the computation being performed + * @param windowSpec the window spec definition + */ +case class WindowExpression(child: Expression, windowSpec: WindowSpec) extends UnaryExpression { + + override type EvaluatedType = Any + + override def eval(input: Row): Any = child.eval(input) + + override def dataType: DataType = child.dataType + override def foldable: Boolean = child.foldable + override def nullable: Boolean = child.nullable + + override def toString: String = s"$child $windowSpec" +} + +case class WindowSpec(windowPartition: WindowPartition, windowFrame: Option[WindowFrame]) + +case class WindowPartition(partitionBy: Seq[Expression], sortBy: Seq[SortOrder]) + +sealed trait FrameType + +case object RowFrame extends FrameType +case object RangeFrame extends FrameType + +case class WindowFrame(frameType: FrameType, preceding: Int, following: Int) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 17522976dc2c9..d253b91775516 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -189,6 +189,16 @@ case class Aggregate( override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute) } +case class WindowAggregate( + partitionExpressions: Seq[Expression], + windowExpressions: Seq[Alias], + otherExpressions: Seq[NamedExpression], + child: LogicalPlan) + extends UnaryNode { + + override def output: Seq[Attribute] = (windowExpressions ++ otherExpressions).map(_.toAttribute) +} + /** * Apply the all of the GroupExpressions to every input row, hence we will get * multiple output rows for a input row. diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 81ee48ef4152f..5e411c2fdba9d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -185,7 +185,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Hive does not support buckets. ".*bucket.*", - // No window support yet + // We have our own tests based on these query files. ".*window.*", // Fails in hive with authorization errors. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7c6a7df2bd01e..e48f5746c9221 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -250,6 +250,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.PreInsertionCasts :: ExtractPythonUdfs :: ResolveUdtfsAlias :: + ResolveWindowUdaf :: sources.PreInsertCastAndRename :: Nil } @@ -371,6 +372,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { InMemoryScans, ParquetConversion, // Must be before HiveTableScans HiveTableScans, + WindowFunction, DataSinks, Scripts, HashAggregation, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index fd305eb480e63..956a76ec7f7d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import java.sql.Date +import java.util.concurrent.atomic.AtomicInteger import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} @@ -423,16 +424,16 @@ private[hive] object HiveQl { } /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to + * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) + * is equivalent to * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 * Check the following link for details. - * + * https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup * * The bitmask denotes the grouping expressions validity for a grouping set, * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of + * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. */ protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { @@ -446,7 +447,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val bitmasks: Seq[Int] = setASTs.map(set => set match { case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => + case Token("TOK_GROUPING_SETS_EXPRESSION", children) => children.foldLeft(0)((bitmap, col) => { val colString = col.asInstanceOf[ASTNode].toStringTree() require(keyMap.contains(colString), s"$colString doens't show up in the GROUP BY list") @@ -615,7 +616,8 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C clusterByClause :: distributeByClause :: limitClause :: - lateralViewClause :: Nil) = { + lateralViewClause :: + windowClause :: Nil) = { getClauses( Seq( "TOK_INSERT_INTO", @@ -633,15 +635,18 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C "TOK_CLUSTERBY", "TOK_DISTRIBUTEBY", "TOK_LIMIT", - "TOK_LATERAL_VIEW"), + "TOK_LATERAL_VIEW", + "WINDOW"), singleInsert) } - + val relations = fromClause match { case Some(f) => nodeToRelation(f) case None => OneRowRelation } - + + collectWindowDefs(windowClause) + val withWhere = whereClause.map { whereNode => val Seq(whereExpr) = whereNode.getChildren.toSeq Filter(nodeToExpr(whereExpr), relations) @@ -693,7 +698,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val serdeProps = propsClause.map { case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => (name, value) - } + } (Nil, serdeClass, serdeProps) case Nil => (Nil, "", Nil) @@ -736,32 +741,36 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // The projection of the query can either be a normal projection, an aggregation // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { - val selectExpressions = + val selectExpressions = nameExpressions(select.getChildren.flatMap(selExprNodeToExpr).toSeq) - Seq( - groupByClause.map(e => e match { - case Token("TOK_GROUPBY", children) => - // Not a transformation so must be either project or aggregation. - Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView) - case _ => sys.error("Expect GROUP BY") - }), - groupingSetsClause.map(e => e match { - case Token("TOK_GROUPING_SETS", children) => - val(groupByExprs, masks) = extractGroupingSet(children) - GroupingSets(masks, groupByExprs, withLateralView, selectExpressions) - case _ => sys.error("Expect GROUPING SETS") - }), - rollupGroupByClause.map(e => e match { - case Token("TOK_ROLLUP_GROUPBY", children) => - Rollup(children.map(nodeToExpr), withLateralView, selectExpressions) - case _ => sys.error("Expect WITH ROLLUP") - }), - cubeGroupByClause.map(e => e match { - case Token("TOK_CUBE_GROUPBY", children) => - Cube(children.map(nodeToExpr), withLateralView, selectExpressions) - case _ => sys.error("Expect WITH CUBE") - }), - Some(Project(selectExpressions, withLateralView))).flatten.head + + val groupPlan = (selectExprs: Seq[NamedExpression]) => + Seq( + groupByClause.map(e => e match { + case Token("TOK_GROUPBY", children) => + // Not a transformation so must be either project or aggregation. + Aggregate(children.map(nodeToExpr), selectExprs, withLateralView) + case _ => sys.error("Expect GROUP BY") + }), + groupingSetsClause.map(e => e match { + case Token("TOK_GROUPING_SETS", children) => + val(groupByExprs, masks) = extractGroupingSet(children) + GroupingSets(masks, groupByExprs, withLateralView, selectExprs) + case _ => sys.error("Expect GROUPING SETS") + }), + rollupGroupByClause.map(e => e match { + case Token("TOK_ROLLUP_GROUPBY", children) => + Rollup(children.map(nodeToExpr), withLateralView, selectExprs) + case _ => sys.error("Expect WITH ROLLUP") + }), + cubeGroupByClause.map(e => e match { + case Token("TOK_CUBE_GROUPBY", children) => + Cube(children.map(nodeToExpr), withLateralView, selectExprs) + case _ => sys.error("Expect WITH CUBE") + }), + Some(Project(selectExprs, withLateralView))).flatten.head + + windowToPlan(selectExpressions, groupPlan) } val withDistinct = @@ -1051,6 +1060,168 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") } + protected val windowDefs = new ThreadLocal[Map[String, Seq[ASTNode]]] { + override def initialValue() = Map.empty[String, Seq[ASTNode]] + } + + protected val nextWindowSpecId: AtomicInteger = new AtomicInteger(0) + + protected def collectWindowDefs(windowClause: Option[Node]) = { + val definitions = windowClause.toSeq.flatMap(_.getChildren.toSeq).collect { + case Token("TOK_WINDOWDEF", Token(alias, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => + alias -> spec + }.toMap + + windowDefs.set(definitions) + } + + protected def substituteWindowSpec(windowSpec: Seq[ASTNode]): Seq[ASTNode] = { + windowSpec match { + case Token(alias, Nil) :: Nil => + substituteWindowSpec(getWindowSpec(alias)) + + case Token(alias, Nil) :: frame => + val (partitionClause :: _ /* range frame */ :: _ /* value frame */ :: Nil) = getClauses( + Seq( + "TOK_PARTITIONINGSPEC", + "TOK_WINDOWRANGE", + "TOK_WINDOWVALUES"), + substituteWindowSpec(getWindowSpec(alias))) + + partitionClause + .map(_.asInstanceOf[ASTNode] :: frame) + .getOrElse(frame) + + case e => e + } + } + + protected def getWindowSpec(alias: String): Seq[ASTNode] = { + windowDefs.get().getOrElse(alias, sys.error(s"No window named $alias found.")) + } + + protected def windowToPlan( + selectExpressions: Seq[NamedExpression], + groupPlan: Seq[NamedExpression] => LogicalPlan): LogicalPlan = { + + val windowExpressions = + selectExpressions.flatMap(_.collect { case a @ Alias(WindowExpression(_, _), _) => a }) + + if (windowExpressions.isEmpty) groupPlan(selectExpressions) + else { + val subSelectExprs = + selectExpressions.filter( + _.collect { case a @ Alias(WindowExpression(_, _), _) => a }.isEmpty) + + val childPlan = groupPlan(subSelectExprs).transform { + case Project(_, child) => child + } + + val attributes = + ( + windowExpressions.flatMap(_.collect { + case a: UnresolvedAttribute => a + }) ++ subSelectExprs.map(_.toAttribute) + ).distinct + + val windowPartitions = windowExpressions.collect { + case Alias(WindowExpression(_, spec), _) => spec.windowPartition + }.distinct + + val (restWindowExprs, _, withWindow) = + windowPartitions.foldLeft((windowExpressions, attributes, childPlan)) { + case ((expressions, attributes, plan), part @ WindowPartition(partitionBy, sortBy)) => + val (computeExprs, restWindowExprs) = + expressions.partition( + _.child.asInstanceOf[WindowExpression].windowSpec.windowPartition == part) + + val withWindowPartition = (partitionBy, sortBy) match { + case (Nil, Nil) => plan + case (Nil, s) => Sort(s, false, plan) + case (p, Nil) => Repartition(p, plan) + case (p, s) => SortPartitions(s, Repartition(p, plan)) + } + + val otherExpressions = (attributes ++ (partitionBy ++ sortBy.map(_.child)).collect { + case a: UnresolvedAttribute => a + }).distinct + + (restWindowExprs, attributes ++ computeExprs.map(_.toAttribute), + WindowAggregate(partitionBy, computeExprs, otherExpressions, withWindowPartition)) + } + + assert(restWindowExprs.isEmpty) + + val finalExprs = selectExpressions.map { expr => + expr transform { + case u: NamedExpression + if windowExpressions.contains(u) || subSelectExprs.contains(u) => u.toAttribute + } + } + + Project(finalExprs.asInstanceOf[Seq[NamedExpression]], withWindow) + } + } + + protected def parseWindowSpec(windowSpec: Seq[ASTNode]): WindowSpec = { + val (partitionClause :: rowFrame :: rangeFrame :: Nil) = getClauses( + Seq( + "TOK_PARTITIONINGSPEC", + "TOK_WINDOWRANGE", + "TOK_WINDOWVALUES"), + substituteWindowSpec(windowSpec)) + + val windowPartition = partitionClause.map { partition => + val (orderByClause :: sortByClause :: distributeByClause :: clusterByClause :: Nil) = + getClauses( + Seq( + "TOK_ORDERBY", + "TOK_SORTBY", + "TOK_DISTRIBUTEBY", + "TOK_CLUSTERBY"), + partition.getChildren.toSeq.asInstanceOf[Seq[ASTNode]]) + + val partitionBy = distributeByClause.orElse(clusterByClause).toSeq + val sortBy = clusterByClause.orElse(orderByClause).orElse(sortByClause).toSeq + + WindowPartition( + partitionBy.flatMap(_.getChildren.map(nodeToExpr)), + sortBy.flatMap(_.getChildren.map(nodeToSortOrder))) + }.getOrElse(WindowPartition(Nil, Nil)) + + val maybeWindowFrame = rowFrame.orElse(rangeFrame).flatMap { frame => + val ranges = frame.getChildren.toList + val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) + + def nodeToBound(node: Node) = node match { + case Token("preceding" | "following", Token(count, Nil) :: Nil) => + if (count == "unbounded") Int.MaxValue else count.toInt + case _ => 0 + } + + ranges match { + case precedingNode :: followingNode :: _ => + Some(WindowFrame(frameType, nodeToBound(precedingNode), nodeToBound(followingNode))) + case precedingNode :: Nil => + Some(WindowFrame(frameType, nodeToBound(precedingNode), 0)) + case Nil => + None + } + } + + WindowSpec(windowPartition, maybeWindowFrame) + } + + protected def windowToExpr(node: Node): Option[Expression] = node match { + case Token(_, (tn @ Token(name, Nil)) :: tail) => + val (specNodes, argNodes) = tail.partition(_.getText == "TOK_WINDOWSPEC") + val maybeWindowSpec = specNodes.collectFirst { case Token(_, spec) => parseWindowSpec(spec) } + val newToken = node.asInstanceOf[ASTNode].withChildren((tn :: argNodes).toSeq) + maybeWindowSpec + .map(s => Alias(WindowExpression(nodeToExpr(newToken), s), + s"w_${nextWindowSpecId.getAndIncrement}")()) + case _ => sys.error(s"Failed to parse node with TOK_WINDOWSPEC") + } protected val escapedIdentifier = "`([^`]+)`".r /** Strips backticks from ident if present */ @@ -1248,10 +1419,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) /* UDFs - Must be last otherwise will preempt built in functions */ - case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr)) - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil) + case node @ Token("TOK_FUNCTION", Token(name, Nil) :: args) => + windowToExpr(node).getOrElse(UnresolvedFunction(name, args.map(nodeToExpr))) + + case node @ Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => + windowToExpr(node).getOrElse(UnresolvedFunction(name, UnresolvedStar(None) :: Nil)) /* Literals */ case Token("TOK_NULL", Nil) => Literal.create(null, NullType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index a6f4fbe8aba06..0933a2361e87d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -193,6 +193,14 @@ private[hive] trait HiveStrategies { } } + object WindowFunction extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.WindowAggregate(partition, window, other, child) => + execution.WindowAggregate(partition, window, other, planLater(child)) :: Nil + case _ => Nil + } + } + /** * Retrieves data using a HiveTableScan. Partition pruning predicates are also detected and * applied. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/WindowAggregate.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/WindowAggregate.scala new file mode 100644 index 0000000000000..f5349c0b3d2a2 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/WindowAggregate.scala @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.util.HashMap + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.sql.types._ +import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, Sort} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.hive.HiveGenericUdaf +import org.apache.spark.rdd.RDD + + +/** + * :: DeveloperApi :: + * Groups input data by `partitionExpressions` and computes the `windowExpressions` for each + * group. + * @param partitionExpressions expressions that are evaluated to determine partition. + * @param windowExpressions computeExpressions that compute now for each partition. + * @param otherExpressions otherExpressions that are expressions except computeExpressions. + * @param child the input data source. + */ +@DeveloperApi +case class WindowAggregate( + partitionExpressions: Seq[Expression], + windowExpressions: Seq[Alias], + otherExpressions: Seq[NamedExpression], + child: SparkPlan) + extends UnaryNode { + + override def requiredChildDistribution: List[Distribution] = + if (partitionExpressions == Nil) { + AllTuples :: Nil + } else { + ClusteredDistribution(partitionExpressions) :: Nil + } + + // HACK: Generators don't correctly preserve their output through serializations so we grab + // out child's output attributes statically here. + private[this] val childOutput = child.output + + override def output: Seq[Attribute] = (windowExpressions ++ otherExpressions).map(_.toAttribute) + + case class ComputedWindow( + unboundFunction: Expression, + pivotResult: Boolean, + windowSpec: WindowSpec, + boundedFunction: Expression, + computedAttribute: AttributeReference) + + case class WindowFunctionInfo( + supportsWindow: Boolean, pivotResult: Boolean, impliesOrder: Boolean) + + private[this] val computedWindows = windowExpressions.collect{ + + case Alias(expr @ WindowExpression(func, spec), _) => + val ipr = func match { + case HiveGenericUdaf(_, wfi, _) => wfi.isPivotResult + case _ => false + } + ComputedWindow( + func, + ipr, + spec, + BindReferences.bindReference(func, child.output), + AttributeReference(s"funcResult:$func", func.dataType, func.nullable)()) + } + + private[this] val otherAttributes = otherExpressions.map(_.toAttribute) + + /** The schema of the result of all evaluations */ + private[this] val resultAttributes = otherAttributes ++ computedWindows.map(_.computedAttribute) + + private[this] val resultMap = + (otherExpressions.map { other => other -> other.toAttribute } ++ + computedWindows.map { window => window.unboundFunction -> window.computedAttribute }).toMap + + private[this] val resultExpressions = (windowExpressions ++ otherExpressions).map { sel => + sel.transform { + case e: Expression if resultMap.contains(e) => resultMap(e) + } + } + + private[this] val sortExpressions = child match { + case Sort(sortOrder, _, _) => sortOrder + case _ => Seq[SortOrder]() + } + + // check whether to sort by other key in one partition + private[this] val ifSortInOnePartition = + !sortExpressions.isEmpty && + !sortExpressions.map(_.child).diff(partitionExpressions).isEmpty + + private[this] val sortReference = + if (sortExpressions.isEmpty) None + else { + // this is for computing with range frame ,and it only support 1 order + Some(BindReferences.bindReference(sortExpressions.head, childOutput)) + } + + private[this] def computeFunctions(rows: CompactBuffer[Row]): Seq[Iterator[Any]] = + computedWindows.map{ window => + val baseExpr = window.boundedFunction.asInstanceOf[AggregateExpression] + window.windowSpec.windowFrame.map { frame => + frame.frameType match { + case RowFrame => rowFrameFunction(baseExpr, frame, rows).iterator + case RangeFrame => rangeFrameFunction(baseExpr, frame, rows).iterator + } + }.getOrElse { + val function = baseExpr.newInstance() + if (window.pivotResult) { + rows.foreach(function.update) + function.eval(EmptyRow).asInstanceOf[Seq[Any]].iterator + } else if (ifSortInOnePartition) { + rows.map { row => + function.update(row) + function.eval(EmptyRow) + }.iterator + } else { + rows.foreach(function.update) + val result = function.eval(EmptyRow) + (0 to rows.size - 1).map(r => result).iterator + } + + } + } + + private[this] def rowFrameFunction(base: AggregateExpression, frame: WindowFrame, + rows: CompactBuffer[Row]): CompactBuffer[Any] = { + + val frameResults = new CompactBuffer[Any]() + var rowIndex = 0 + while (rowIndex < rows.size) { + var start = + if (frame.preceding == Int.MaxValue) 0 + else rowIndex - frame.preceding + if (start < 0) start = 0 + var end = + if (frame.following == Int.MaxValue) { + rows.size - 1 + } else { + rowIndex + frame.following + } + if (end > rows.size - 1) end = rows.size - 1 + + // new aggregate function + val aggr = base.newInstance() + (start to end).foreach(i => aggr.update(rows(i))) + + frameResults += aggr.eval(EmptyRow) + rowIndex += 1 + } + frameResults + } + + private[this] def rangeFrameFunction(base: AggregateExpression, frame: WindowFrame, + rows: CompactBuffer[Row]): CompactBuffer[Any] = { + + val (preceding, following) = sortReference.map { sortExpression => + sortExpression.child.dataType match { + case IntegerType => (Literal(frame.preceding), Literal(frame.following)) + case LongType => (Literal(frame.preceding.toLong), Literal(frame.following.toLong)) + case DoubleType => (Literal(frame.preceding.toDouble), Literal(frame.following.toDouble)) + case FloatType => (Literal(frame.preceding.toFloat), Literal(frame.following.toFloat)) + case ShortType => (Literal(frame.preceding.toShort), Literal(frame.following.toShort)) + case DecimalType() => + (Literal(BigDecimal(frame.preceding)), Literal(BigDecimal(frame.following))) + // TODO: need to support StringType comparison + case StringType => throw new Exception(s"not support StringType comparison yet") + case dt => throw new Exception(s"not support $dt comparison") + } + }.getOrElse { + throw new Exception(s"not support range frame with no sort expression ") + } + + val frameResults = new CompactBuffer[Any]() + var rowIndex = 0 + while (rowIndex < rows.size) { + + var precedingIndex = 0 + var followingIndex = rows.size - 1 + + sortReference.map { sortExpression => + val currentRow = rows(rowIndex) + val eval = sortExpression.child.eval(currentRow) + val precedingExpr = + if (sortExpression.direction == Ascending) { + Literal(eval) - sortExpression.child <= preceding + } else { + sortExpression.child - Literal(eval) <= preceding + } + + val followingExpr = + if (sortExpression.direction == Ascending) { + sortExpression.child - Literal(eval) <= following + } else { + Literal(eval) - sortExpression.child <= following + } + + if (frame.preceding != Int.MaxValue) precedingIndex = rowIndex + while (precedingIndex > 0 && + precedingExpr.eval(rows(precedingIndex - 1)).asInstanceOf[Boolean]) { + precedingIndex -= 1 + } + + if (frame.following != Int.MaxValue) followingIndex = rowIndex + while (followingIndex < rows.size - 1 && + followingExpr.eval(rows(followingIndex + 1)).asInstanceOf[Boolean]) { + followingIndex += 1 + } + } + // new aggregate function + val aggr = base.newInstance() + (precedingIndex to followingIndex).foreach(i => aggr.update(rows(i))) + frameResults += aggr.eval(EmptyRow) + rowIndex += 1 + } + frameResults + } + + private[this] def getNextFunctionsRow( + functionsResult: Seq[Iterator[Any]]): GenericMutableRow = { + val result = new GenericMutableRow(functionsResult.length) + var i = 0 + while (i < functionsResult.length) { + result(i) = functionsResult(i).next + i += 1 + } + result + } + + + override def execute(): RDD[Row] = attachTree(this, "execute") { + if (partitionExpressions.isEmpty) { + child.execute().mapPartitions { iter => + + val resultProjection = new InterpretedProjection(resultExpressions, resultAttributes) + + val otherProjection = new InterpretedMutableProjection(otherAttributes, childOutput) + val joinedRow = new JoinedRow + + val rows = new CompactBuffer[Row]() + while (iter.hasNext) { + rows += iter.next().copy() + } + new Iterator[Row] { + private[this] val functionsResult = computeFunctions(rows) + private[this] var currentRowIndex: Int = 0 + + override final def hasNext: Boolean = currentRowIndex < rows.size + + override final def next(): Row = { + + val otherResults = otherProjection(rows(currentRowIndex)).copy() + currentRowIndex += 1 + resultProjection(joinedRow(otherResults,getNextFunctionsRow(functionsResult))) + } + } + + } + } else { + child.execute().mapPartitions { iter => + val partitionTable = new HashMap[Row, CompactBuffer[Row]] + val partitionProjection = + new InterpretedMutableProjection(partitionExpressions, childOutput) + + var currentRow: Row = null + while (iter.hasNext) { + currentRow = iter.next() + val partitionKey = partitionProjection(currentRow).copy() + val existingMatchList = partitionTable.get(partitionKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + partitionTable.put(partitionKey, newMatchList) + newMatchList + } else { + existingMatchList + } + matchList += currentRow.copy() + } + + new Iterator[Row] { + private[this] val partitionTableIter = partitionTable.entrySet().iterator() + private[this] var currentpartition: CompactBuffer[Row] = _ + private[this] var functionsResult: Seq[Iterator[Any]] = _ + private[this] var currentRowIndex: Int = -1 + + val resultProjection = new InterpretedProjection(resultExpressions, resultAttributes) + val otherProjection = new InterpretedMutableProjection(otherAttributes, childOutput) + val joinedRow = new JoinedRow + + override final def hasNext: Boolean = + (currentRowIndex != -1 && currentRowIndex < currentpartition.size) || + (partitionTableIter.hasNext && fetchNext()) + + override final def next(): Row = { + + val otherResults = otherProjection(currentpartition(currentRowIndex)).copy() + currentRowIndex += 1 + resultProjection(joinedRow(otherResults,getNextFunctionsRow(functionsResult))) + + } + + private final def fetchNext(): Boolean = { + + currentRowIndex = 0 + if (partitionTableIter.hasNext) { + currentpartition = partitionTableIter.next().getValue + functionsResult = computeFunctions(currentpartition) + true + } else false + } + } + + } + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 47305571e579e..d12f5f411df64 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -24,8 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory -import org.apache.hadoop.hive.ql.exec.{UDF, UDAF} -import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} +import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ @@ -33,7 +32,7 @@ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Generate, Project, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.analysis.MultiAlias @@ -59,10 +58,12 @@ private[hive] abstract class HiveFunctionRegistry if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUdf(name, new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) + val windowFunctionInfo: WindowFunctionInfo = + FunctionRegistry.getWindowFunctionInfo(name.toLowerCase) + HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), windowFunctionInfo, children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUdaf(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { @@ -136,7 +137,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) override def get(): AnyRef = wrap(func(), oi) } -private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUdf( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF type EvaluatedType = Any @@ -189,8 +191,42 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr } } +private[spark] object ResolveWindowUdaf extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + + case q @ WindowAggregate(_, _, _, child) => + q transformExpressions { + case WindowExpression(HiveGenericUdaf(wrapper, wfi, _), WindowSpec(_, Some(frame))) + if (wfi.isSupportsWindow == false) => + sys.error(s"udaf ${wrapper.functionClassName} does not support window frame") + // if `isImpliesOrder` is true, we need to use sort expressions as parameters, + // such as rank, dense_rank + case HiveGenericUdaf(wrapper, wfi, children) + if (wfi.isImpliesOrder && children.isEmpty) => child match { + case SortPartitions(sortExpr, _) => + HiveGenericUdaf(wrapper, wfi, children ++ sortExpr.map(_.child)) + case Sort(sortExpr, _, _) => + HiveGenericUdaf(wrapper, wfi, children ++ sortExpr.map(_.child)) + case _ => + sys.error(s"udaf ${wrapper.functionClassName} need sort expressions") + } + // if function computed with window is `HiveGenericUdf`, we need to check whether + // it has HiveGenericUadf one, such as lead, lag + case HiveGenericUdf(name, wrapper, children) => + val windowFunctionInfo: WindowFunctionInfo = + Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( + sys.error(s"Couldn't find udaf function $name")) + HiveGenericUdaf( + new HiveFunctionWrapper(windowFunctionInfo.getfInfo().getFunctionClass.getName), + windowFunctionInfo, children) + } + + } +} + private[hive] case class HiveGenericUdaf( funcWrapper: HiveFunctionWrapper, + @transient windowFunctionInfo: WindowFunctionInfo, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -209,7 +245,17 @@ private[hive] case class HiveGenericUdaf( @transient protected lazy val inspectors = children.map(toInspector) - def dataType: DataType = inspectorToDataType(objectInspector) + protected val pivotResult = windowFunctionInfo.isPivotResult + + def dataType: DataType = + if (!pivotResult) inspectorToDataType(objectInspector) + else { + inspectorToDataType(objectInspector) match { + case ArrayType(dt, _) => dt + case _ => sys.error(s"error resolve the data type of udaf $funcWrapper.functionClassName") + } + } + def nullable: Boolean = true @@ -375,8 +421,7 @@ private[hive] case class HiveUdafFunction( private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - private val buffer = - function.getNewAggregationBuffer + private val buffer = function.getNewAggregationBuffer override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector) diff --git a/sql/hive/src/test/resources/data/files/part_tiny2.txt b/sql/hive/src/test/resources/data/files/part_tiny2.txt new file mode 100644 index 0000000000000..0445ae087dc5d --- /dev/null +++ b/sql/hive/src/test/resources/data/files/part_tiny2.txt @@ -0,0 +1,25 @@ +121152almond antique burnished rose metallicManufacturer#1Brand#14PROMO PLATED TIN2JUMBO BOX1173.15e pinto beans h +85768almond antique chartreuse lavender yellowManufacturer#1Brand#12LARGE BRUSHED STEEL34SM BAG1753.76refull +110592almond antique salmon chartreuse burlywoodManufacturer#1Brand#15PROMO BURNISHED NICKEL6JUMBO PKG1602.59 to the furiously +86428almond aquamarine burnished black steelManufacturer#1Brand#12STANDARD ANODIZED STEEL28WRAP BAG1414.42arefully +65667almond aquamarine pink moccasin thistleManufacturer#1Brand#12LARGE BURNISHED STEEL42JUMBO CASE1632.66e across the expr +105685almond antique violet chocolate turquoiseManufacturer#2Brand#22MEDIUM ANODIZED COPPER14MED CAN1690.68ly pending requ +191709almond antique violet turquoise frostedManufacturer#2Brand#22ECONOMY POLISHED STEEL40MED BOX1800.7 haggle +146985almond aquamarine midnight light salmonManufacturer#2Brand#23MEDIUM BURNISHED COPPER2SM CASE2031.98s cajole caref +132666almond aquamarine rose maroon antiqueManufacturer#2Brand#24SMALL POLISHED NICKEL25MED BOX1698.66even +195606almond aquamarine sandy cyan gainsboroManufacturer#2Brand#25STANDARD PLATED TIN18SM PKG1701.6ic de +90681almond antique chartreuse khaki whiteManufacturer#3Brand#31MEDIUM BURNISHED TIN17SM CASE1671.68are slyly after the sl +17273almond antique forest lavender goldenrodManufacturer#3Brand#35PROMO ANODIZED TIN14JUMBO CASE1190.27along the +112398almond antique metallic orange dimManufacturer#3Brand#32MEDIUM BURNISHED BRASS19JUMBO JAR1410.39ole car +40982almond antique misty red oliveManufacturer#3Brand#32ECONOMY PLATED COPPER1LG PKG1922.98c foxes can s +144293almond antique olive coral navajoManufacturer#3Brand#34STANDARD POLISHED STEEL45JUMBO CAN1337.29ag furiously about +49671almond antique gainsboro frosted violetManufacturer#4Brand#41SMALL BRUSHED BRASS10SM BOX1620.67ccounts run quick +48427almond antique violet mint lemonManufacturer#4Brand#42PROMO POLISHED STEEL39SM CASE1375.42hely ironic i +45261almond aquamarine floral ivory bisqueManufacturer#4Brand#42SMALL PLATED STEEL27WRAP CASE1206.26careful +17927almond aquamarine yellow dodger mintManufacturer#4Brand#41ECONOMY BRUSHED COPPER7SM PKG1844.92ites. eve +33357almond azure aquamarine papaya violetManufacturer#4Brand#41STANDARD ANODIZED TIN12WRAP CASE1290.35reful +192697almond antique blue firebrick mintManufacturer#5Brand#52MEDIUM BURNISHED TIN31LG DRUM1789.69ickly ir +42669almond antique medium spring khakiManufacturer#5Brand#51STANDARD BURNISHED TIN6MED CAN1611.66sits haggl +155733almond antique sky peru orangeManufacturer#5Brand#53SMALL PLATED BRASS2WRAP DRUM1788.73furiously. bra +15103almond aquamarine dodger light gainsboroManufacturer#5Brand#53ECONOMY BURNISHED STEEL46LG PACK1018.1packages hinder carefu +78486almond azure blanched chiffon midnightManufacturer#5Brand#52LARGE BRUSHED BRASS23MED BAG1464.48hely blith \ No newline at end of file diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 1. testWindowing-0-327a8cd39fe30255ff492ee86f660522 b/sql/hive/src/test/resources/golden/windowing.q -- 1. testWindowing-0-327a8cd39fe30255ff492ee86f660522 new file mode 100644 index 0000000000000..850c41c8115d6 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 1. testWindowing-0-327a8cd39fe30255ff492ee86f660522 @@ -0,0 +1,26 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1 1 1173.15 +Manufacturer#1 almond antique burnished rose metallic 2 1 1 2346.3 +Manufacturer#1 almond antique chartreuse lavender yellow 34 3 2 4100.06 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 4 3 5702.650000000001 +Manufacturer#1 almond aquamarine burnished black steel 28 5 4 7117.070000000001 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 6 5 8749.730000000001 +Manufacturer#2 almond antique violet chocolate turquoise 14 1 1 1690.68 +Manufacturer#2 almond antique violet turquoise frosted 40 2 2 3491.38 +Manufacturer#2 almond aquamarine midnight light salmon 2 3 3 5523.360000000001 +Manufacturer#2 almond aquamarine rose maroon antique 25 4 4 7222.02 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 5 5 8923.62 +Manufacturer#3 almond antique chartreuse khaki white 17 1 1 1671.68 +Manufacturer#3 almond antique forest lavender goldenrod 14 2 2 2861.95 +Manufacturer#3 almond antique metallic orange dim 19 3 3 4272.34 +Manufacturer#3 almond antique misty red olive 1 4 4 6195.32 +Manufacturer#3 almond antique olive coral navajo 45 5 5 7532.61 +Manufacturer#4 almond antique gainsboro frosted violet 10 1 1 1620.67 +Manufacturer#4 almond antique violet mint lemon 39 2 2 2996.09 +Manufacturer#4 almond aquamarine floral ivory bisque 27 3 3 4202.35 +Manufacturer#4 almond aquamarine yellow dodger mint 7 4 4 6047.27 +Manufacturer#4 almond azure aquamarine papaya violet 12 5 5 7337.620000000001 +Manufacturer#5 almond antique blue firebrick mint 31 1 1 1789.69 +Manufacturer#5 almond antique medium spring khaki 6 2 2 3401.3500000000004 +Manufacturer#5 almond antique sky peru orange 2 3 3 5190.08 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 4 4 6208.18 +Manufacturer#5 almond azure blanched chiffon midnight 23 5 5 7672.66 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 19. testUDAFsWithGBY-0-f4673060f65c6de8a1209579912f9489 b/sql/hive/src/test/resources/golden/windowing.q -- 19. testUDAFsWithGBY-0-f4673060f65c6de8a1209579912f9489 new file mode 100644 index 0000000000000..6461642d34a21 --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 19. testUDAFsWithGBY-0-f4673060f65c6de8a1209579912f9489 @@ -0,0 +1,25 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1173.15 4529.5 1173.15 1173.15 1509.8333333333333 +Manufacturer#1 almond antique chartreuse lavender yellow 34 1753.76 5943.92 1753.76 1753.76 1485.98 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 1602.59 7576.58 1602.59 1602.59 1515.316 +Manufacturer#1 almond aquamarine burnished black steel 28 1414.42 6403.43 1414.42 1414.42 1600.8575 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 4649.67 1632.66 1632.66 1549.89 +Manufacturer#2 almond antique violet chocolate turquoise 14 1690.68 5523.360000000001 1690.68 1690.68 1841.1200000000001 +Manufacturer#2 almond antique violet turquoise frosted 40 1800.7 7222.02 1800.7 1800.7 1805.505 +Manufacturer#2 almond aquamarine midnight light salmon 2 2031.98 8923.62 2031.98 2031.98 1784.7240000000002 +Manufacturer#2 almond aquamarine rose maroon antique 25 1698.66 7232.9400000000005 1698.66 1698.66 1808.2350000000001 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 1701.6 5432.24 1701.6 1701.6 1810.7466666666667 +Manufacturer#3 almond antique chartreuse khaki white 17 1671.68 4272.34 1671.68 1671.68 1424.1133333333335 +Manufacturer#3 almond antique forest lavender goldenrod 14 1190.27 6195.32 1190.27 1190.27 1548.83 +Manufacturer#3 almond antique metallic orange dim 19 1410.39 7532.61 1410.39 1410.39 1506.522 +Manufacturer#3 almond antique misty red olive 1 1922.98 5860.929999999999 1922.98 1922.98 1465.2324999999998 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 4670.66 1337.29 1337.29 1556.8866666666665 +Manufacturer#4 almond antique gainsboro frosted violet 10 1620.67 4202.35 1620.67 1620.67 1400.7833333333335 +Manufacturer#4 almond antique violet mint lemon 39 1375.42 6047.27 1375.42 1375.42 1511.8175 +Manufacturer#4 almond aquamarine floral ivory bisque 27 1206.26 7337.620000000001 1206.26 1206.26 1467.5240000000001 +Manufacturer#4 almond aquamarine yellow dodger mint 7 1844.92 5716.950000000001 1844.92 1844.92 1429.2375000000002 +Manufacturer#4 almond azure aquamarine papaya violet 12 1290.35 4341.530000000001 1290.35 1290.35 1447.176666666667 +Manufacturer#5 almond antique blue firebrick mint 31 1789.69 5190.08 1789.69 1789.69 1730.0266666666666 +Manufacturer#5 almond antique medium spring khaki 6 1611.66 6208.18 1611.66 1611.66 1552.045 +Manufacturer#5 almond antique sky peru orange 2 1788.73 7672.66 1788.73 1788.73 1534.532 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 1018.1 5882.970000000001 1018.1 1018.1 1470.7425000000003 +Manufacturer#5 almond azure blanched chiffon midnight 23 1464.48 4271.3099999999995 1464.48 1464.48 1423.7699999999998 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 2. testGroupByWithPartitioning-0-b4c80a243b3ef15de15d9128b0bc51a6 b/sql/hive/src/test/resources/golden/windowing.q -- 2. testGroupByWithPartitioning-0-b4c80a243b3ef15de15d9128b0bc51a6 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 2. testGroupByWithPartitioning-0-b4c80a243b3ef15de15d9128b0bc51a6 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 2. testGroupByWithPartitioning-0-cb5618b1e626f3a9d4a030b508b5d251 b/sql/hive/src/test/resources/golden/windowing.q -- 2. testGroupByWithPartitioning-0-cb5618b1e626f3a9d4a030b508b5d251 new file mode 100644 index 0000000000000..2c30e652aa26d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 2. testGroupByWithPartitioning-0-cb5618b1e626f3a9d4a030b508b5d251 @@ -0,0 +1,25 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1173.15 1 1 2 0 +Manufacturer#1 almond antique chartreuse lavender yellow 34 1753.76 2 2 34 32 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 1602.59 3 3 6 -28 +Manufacturer#1 almond aquamarine burnished black steel 28 1414.42 4 4 28 22 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 5 5 42 14 +Manufacturer#2 almond antique violet chocolate turquoise 14 1690.68 1 1 14 0 +Manufacturer#2 almond antique violet turquoise frosted 40 1800.7 2 2 40 26 +Manufacturer#2 almond aquamarine midnight light salmon 2 2031.98 3 3 2 -38 +Manufacturer#2 almond aquamarine rose maroon antique 25 1698.66 4 4 25 23 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 1701.6 5 5 18 -7 +Manufacturer#3 almond antique chartreuse khaki white 17 1671.68 1 1 17 0 +Manufacturer#3 almond antique forest lavender goldenrod 14 1190.27 2 2 14 -3 +Manufacturer#3 almond antique metallic orange dim 19 1410.39 3 3 19 5 +Manufacturer#3 almond antique misty red olive 1 1922.98 4 4 1 -18 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 5 5 45 44 +Manufacturer#4 almond antique gainsboro frosted violet 10 1620.67 1 1 10 0 +Manufacturer#4 almond antique violet mint lemon 39 1375.42 2 2 39 29 +Manufacturer#4 almond aquamarine floral ivory bisque 27 1206.26 3 3 27 -12 +Manufacturer#4 almond aquamarine yellow dodger mint 7 1844.92 4 4 7 -20 +Manufacturer#4 almond azure aquamarine papaya violet 12 1290.35 5 5 12 5 +Manufacturer#5 almond antique blue firebrick mint 31 1789.69 1 1 31 0 +Manufacturer#5 almond antique medium spring khaki 6 1611.66 2 2 6 -25 +Manufacturer#5 almond antique sky peru orange 2 1788.73 3 3 2 -4 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 1018.1 4 4 46 44 +Manufacturer#5 almond azure blanched chiffon midnight 23 1464.48 5 5 23 -23 diff --git a/sql/hive/src/test/resources/golden/windowing.q -- 26. testGroupByHavingWithSWQAndAlias-0-b996a664b06e5741c08079d5c38241bc b/sql/hive/src/test/resources/golden/windowing.q -- 26. testGroupByHavingWithSWQAndAlias-0-b996a664b06e5741c08079d5c38241bc new file mode 100644 index 0000000000000..2c30e652aa26d --- /dev/null +++ b/sql/hive/src/test/resources/golden/windowing.q -- 26. testGroupByHavingWithSWQAndAlias-0-b996a664b06e5741c08079d5c38241bc @@ -0,0 +1,25 @@ +Manufacturer#1 almond antique burnished rose metallic 2 1173.15 1 1 2 0 +Manufacturer#1 almond antique chartreuse lavender yellow 34 1753.76 2 2 34 32 +Manufacturer#1 almond antique salmon chartreuse burlywood 6 1602.59 3 3 6 -28 +Manufacturer#1 almond aquamarine burnished black steel 28 1414.42 4 4 28 22 +Manufacturer#1 almond aquamarine pink moccasin thistle 42 1632.66 5 5 42 14 +Manufacturer#2 almond antique violet chocolate turquoise 14 1690.68 1 1 14 0 +Manufacturer#2 almond antique violet turquoise frosted 40 1800.7 2 2 40 26 +Manufacturer#2 almond aquamarine midnight light salmon 2 2031.98 3 3 2 -38 +Manufacturer#2 almond aquamarine rose maroon antique 25 1698.66 4 4 25 23 +Manufacturer#2 almond aquamarine sandy cyan gainsboro 18 1701.6 5 5 18 -7 +Manufacturer#3 almond antique chartreuse khaki white 17 1671.68 1 1 17 0 +Manufacturer#3 almond antique forest lavender goldenrod 14 1190.27 2 2 14 -3 +Manufacturer#3 almond antique metallic orange dim 19 1410.39 3 3 19 5 +Manufacturer#3 almond antique misty red olive 1 1922.98 4 4 1 -18 +Manufacturer#3 almond antique olive coral navajo 45 1337.29 5 5 45 44 +Manufacturer#4 almond antique gainsboro frosted violet 10 1620.67 1 1 10 0 +Manufacturer#4 almond antique violet mint lemon 39 1375.42 2 2 39 29 +Manufacturer#4 almond aquamarine floral ivory bisque 27 1206.26 3 3 27 -12 +Manufacturer#4 almond aquamarine yellow dodger mint 7 1844.92 4 4 7 -20 +Manufacturer#4 almond azure aquamarine papaya violet 12 1290.35 5 5 12 5 +Manufacturer#5 almond antique blue firebrick mint 31 1789.69 1 1 31 0 +Manufacturer#5 almond antique medium spring khaki 6 1611.66 2 2 6 -25 +Manufacturer#5 almond antique sky peru orange 2 1788.73 3 3 2 -4 +Manufacturer#5 almond aquamarine dodger light gainsboro 46 1018.1 4 4 46 44 +Manufacturer#5 almond azure blanched chiffon midnight 23 1464.48 5 5 23 -23 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 027056d4b865f..fbffecd3b8ae0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.hive.execution import java.io._ +import org.apache.spark.sql.AnalysisException import org.scalatest.{BeforeAndAfterAll, FunSuite, GivenWhenThen} import org.apache.spark.Logging @@ -300,7 +301,8 @@ abstract class HiveComparisonTest val hiveQueries = queryList.map(new TestHive.HiveQLQueryExecution(_)) // Make sure we can at least parse everything before attempting hive execution. - hiveQueries.foreach(_.analyzed) + hiveQueries.foreach(_.logical) + val computedResults = (queryList.zipWithIndex, hiveQueries, hiveCacheFiles).zipped.map { case ((queryString, i), hiveQuery, cachedAnswerFile)=> try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala new file mode 100644 index 0000000000000..50c096804f6de --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.util.{Locale, TimeZone} + +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.util.Utils +import org.scalatest.BeforeAndAfter + +/** + * The test suite for window functions. To actually compare results with Hive, + * every test should be created by `createQueryTest`. Because we are reusing tables + * for different tests and there are a few properties needed to let Hive generate golden + * files, every `createQueryTest` calls should explicitly set `reset` to `false`. + */ +class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfter { + private val originalTimeZone = TimeZone.getDefault + private val originalLocale = Locale.getDefault + private val testTempDir = Utils.createTempDir() + import org.apache.spark.sql.hive.test.TestHive.implicits._ + + override def beforeAll() { + TestHive.cacheTables = true + // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) + TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) + // Add Locale setting + Locale.setDefault(Locale.US) + + // Create the table used in windowing.q + sql("DROP TABLE IF EXISTS part") + sql( + """ + |CREATE TABLE part( + | p_partkey INT, + | p_name STRING, + | p_mfgr STRING, + | p_brand STRING, + | p_type STRING, + | p_size INT, + | p_container STRING, + | p_retailprice DOUBLE, + | p_comment STRING) + """.stripMargin) + val testData = TestHive.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + sql( + s""" + |LOAD DATA LOCAL INPATH '$testData' overwrite into table part + """.stripMargin) + // The following settings are used for generating golden files with Hive. + // We have to use kryo to correctly let Hive serialize plans with window functions. + // This is used to generate golden files. + sql("set hive.plan.serialization.format=kryo") + // Explicitly set fs to local fs. + sql(s"set fs.default.name=file://$testTempDir/") + //sql(s"set mapred.working.dir=${testTempDir}") + // Ask Hive to run jobs in-process as a single map and reduce task. + sql("set mapred.job.tracker=local") + } + + override def afterAll() { + TestHive.cacheTables = false + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.reset() + } + + ///////////////////////////////////////////////////////////////////////////// + // Tests from windowing.q + ///////////////////////////////////////////////////////////////////////////// + createQueryTest("windowing.q -- 1. testWindowing", + s""" + |select p_mfgr, p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |sum(p_retailprice) over + |(distribute by p_mfgr sort by p_name rows between unbounded preceding and current row) as s1 + |from part + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 2. testGroupByWithPartitioning", + s""" + |select p_mfgr, p_name, p_size, + |min(p_retailprice), + |rank() over(distribute by p_mfgr sort by p_name)as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |p_size, p_size - lag(p_size,1,p_size) over(distribute by p_mfgr sort by p_name) as deltaSz + |from part + |group by p_mfgr, p_name, p_size + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 19. testUDAFsWithGBY", + """ + | + |select p_mfgr,p_name, p_size, p_retailprice, + |sum(p_retailprice) over w1 as s, + |min(p_retailprice) as mi , + |max(p_retailprice) as ma , + |avg(p_retailprice) over w1 as ag + |from part + |group by p_mfgr,p_name, p_size, p_retailprice + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name rows between 2 preceding and 2 following); + | + """.stripMargin, reset = false) + + createQueryTest("windowing.q -- 26. testGroupByHavingWithSWQAndAlias", + """ + |select p_mfgr, p_name, p_size, min(p_retailprice) as mi, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |p_size, p_size - lag(p_size,1,p_size) over(distribute by p_mfgr sort by p_name) as deltaSz + |from part + |group by p_mfgr, p_name, p_size + |having p_size > 0 + """.stripMargin, reset = false) +}