Skip to content

Commit 74bdf54

Browse files
committed
Initial commit for correelation and covariance matrices
1 parent cf2e0ae commit 74bdf54

File tree

3 files changed

+188
-1
lines changed

3 files changed

+188
-1
lines changed

sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,39 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
5252
StatFunctions.calculateCov(df, Seq(col1, col2))
5353
}
5454

55+
/**
56+
* Calculate the sample covariance between columns of a DataFrame.
57+
*
58+
* @return a covariance matrix as a DataFrame
59+
*
60+
* {{{
61+
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
62+
* .withColumn("rand2", rand(seed=27))
63+
* val covmatrix = df.stat.cov()
64+
* covmatrix.show()
65+
* +---------+------------------+-------------------+-------------------+
66+
* |FieldName| id| rand1| rand2|
67+
* +---------+------------------+-------------------+-------------------+
68+
* | id| 9.166666666666666| 0.4131594565676311| 0.7012982830955725|
69+
* | rand1|0.4131594565676311|0.11982701890603772|0.06500805072758595|
70+
* | rand2|0.7012982830955725|0.06500805072758595|0.09383550706974164|
71+
* +---------+------------------+-------------------+-------------------+
72+
* }}}
73+
*
74+
* @since 1.6.0
75+
*/
76+
def cov(): DataFrame = {
77+
StatFunctions.calculateCov(df)
78+
}
79+
5580
/**
5681
* Calculates the correlation of two columns of a DataFrame. Currently only supports the Pearson
5782
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
5883
* MLlib's Statistics.
5984
*
6085
* @param col1 the name of the column
6186
* @param col2 the name of the column to calculate the correlation against
87+
* @param method the name of the correlation method
6288
* @return The Pearson Correlation Coefficient as a Double.
6389
*
6490
* {{{
@@ -96,6 +122,63 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) {
96122
corr(col1, col2, "pearson")
97123
}
98124

125+
/**
126+
* Calculates the correlation of columns in the DataFrame. Currently only supports the Pearson
127+
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
128+
* MLlib's Statistics.
129+
*
130+
* @param method the name of the correlation method
131+
* @return The Pearson Correlation matrix as a DataFrame.
132+
*
133+
* {{{
134+
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
135+
* .withColumn("rand2", rand(seed=27))
136+
* val corrmatrix = df.stat.corr()
137+
* corrmatrix.show()
138+
* +---------+------------------+------------------+------------------+
139+
* |FieldName| id| rand1| rand2|
140+
* +---------+------------------+------------------+------------------+
141+
* | id| 1.0| 0.3942163209095|0.7561595709319909|
142+
* | rand1| 0.3942163209095| 1.0|0.6130644931298477|
143+
* | rand2|0.7561595709319909|0.6130644931298477| 1.0|
144+
* +---------+------------------+------------------+------------------+
145+
* }}}
146+
*
147+
* @since 1.6.0
148+
*/
149+
def corr(method: String): DataFrame = {
150+
require(method == "pearson", "Currently only the calculation of the Pearson Correlation " +
151+
"coefficient is supported.")
152+
StatFunctions.pearsonCorrelation(df)
153+
}
154+
155+
/**
156+
* Calculates the correlation of columns in the DataFrame. Currently only supports the Pearson
157+
* Correlation Coefficient. For Spearman Correlation, consider using RDD methods found in
158+
* MLlib's Statistics.
159+
*
160+
* @return The Pearson Correlation matrix as a DataFrame.
161+
*
162+
* {{{
163+
* val df = sc.parallelize(0 until 10).toDF("id").withColumn("rand1", rand(seed=10))
164+
* .withColumn("rand2", rand(seed=27))
165+
* val corrmatrix = df.stat.corr()
166+
* corrmatrix.show()
167+
* +---------+------------------+------------------+------------------+
168+
* |FieldName| id| rand1| rand2|
169+
* +---------+------------------+------------------+------------------+
170+
* | id| 1.0| 0.3942163209095|0.7561595709319909|
171+
* | rand1| 0.3942163209095| 1.0|0.6130644931298477|
172+
* | rand2|0.7561595709319909|0.6130644931298477| 1.0|
173+
* +---------+------------------+------------------+------------------+
174+
* }}}
175+
*
176+
* @since 1.6.0
177+
*/
178+
def corr(): DataFrame = {
179+
corr("pearson")
180+
}
181+
99182
/**
100183
* Computes a pair-wise frequency table of the given columns. Also known as a contingency table.
101184
* The number of distinct values for each column should be less than 1e4. At most 1e6 non-zero

sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@ package org.apache.spark.sql.execution.stat
1919

2020
import org.apache.spark.Logging
2121
import org.apache.spark.sql.{Row, Column, DataFrame}
22-
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast}
22+
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast, AttributeReference}
23+
import scala.collection.mutable.ArrayBuffer
2324
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2425
import org.apache.spark.sql.functions._
2526
import org.apache.spark.sql.types._
@@ -33,6 +34,31 @@ private[sql] object StatFunctions extends Logging {
3334
counts.Ck / math.sqrt(counts.MkX * counts.MkY)
3435
}
3536

37+
/** Calculate the Pearson Correlation matrix for given DataFrame */
38+
private[sql] def pearsonCorrelation(df: DataFrame): DataFrame = {
39+
val fieldNames = df.schema.fieldNames
40+
val dfStructAttrs = ArrayBuffer[AttributeReference](
41+
AttributeReference("FieldName", StringType, true)())
42+
val rows = fieldNames.map{fname => val countsRow = new GenericMutableRow(fieldNames.length + 1)
43+
countsRow.update(0, UTF8String.fromString(fname))
44+
countsRow
45+
}.toSeq
46+
// generates field types of the output DataFrame
47+
for(field <- fieldNames) dfStructAttrs += AttributeReference(field, DoubleType, true)()
48+
49+
// fills the correlation matrix by computing column-by-column correlations
50+
for (i <- 0 to fieldNames.length - 1){
51+
for (j <- 0 to i){
52+
val corr = pearsonCorrelation(df, Seq(fieldNames(i), fieldNames(j)))
53+
rows(i).setDouble(j + 1, corr)
54+
rows(j).setDouble(i + 1, corr)
55+
}
56+
rows(i).setDouble(i + 1, 1.0)
57+
}
58+
59+
new DataFrame(df.sqlContext, new LocalRelation(dfStructAttrs, rows))
60+
}
61+
3662
/** Helper class to simplify tracking and merging counts. */
3763
private class CovarianceCounter extends Serializable {
3864
var xAvg = 0.0 // the mean of all examples seen so far in col1
@@ -102,6 +128,34 @@ private[sql] object StatFunctions extends Logging {
102128
counts.cov
103129
}
104130

131+
/**
132+
* Calculate the covariance of two numerical columns of a DataFrame.
133+
* @param df The DataFrame
134+
* @return the covariance matrix.
135+
*/
136+
private[sql] def calculateCov(df: DataFrame): DataFrame = {
137+
val fieldNames = df.schema.fieldNames
138+
val dfStructAttrs = ArrayBuffer[AttributeReference](
139+
AttributeReference("FieldName", StringType, true)())
140+
val rows = fieldNames.map{fname => val countsRow = new GenericMutableRow(fieldNames.length + 1)
141+
countsRow.update(0, UTF8String.fromString(fname))
142+
countsRow
143+
}.toSeq
144+
// generates field types of the output DataFrame
145+
for(field <- fieldNames) dfStructAttrs += AttributeReference(field, DoubleType, true)()
146+
147+
// fills the covariance matrix by computing column-by-column covariances
148+
for (i <- 0 to fieldNames.length-1){
149+
for (j <- 0 to i){
150+
val cov = calculateCov(df, Seq(fieldNames(i), fieldNames(j)))
151+
rows(i).setDouble(j + 1, cov)
152+
rows(j).setDouble(i + 1, cov)
153+
}
154+
}
155+
156+
new DataFrame(df.sqlContext, new LocalRelation(dfStructAttrs, rows))
157+
}
158+
105159
/** Generate a table of frequencies for the elements of two columns. */
106160
private[sql] def crossTabulate(df: DataFrame, col1: String, col2: String): DataFrame = {
107161
val tableName = s"${col1}_$col2"

sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,35 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
8585
assert(math.abs(corr3 - 0.95723391394758572) < 1e-12)
8686
}
8787

88+
test("pearson correlation matrix") {
89+
val df1 = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
90+
91+
intercept[IllegalArgumentException] {
92+
df1.stat.corr() // doesn't accept non-numerical dataTypes
93+
}
94+
95+
val df2 = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c")
96+
val results = df2.stat.corr()
97+
98+
val row1 = results.where($"FieldName" === "a").collect()(0)
99+
assert(row1.getString(0) == "a")
100+
assert(row1.getDouble(1) == 1.0)
101+
assert(row1.getDouble(2) == 1.0)
102+
assert(row1.getDouble(3) == -1.0)
103+
104+
val row2 = results.where($"FieldName" === "b").collect()(0)
105+
assert(row2.getString(0) == "b")
106+
assert(row2.getDouble(1) == 1.0)
107+
assert(row2.getDouble(2) == 1.0)
108+
assert(row2.getDouble(3) == -1.0)
109+
110+
val row3 = results.where($"FieldName" === "c").collect()(0)
111+
assert(row3.getString(0) == "c")
112+
assert(row3.getDouble(1) == -1.0)
113+
assert(row3.getDouble(2) == -1.0)
114+
assert(row3.getDouble(3) == 1.0)
115+
}
116+
88117
test("covariance") {
89118
val df = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
90119

@@ -98,6 +127,27 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
98127
assert(math.abs(decimalRes) < 1e-12)
99128
}
100129

130+
test("covariance matrix") {
131+
val df1 = Seq.tabulate(10)(i => (i, 2.0 * i, toLetter(i))).toDF("singles", "doubles", "letters")
132+
133+
intercept[IllegalArgumentException] {
134+
df1.stat.cov() // doesn't accept non-numerical dataTypes
135+
}
136+
137+
val df2 = Seq.tabulate(10)(i => (i, 2.0 * i)).toDF("singles", "doubles")
138+
val results = df2.stat.cov()
139+
140+
val row1 = results.where($"FieldName" === "singles").collect()(0)
141+
assert(row1.getString(0) == "singles")
142+
assert(row1.getDouble(1) == 9.166666666666666)
143+
assert(row1.getDouble(2) == 18.333333333333332)
144+
145+
val row2 = results.where($"FieldName" === "doubles").collect()(0)
146+
assert(row2.getString(0) == "doubles")
147+
assert(row2.getDouble(1) == row1.getDouble(2))
148+
assert(row2.getDouble(2) == 36.666666666666664)
149+
}
150+
101151
test("crosstab") {
102152
val rng = new Random()
103153
val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10)))

0 commit comments

Comments
 (0)