Skip to content

Commit b098576

Browse files
zhengruifengholdenk
authored andcommitted
[SPARK-14352][SQL] approxQuantile should support multi columns
## What changes were proposed in this pull request? 1, add the multi-cols support based on current private api 2, add the multi-cols support to pyspark ## How was this patch tested? unit tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Author: Ruifeng Zheng <ruifengz@foxmail.com> Closes #12135 from zhengruifeng/quantile4multicols.
1 parent c5fcb7f commit b098576

File tree

4 files changed

+101
-11
lines changed

4 files changed

+101
-11
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
#
1717

1818
import sys
19-
import warnings
2019
import random
2120

2221
if sys.version >= '3':
@@ -1348,7 +1347,7 @@ def replace(self, to_replace, value, subset=None):
13481347
@since(2.0)
13491348
def approxQuantile(self, col, probabilities, relativeError):
13501349
"""
1351-
Calculates the approximate quantiles of a numerical column of a
1350+
Calculates the approximate quantiles of numerical columns of a
13521351
DataFrame.
13531352
13541353
The result of this algorithm has the following deterministic bound:
@@ -1365,18 +1364,41 @@ def approxQuantile(self, col, probabilities, relativeError):
13651364
Space-efficient Online Computation of Quantile Summaries]]
13661365
by Greenwald and Khanna.
13671366
1368-
:param col: the name of the numerical column
1367+
Note that rows containing any null values will be removed before calculation.
1368+
1369+
:param col: str, list.
1370+
Can be a single column name, or a list of names for multiple columns.
13691371
:param probabilities: a list of quantile probabilities
13701372
Each number must belong to [0, 1].
13711373
For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
13721374
:param relativeError: The relative target precision to achieve
13731375
(>= 0). If set to zero, the exact quantiles are computed, which
13741376
could be very expensive. Note that values greater than 1 are
13751377
accepted but give the same result as 1.
1376-
:return: the approximate quantiles at the given probabilities
1378+
:return: the approximate quantiles at the given probabilities. If
1379+
the input `col` is a string, the output is a list of floats. If the
1380+
input `col` is a list or tuple of strings, the output is also a
1381+
list, but each element in it is a list of floats, i.e., the output
1382+
is a list of list of floats.
1383+
1384+
.. versionchanged:: 2.2
1385+
Added support for multiple columns.
13771386
"""
1378-
if not isinstance(col, str):
1379-
raise ValueError("col should be a string.")
1387+
1388+
if not isinstance(col, (str, list, tuple)):
1389+
raise ValueError("col should be a string, list or tuple, but got %r" % type(col))
1390+
1391+
isStr = isinstance(col, str)
1392+
1393+
if isinstance(col, tuple):
1394+
col = list(col)
1395+
elif isinstance(col, str):
1396+
col = [col]
1397+
1398+
for c in col:
1399+
if not isinstance(c, str):
1400+
raise ValueError("columns should be strings, but got %r" % type(c))
1401+
col = _to_list(self._sc, col)
13801402

13811403
if not isinstance(probabilities, (list, tuple)):
13821404
raise ValueError("probabilities should be a list or tuple")
@@ -1392,7 +1414,8 @@ def approxQuantile(self, col, probabilities, relativeError):
13921414
relativeError = float(relativeError)
13931415

13941416
jaq = self._jdf.stat().approxQuantile(col, probabilities, relativeError)
1395-
return list(jaq)
1417+
jaq_list = [list(j) for j in jaq]
1418+
return jaq_list[0] if isStr else jaq_list
13961419

13971420
@since(1.4)
13981421
def corr(self, col1, col2, method=None):

python/pyspark/sql/tests.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -895,11 +895,32 @@ def test_first_last_ignorenulls(self):
895895
self.assertEqual([Row(a=None, b=1, c=None, d=98)], df3.collect())
896896

897897
def test_approxQuantile(self):
898-
df = self.sc.parallelize([Row(a=i) for i in range(10)]).toDF()
898+
df = self.sc.parallelize([Row(a=i, b=i+10) for i in range(10)]).toDF()
899899
aq = df.stat.approxQuantile("a", [0.1, 0.5, 0.9], 0.1)
900900
self.assertTrue(isinstance(aq, list))
901901
self.assertEqual(len(aq), 3)
902902
self.assertTrue(all(isinstance(q, float) for q in aq))
903+
aqs = df.stat.approxQuantile(["a", "b"], [0.1, 0.5, 0.9], 0.1)
904+
self.assertTrue(isinstance(aqs, list))
905+
self.assertEqual(len(aqs), 2)
906+
self.assertTrue(isinstance(aqs[0], list))
907+
self.assertEqual(len(aqs[0]), 3)
908+
self.assertTrue(all(isinstance(q, float) for q in aqs[0]))
909+
self.assertTrue(isinstance(aqs[1], list))
910+
self.assertEqual(len(aqs[1]), 3)
911+
self.assertTrue(all(isinstance(q, float) for q in aqs[1]))
912+
aqt = df.stat.approxQuantile(("a", "b"), [0.1, 0.5, 0.9], 0.1)
913+
self.assertTrue(isinstance(aqt, list))
914+
self.assertEqual(len(aqt), 2)
915+
self.assertTrue(isinstance(aqt[0], list))
916+
self.assertEqual(len(aqt[0]), 3)
917+
self.assertTrue(all(isinstance(q, float) for q in aqt[0]))
918+
self.assertTrue(isinstance(aqt[1], list))
919+
self.assertEqual(len(aqt[1]), 3)
920+
self.assertTrue(all(isinstance(q, float) for q in aqt[1]))
921+
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(123, [0.1, 0.9], 0.1))
922+
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(("a", 123), [0.1, 0.9], 0.1))
923+
self.assertRaises(ValueError, lambda: df.stat.approxQuantile(["a", 123], [0.1, 0.9], 0.1))
903924

904925
def test_corr(self):
905926
import math

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
2424
import org.apache.spark.annotation.InterfaceStability
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.execution.stat._
27+
import org.apache.spark.sql.functions.col
2728
import org.apache.spark.sql.types._
2829
import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
2930

@@ -74,14 +75,44 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
7475
Seq(col), probabilities, relativeError).head.toArray
7576
}
7677

78+
/**
79+
* Calculates the approximate quantiles of numerical columns of a DataFrame.
80+
* @see [[DataFrameStatsFunctions.approxQuantile(col:Str* approxQuantile]] for
81+
* detailed description.
82+
*
83+
* Note that rows containing any null or NaN values values will be removed before
84+
* calculation.
85+
* @param cols the names of the numerical columns
86+
* @param probabilities a list of quantile probabilities
87+
* Each number must belong to [0, 1].
88+
* For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
89+
* @param relativeError The relative target precision to achieve (>= 0).
90+
* If set to zero, the exact quantiles are computed, which could be very expensive.
91+
* Note that values greater than 1 are accepted but give the same result as 1.
92+
* @return the approximate quantiles at the given probabilities of each column
93+
*
94+
* @note Rows containing any NaN values will be removed before calculation
95+
*
96+
* @since 2.2.0
97+
*/
98+
def approxQuantile(
99+
cols: Array[String],
100+
probabilities: Array[Double],
101+
relativeError: Double): Array[Array[Double]] = {
102+
StatFunctions.multipleApproxQuantiles(df.select(cols.map(col): _*).na.drop(), cols,
103+
probabilities, relativeError).map(_.toArray).toArray
104+
}
105+
106+
77107
/**
78108
* Python-friendly version of [[approxQuantile()]]
79109
*/
80110
private[spark] def approxQuantile(
81-
col: String,
111+
cols: List[String],
82112
probabilities: List[Double],
83-
relativeError: Double): java.util.List[Double] = {
84-
approxQuantile(col, probabilities.toArray, relativeError).toList.asJava
113+
relativeError: Double): java.util.List[java.util.List[Double]] = {
114+
approxQuantile(cols.toArray, probabilities.toArray, relativeError)
115+
.map(_.toList.asJava).toList.asJava
85116
}
86117

87118
/**

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,11 +149,26 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
149149
assert(math.abs(s2 - q2 * n) < error_single)
150150
assert(math.abs(d1 - 2 * q1 * n) < error_double)
151151
assert(math.abs(d2 - 2 * q2 * n) < error_double)
152+
153+
// Multiple columns
154+
val Array(Array(ms1, ms2), Array(md1, md2)) =
155+
df.stat.approxQuantile(Array("singles", "doubles"), Array(q1, q2), epsilon)
156+
157+
assert(math.abs(ms1 - q1 * n) < error_single)
158+
assert(math.abs(ms2 - q2 * n) < error_single)
159+
assert(math.abs(md1 - 2 * q1 * n) < error_double)
160+
assert(math.abs(md2 - 2 * q2 * n) < error_double)
152161
}
153162
// test approxQuantile on NaN values
154163
val dfNaN = Seq(Double.NaN, 1.0, Double.NaN, Double.NaN).toDF("input")
155164
val resNaN = dfNaN.stat.approxQuantile("input", Array(q1, q2), epsilons.head)
156165
assert(resNaN.count(_.isNaN) === 0)
166+
// test approxQuantile on multi-column NaN values
167+
val dfNaN2 = Seq((Double.NaN, 1.0), (1.0, 1.0), (-1.0, Double.NaN), (Double.NaN, Double.NaN))
168+
.toDF("input1", "input2")
169+
val resNaN2 = dfNaN2.stat.approxQuantile(Array("input1", "input2"),
170+
Array(q1, q2), epsilons.head)
171+
assert(resNaN2.flatten.count(_.isNaN) === 0)
157172
}
158173

159174
test("crosstab") {

0 commit comments

Comments
 (0)