Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@ import java.util.Locale

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Cast, EvalMode, Expression, GenericInternalRow, GetArrayItem, Literal}
import org.apache.spark.sql.catalyst.expressions.{Cast, ElementAt, EvalMode, GenericInternalRow}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.catalyst.util.{GenericArrayData, QuantileSummaries}
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -233,9 +232,11 @@ object StatFunctions extends Logging {

/** Calculate selected summary statistics for a dataset */
def summary(ds: Dataset[_], statistics: Seq[String]): DataFrame = {

val defaultStatistics = Seq("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
val selectedStatistics = if (statistics.nonEmpty) statistics else defaultStatistics
val selectedStatistics = if (statistics.nonEmpty) {
statistics.toArray
} else {
Array("count", "mean", "stddev", "min", "25%", "50%", "75%", "max")
}

val percentiles = selectedStatistics.filter(a => a.endsWith("%")).map { p =>
try {
Expand All @@ -247,71 +248,66 @@ object StatFunctions extends Logging {
}
require(percentiles.forall(p => p >= 0 && p <= 1), "Percentiles must be in the range [0, 1]")

def castAsDoubleIfNecessary(e: Expression): Expression = if (e.dataType == StringType) {
Cast(e, DoubleType, evalMode = EvalMode.TRY)
} else {
e
}
var percentileIndex = 0
val statisticFns = selectedStatistics.map { stats =>
if (stats.endsWith("%")) {
val index = percentileIndex
percentileIndex += 1
(child: Expression) =>
GetArrayItem(
new ApproximatePercentile(castAsDoubleIfNecessary(child),
Literal(new GenericArrayData(percentiles), ArrayType(DoubleType, false)))
.toAggregateExpression(),
Literal(index))
} else {
stats.toLowerCase(Locale.ROOT) match {
case "count" => (child: Expression) => Count(child).toAggregateExpression()
case "count_distinct" => (child: Expression) =>
Count(child).toAggregateExpression(isDistinct = true)
case "approx_count_distinct" => (child: Expression) =>
HyperLogLogPlusPlus(child).toAggregateExpression()
case "mean" => (child: Expression) =>
Average(castAsDoubleIfNecessary(child)).toAggregateExpression()
case "stddev" => (child: Expression) =>
StddevSamp(castAsDoubleIfNecessary(child)).toAggregateExpression()
case "min" => (child: Expression) => Min(child).toAggregateExpression()
case "max" => (child: Expression) => Max(child).toAggregateExpression()
case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats)
var mapColumns = Seq.empty[Column]
var columnNames = Seq.empty[String]

ds.schema.fields.foreach { field =>
if (field.dataType.isInstanceOf[NumericType] || field.dataType.isInstanceOf[StringType]) {
val column = col(field.name)
var casted = column
if (field.dataType.isInstanceOf[StringType]) {
casted = new Column(Cast(column.expr, DoubleType, evalMode = EvalMode.TRY))
}
}
}

val selectedCols = ds.logicalPlan.output
.filter(a => a.dataType.isInstanceOf[NumericType] || a.dataType.isInstanceOf[StringType])
val percentilesCol = if (percentiles.nonEmpty) {
percentile_approx(casted, lit(percentiles),
lit(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
} else null

val aggExprs = statisticFns.flatMap { func =>
selectedCols.map(c => Column(Cast(func(c), StringType)).as(c.name))
}
var aggColumns = Seq.empty[Column]
var percentileIndex = 0
selectedStatistics.foreach { stats =>
aggColumns :+= lit(stats)

// If there is no selected columns, we don't need to run this aggregate, so make it a lazy val.
lazy val aggResult = ds.select(aggExprs: _*).queryExecution.toRdd.collect().head
stats.toLowerCase(Locale.ROOT) match {
case "count" => aggColumns :+= count(column)

// We will have one row for each selected statistic in the result.
val result = Array.fill[InternalRow](selectedStatistics.length) {
// each row has the statistic name, and statistic values of each selected column.
new GenericInternalRow(selectedCols.length + 1)
}
case "count_distinct" => aggColumns :+= count_distinct(column)

case "approx_count_distinct" => aggColumns :+= approx_count_distinct(column)

var rowIndex = 0
while (rowIndex < result.length) {
val statsName = selectedStatistics(rowIndex)
result(rowIndex).update(0, UTF8String.fromString(statsName))
for (colIndex <- selectedCols.indices) {
val statsValue = aggResult.getUTF8String(rowIndex * selectedCols.length + colIndex)
result(rowIndex).update(colIndex + 1, statsValue)
case "mean" => aggColumns :+= avg(casted)

case "stddev" => aggColumns :+= stddev(casted)

case "min" => aggColumns :+= min(column)

case "max" => aggColumns :+= max(column)

case percentile if percentile.endsWith("%") =>
aggColumns :+= get(percentilesCol, lit(percentileIndex))
percentileIndex += 1

case _ => throw QueryExecutionErrors.statisticNotRecognizedError(stats)
}
}

// map { "count" -> "1024", "min" -> "1.0", ... }
mapColumns :+= map(aggColumns.map(_.cast(StringType)): _*).as(field.name)
columnNames :+= field.name
}
rowIndex += 1
}

// All columns are string type
val output = AttributeReference("summary", StringType)() +:
selectedCols.map(c => AttributeReference(c.name, StringType)())

Dataset.ofRows(ds.sparkSession, LocalRelation(output, result))
if (mapColumns.isEmpty) {
ds.sparkSession.createDataFrame(selectedStatistics.map(Tuple1.apply))
.withColumnRenamed("_1", "summary")
} else {
val valueColumns = columnNames.map { columnName =>
new Column(ElementAt(col(columnName).expr, col("summary").expr)).as(columnName)
}
ds.select(mapColumns: _*)
.withColumn("summary", explode(lit(selectedStatistics)))
.select(Array(col("summary")) ++ valueColumns: _*)
}
}
}