Skip to content

Commit c875628

Browse files
Wayne Zhangyanboliang
authored andcommitted
[SPARK-20574][ML] Allow Bucketizer to handle non-Double numeric column
## What changes were proposed in this pull request? Bucketizer currently requires input column to be Double, but the logic should work on any numeric data types. Many practical problems have integer/float data types, and it could get very tedious to manually cast them into Double before calling bucketizer. This PR extends bucketizer to handle all numeric types. ## How was this patch tested? New test. Author: Wayne Zhang <actuaryzhang@uber.com> Closes #17840 from actuaryzhang/bucketizer. (cherry picked from commit 0d16faa) Signed-off-by: Yanbo Liang <ybliang8@gmail.com>
1 parent 425ed26 commit c875628

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
116116
Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
117117
}
118118

119-
val newCol = bucketizer(filteredDataset($(inputCol)))
119+
val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType))
120120
val newField = prepOutputField(filteredDataset.schema)
121121
filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
122122
}
@@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
130130

131131
@Since("1.4.0")
132132
override def transformSchema(schema: StructType): StructType = {
133-
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
133+
SchemaUtils.checkNumericType(schema, $(inputCol))
134134
SchemaUtils.appendColumn(schema, prepOutputField(schema))
135135
}
136136

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
2626
import org.apache.spark.ml.util.TestingUtils._
2727
import org.apache.spark.mllib.util.MLlibTestSparkContext
2828
import org.apache.spark.sql.{DataFrame, Row}
29+
import org.apache.spark.sql.functions._
30+
import org.apache.spark.sql.types._
2931

3032
class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
3133

@@ -162,6 +164,29 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
162164
.setSplits(Array(0.1, 0.8, 0.9))
163165
testDefaultReadWrite(t)
164166
}
167+
168+
test("Bucket numeric features") {
169+
val splits = Array(-3.0, 0.0, 3.0)
170+
val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0)
171+
val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0)
172+
val dataFrame: DataFrame = data.zip(expectedBuckets).toSeq.toDF("feature", "expected")
173+
174+
val bucketizer: Bucketizer = new Bucketizer()
175+
.setInputCol("feature")
176+
.setOutputCol("result")
177+
.setSplits(splits)
178+
179+
val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType,
180+
ByteType, DecimalType(10, 0))
181+
for (mType <- types) {
182+
val df = dataFrame.withColumn("feature", col("feature").cast(mType))
183+
bucketizer.transform(df).select("result", "expected").collect().foreach {
184+
case Row(x: Double, y: Double) =>
185+
assert(x === y, "The result is not correct after bucketing in type " +
186+
mType.toString + ". " + s"Expected $y but found $x.")
187+
}
188+
}
189+
}
165190
}
166191

167192
private object BucketizerSuite extends SparkFunSuite {

0 commit comments

Comments
 (0)