Skip to content

Commit 3f63f08

Browse files
RoyGaomengxr
authored andcommitted
[SPARK-7013][ML][TEST] Add unit test for spark.ml StandardScaler
I have added unit test for ML's StandardScaler By comparing with R's output, please review for me. Thx. Author: RoyGaoVLIS <roygao@zju.edu.cn> Closes #6665 from RoyGao/7013. (cherry picked from commit 67a5132) Signed-off-by: Xiangrui Meng <meng@databricks.com>
1 parent 737f071 commit 3f63f08

File tree

1 file changed

+108
-0
lines changed

1 file changed

+108
-0
lines changed
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.ml.feature
19+
20+
21+
import org.apache.spark.SparkFunSuite
22+
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
23+
import org.apache.spark.mllib.util.MLlibTestSparkContext
24+
import org.apache.spark.mllib.util.TestingUtils._
25+
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
26+
27+
class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext{
28+
29+
@transient var data: Array[Vector] = _
30+
@transient var resWithStd: Array[Vector] = _
31+
@transient var resWithMean: Array[Vector] = _
32+
@transient var resWithBoth: Array[Vector] = _
33+
34+
override def beforeAll(): Unit = {
35+
super.beforeAll()
36+
37+
data = Array(
38+
Vectors.dense(-2.0, 2.3, 0.0),
39+
Vectors.dense(0.0, -5.1, 1.0),
40+
Vectors.dense(1.7, -0.6, 3.3)
41+
)
42+
resWithMean = Array(
43+
Vectors.dense(-1.9, 3.433333333333, -1.433333333333),
44+
Vectors.dense(0.1, -3.966666666667, -0.433333333333),
45+
Vectors.dense(1.8, 0.533333333333, 1.866666666667)
46+
)
47+
resWithStd = Array(
48+
Vectors.dense(-1.079898494312, 0.616834091415, 0.0),
49+
Vectors.dense(0.0, -1.367762550529, 0.590968109266),
50+
Vectors.dense(0.917913720165, -0.160913241239, 1.950194760579)
51+
)
52+
resWithBoth = Array(
53+
Vectors.dense(-1.0259035695965, 0.920781324866, -0.8470542899497),
54+
Vectors.dense(0.0539949247156, -1.063815317078, -0.256086180682),
55+
Vectors.dense(0.9719086448809, 0.143033992212, 1.103140470631)
56+
)
57+
}
58+
59+
def assertResult(dataframe: DataFrame): Unit = {
60+
dataframe.select("standarded_features", "expected").collect().foreach {
61+
case Row(vector1: Vector, vector2: Vector) =>
62+
assert(vector1 ~== vector2 absTol 1E-5,
63+
"The vector value is not correct after standardization.")
64+
}
65+
}
66+
67+
test("Standardization with default parameter") {
68+
val df0 = sqlContext.createDataFrame(data.zip(resWithStd)).toDF("features", "expected")
69+
70+
val standardscaler0 = new StandardScaler()
71+
.setInputCol("features")
72+
.setOutputCol("standarded_features")
73+
.fit(df0)
74+
75+
assertResult(standardscaler0.transform(df0))
76+
}
77+
78+
test("Standardization with setter") {
79+
val df1 = sqlContext.createDataFrame(data.zip(resWithBoth)).toDF("features", "expected")
80+
val df2 = sqlContext.createDataFrame(data.zip(resWithMean)).toDF("features", "expected")
81+
val df3 = sqlContext.createDataFrame(data.zip(data)).toDF("features", "expected")
82+
83+
val standardscaler1 = new StandardScaler()
84+
.setInputCol("features")
85+
.setOutputCol("standarded_features")
86+
.setWithMean(true)
87+
.setWithStd(true)
88+
.fit(df1)
89+
90+
val standardscaler2 = new StandardScaler()
91+
.setInputCol("features")
92+
.setOutputCol("standarded_features")
93+
.setWithMean(true)
94+
.setWithStd(false)
95+
.fit(df2)
96+
97+
val standardscaler3 = new StandardScaler()
98+
.setInputCol("features")
99+
.setOutputCol("standarded_features")
100+
.setWithMean(false)
101+
.setWithStd(false)
102+
.fit(df3)
103+
104+
assertResult(standardscaler1.transform(df1))
105+
assertResult(standardscaler2.transform(df2))
106+
assertResult(standardscaler3.transform(df3))
107+
}
108+
}

0 commit comments

Comments
 (0)