Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@ case class Average(child: Expression) extends AlgebraicAggregate {
override def dataType: DataType = resultType

// Expected input data type.
// TODO: Once we remove the old code path, we can use our analyzer to cast NullType
// to the default data type of the NumericType.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select avg(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(NumericType, NullType))

private val resultType = child.dataType match {
Expand Down Expand Up @@ -256,12 +259,19 @@ case class Sum(child: Expression) extends AlgebraicAggregate {
override def dataType: DataType = resultType

// Expected input data type.
// TODO: Right now, we replace old aggregate functions (based on AggregateExpression1) to the
// new version at planning time (after analysis phase). For now, NullType is added at here
// to make it resolved when we have cases like `select sum(null)`.
// We can use our analyzer to cast NullType to the default data type of the NumericType once
// we remove the old aggregate functions. Then, we will not need NullType at here.
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(LongType, DoubleType, DecimalType, NullType))

private val resultType = child.dataType match {
case DecimalType.Fixed(precision, scale) =>
DecimalType.bounded(precision + 10, scale)
// TODO: Remove this line once we remove the NullType from inputTypes.
case NullType => IntegerType
case _ => child.dataType
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// aggregate function to the corresponding attribute of the function.
val aggregateFunctionMap = aggregateExpressions.map { agg =>
val aggregateFunction = agg.aggregateFunction
val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
(aggregateFunction, agg.isDistinct) ->
Alias(aggregateFunction, aggregateFunction.toString)().toAttribute
(aggregateFunction -> attribtue)
}.toMap

val (functionsWithDistinct, functionsWithoutDistinct) =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution

import java.io.{DataInputStream, DataOutputStream, OutputStream, InputStream}
import java.io._
import java.nio.ByteBuffer

import scala.reflect.ClassTag
Expand Down Expand Up @@ -58,11 +58,26 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst
*/
override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream {
private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096)
// When `out` is backed by ChainedBufferOutputStream, we will get an
// UnsupportedOperationException when we call dOut.writeInt because it internally calls
// ChainedBufferOutputStream's write(b: Int), which is not supported.
// To workaround this issue, we create an array for sorting the int value.
// To reproduce the problem, use dOut.writeInt(row.getSizeInBytes) and
// run SparkSqlSerializer2SortMergeShuffleSuite.
private[this] var intBuffer: Array[Byte] = new Array[Byte](4)
private[this] val dOut: DataOutputStream = new DataOutputStream(out)

override def writeValue[T: ClassTag](value: T): SerializationStream = {
val row = value.asInstanceOf[UnsafeRow]
dOut.writeInt(row.getSizeInBytes)
val size = row.getSizeInBytes
// This part is based on DataOutputStream's writeInt.
// It is for dOut.writeInt(row.getSizeInBytes).
intBuffer(0) = ((size >>> 24) & 0xFF).toByte
intBuffer(1) = ((size >>> 16) & 0xFF).toByte
intBuffer(2) = ((size >>> 8) & 0xFF).toByte
intBuffer(3) = ((size >>> 0) & 0xFF).toByte
dOut.write(intBuffer, 0, 4)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JoshRosen I made this change to workaround ChainedBufferOutputStream's unsupported write(b: Int).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we need to double check if we need to wrap input stream with a buffered input stream when we read data back.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @JoshRosen do you think this is fine? seems inefficient to me but maybe there is no better way


row.writeToStream(out, writeBuffer)
this
}
Expand Down Expand Up @@ -90,6 +105,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst

override def close(): Unit = {
writeBuffer = null
intBuffer = null
dOut.writeInt(EOF)
dOut.close()
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.aggregate

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan, UnaryNode}
import org.apache.spark.sql.types.StructType

case class SortBasedAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
groupingExpressions: Seq[NamedExpression],
nonCompleteAggregateExpressions: Seq[AggregateExpression2],
nonCompleteAggregateAttributes: Seq[Attribute],
completeAggregateExpressions: Seq[AggregateExpression2],
completeAggregateAttributes: Seq[Attribute],
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {

override def outputsUnsafeRows: Boolean = false

override def canProcessUnsafeRows: Boolean = false

override def canProcessSafeRows: Boolean = true

override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)

override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}

override def requiredChildOrdering: Seq[Seq[SortOrder]] = {
groupingExpressions.map(SortOrder(_, Ascending)) :: Nil
}

override def outputOrdering: Seq[SortOrder] = {
groupingExpressions.map(SortOrder(_, Ascending))
}

protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
child.execute().mapPartitions { iter =>
// Because the constructor of an aggregation iterator will read at least the first row,
// we need to get the value of iter.hasNext first.
val hasInput = iter.hasNext
if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Iterator[InternalRow]()
} else {
val outputIter = SortBasedAggregationIterator.createFromInputIterator(
groupingExpressions,
nonCompleteAggregateExpressions,
nonCompleteAggregateAttributes,
completeAggregateExpressions,
completeAggregateAttributes,
initialInputBufferOffset,
resultExpressions,
newMutableProjection _,
newProjection _,
child.output,
iter,
outputsUnsafeRows)
if (!hasInput && groupingExpressions.isEmpty) {
// There is no input and there is no grouping expressions.
// We need to output a single row as the output.
Iterator[InternalRow](outputIter.outputForEmptyGroupingKeyWithoutInput())
} else {
outputIter
}
}
}
}

override def simpleString: String = {
val allAggregateExpressions = nonCompleteAggregateExpressions ++ completeAggregateExpressions
s"""SortBasedAggregate ${groupingExpressions} ${allAggregateExpressions}"""
}
}
Loading