diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index c9d3b99990830..484be76b99156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.util.collection.Utils object StatFunctions extends Logging { @@ -188,47 +187,14 @@ object StatFunctions extends Logging { /** Generate a table of frequencies for the elements of two columns. */ def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = { - val tableName = s"${col1}_$col2" - val counts = df.groupBy(col1, col2).agg(count("*")).take(1e6.toInt) - if (counts.length == 1e6.toInt) { - logWarning("The maximum limit of 1e6 pairs have been collected, which may not be all of " + - "the pairs. Please try reducing the amount of distinct items in your columns.") - } - def cleanElement(element: Any): String = { - if (element == null) "null" else element.toString - } - // get the distinct sorted values of column 2, so that we can make them the column names - val distinctCol2: Map[Any, Int] = - Utils.toMapWithIndex(counts.map(e => cleanElement(e.get(1))).distinct.sorted) - val columnSize = distinctCol2.size - require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + - s"exceed 1e4. Currently $columnSize") - val table = counts.groupBy(_.get(0)).map { case (col1Item, rows) => - val countsRow = new GenericInternalRow(columnSize + 1) - rows.foreach { (row: Row) => - // row.get(0) is column 1 - // row.get(1) is column 2 - // row.get(2) is the frequency - val columnIndex = distinctCol2(cleanElement(row.get(1))) - countsRow.setLong(columnIndex + 1, row.getLong(2)) - } - // the value of col1 is the first value, the rest are the counts - countsRow.update(0, UTF8String.fromString(cleanElement(col1Item))) - countsRow - }.toSeq - // Back ticks can't exist in DataFrame column names, therefore drop them. To be able to accept - // special keywords and `.`, wrap the column names in ``. - def cleanColumnName(name: String): String = { - name.replace("`", "") - } - // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in - // SPARK-8681. We need to explicitly sort by the column index and assign the column names. - val headerNames = distinctCol2.toSeq.sortBy(_._2).map { r => - StructField(cleanColumnName(r._1.toString), LongType) - } - val schema = StructType(StructField(tableName, StringType) +: headerNames) - - Dataset.ofRows(df.sparkSession, LocalRelation(schema.toAttributes, table)).na.fill(0.0) + df.groupBy( + when(isnull(col(col1)), "null") + .otherwise(col(col1).cast("string")) + .as(s"${col1}_$col2") + ).pivot( + when(isnull(col(col2)), "null") + .otherwise(regexp_replace(col(col2).cast("string"), "`", "")) + ).count().na.fill(0L) } /** Calculate selected summary statistics for a dataset */