Skip to content

Commit 0bd7765

Browse files
viiryajkbradley
authored andcommitted
[SPARK-23377][ML] Fixes Bucketizer with multiple columns persistence bug
## What changes were proposed in this pull request? #### Problem: Since 2.3, `Bucketizer` supports multiple input/output columns. We will check if exclusive params are set during transformation. E.g., if `inputCols` and `outputCol` are both set, an error will be thrown. However, when we write `Bucketizer`, looks like the default params and user-supplied params are merged during writing. All saved params are loaded back and set to created model instance. So the default `outputCol` param in `HasOutputCol` trait will be set in `paramMap` and become an user-supplied param. That makes the check of exclusive params failed. #### Fix: This changes the saving logic of Bucketizer to handle this case. This is a quick fix to catch the time of 2.3. We should consider modify the persistence mechanism later. Please see the discussion in the JIRA. Note: The multi-column `QuantileDiscretizer` also has the same issue. ## How was this patch tested? Modified tests. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #20594 from viirya/SPARK-23377-2. (cherry picked from commit db45daa) Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
1 parent 03960fa commit 0bd7765

File tree

4 files changed

+78
-4
lines changed

4 files changed

+78
-4
lines changed

mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ package org.apache.spark.ml.feature
1919

2020
import java.{util => ju}
2121

22+
import org.json4s.JsonDSL._
23+
import org.json4s.JValue
24+
import org.json4s.jackson.JsonMethods._
25+
2226
import org.apache.spark.SparkException
2327
import org.apache.spark.annotation.Since
2428
import org.apache.spark.ml.Model
@@ -213,6 +217,8 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
213217
override def copy(extra: ParamMap): Bucketizer = {
214218
defaultCopy[Bucketizer](extra).setParent(parent)
215219
}
220+
221+
override def write: MLWriter = new Bucketizer.BucketizerWriter(this)
216222
}
217223

218224
@Since("1.6.0")
@@ -290,6 +296,28 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
290296
}
291297
}
292298

299+
300+
private[Bucketizer] class BucketizerWriter(instance: Bucketizer) extends MLWriter {
301+
302+
override protected def saveImpl(path: String): Unit = {
303+
// SPARK-23377: The default params will be saved and loaded as user-supplied params.
304+
// Once `inputCols` is set, the default value of `outputCol` param causes the error
305+
// when checking exclusive params. As a temporary to fix it, we skip the default value
306+
// of `outputCol` if `inputCols` is set when saving the metadata.
307+
// TODO: If we modify the persistence mechanism later to better handle default params,
308+
// we can get rid of this.
309+
var paramWithoutOutputCol: Option[JValue] = None
310+
if (instance.isSet(instance.inputCols)) {
311+
val params = instance.extractParamMap().toSeq
312+
val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
313+
p.name -> parse(p.jsonEncode(v))
314+
}.toList
315+
paramWithoutOutputCol = Some(render(jsonParams))
316+
}
317+
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
318+
}
319+
}
320+
293321
@Since("1.6.0")
294322
override def load(path: String): Bucketizer = super.load(path)
295323
}

mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717

1818
package org.apache.spark.ml.feature
1919

20+
import org.json4s.JsonDSL._
21+
import org.json4s.JValue
22+
import org.json4s.jackson.JsonMethods._
23+
2024
import org.apache.spark.annotation.Since
2125
import org.apache.spark.internal.Logging
2226
import org.apache.spark.ml._
@@ -249,11 +253,35 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
249253

250254
@Since("1.6.0")
251255
override def copy(extra: ParamMap): QuantileDiscretizer = defaultCopy(extra)
256+
257+
override def write: MLWriter = new QuantileDiscretizer.QuantileDiscretizerWriter(this)
252258
}
253259

254260
@Since("1.6.0")
255261
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
256262

263+
private[QuantileDiscretizer]
264+
class QuantileDiscretizerWriter(instance: QuantileDiscretizer) extends MLWriter {
265+
266+
override protected def saveImpl(path: String): Unit = {
267+
// SPARK-23377: The default params will be saved and loaded as user-supplied params.
268+
// Once `inputCols` is set, the default value of `outputCol` param causes the error
269+
// when checking exclusive params. As a temporary to fix it, we skip the default value
270+
// of `outputCol` if `inputCols` is set when saving the metadata.
271+
// TODO: If we modify the persistence mechanism later to better handle default params,
272+
// we can get rid of this.
273+
var paramWithoutOutputCol: Option[JValue] = None
274+
if (instance.isSet(instance.inputCols)) {
275+
val params = instance.extractParamMap().toSeq
276+
val jsonParams = params.filter(_.param != instance.outputCol).map { case ParamPair(p, v) =>
277+
p.name -> parse(p.jsonEncode(v))
278+
}.toList
279+
paramWithoutOutputCol = Some(render(jsonParams))
280+
}
281+
DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = paramWithoutOutputCol)
282+
}
283+
}
284+
257285
@Since("1.6.0")
258286
override def load(path: String): QuantileDiscretizer = super.load(path)
259287
}

mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,10 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
172172
.setInputCol("myInputCol")
173173
.setOutputCol("myOutputCol")
174174
.setSplits(Array(0.1, 0.8, 0.9))
175-
testDefaultReadWrite(t)
175+
176+
val bucketizer = testDefaultReadWrite(t)
177+
val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
178+
bucketizer.transform(data)
176179
}
177180

178181
test("Bucket numeric features") {
@@ -327,7 +330,12 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
327330
.setInputCols(Array("myInputCol"))
328331
.setOutputCols(Array("myOutputCol"))
329332
.setSplitsArray(Array(Array(0.1, 0.8, 0.9)))
330-
testDefaultReadWrite(t)
333+
334+
val bucketizer = testDefaultReadWrite(t)
335+
val data = Seq((1.0, 2.0), (10.0, 100.0), (101.0, -1.0)).toDF("myInputCol", "myInputCol2")
336+
bucketizer.transform(data)
337+
assert(t.hasDefault(t.outputCol))
338+
assert(bucketizer.hasDefault(bucketizer.outputCol))
331339
}
332340

333341
test("Bucketizer in a pipeline") {

mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import org.apache.spark.sql.functions.udf
2727
class QuantileDiscretizerSuite
2828
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
2929

30+
import testImplicits._
31+
3032
test("Test observed number of buckets and their sizes match expected values") {
3133
val spark = this.spark
3234
import spark.implicits._
@@ -132,7 +134,10 @@ class QuantileDiscretizerSuite
132134
.setInputCol("myInputCol")
133135
.setOutputCol("myOutputCol")
134136
.setNumBuckets(6)
135-
testDefaultReadWrite(t)
137+
138+
val readDiscretizer = testDefaultReadWrite(t)
139+
val data = sc.parallelize(1 to 100).map(Tuple1.apply).toDF("myInputCol")
140+
readDiscretizer.fit(data)
136141
}
137142

138143
test("Verify resulting model has parent") {
@@ -379,7 +384,12 @@ class QuantileDiscretizerSuite
379384
.setInputCols(Array("input1", "input2"))
380385
.setOutputCols(Array("result1", "result2"))
381386
.setNumBucketsArray(Array(5, 10))
382-
testDefaultReadWrite(discretizer)
387+
388+
val readDiscretizer = testDefaultReadWrite(discretizer)
389+
val data = Seq((1.0, 2.0), (2.0, 3.0), (3.0, 4.0)).toDF("input1", "input2")
390+
readDiscretizer.fit(data)
391+
assert(discretizer.hasDefault(discretizer.outputCol))
392+
assert(readDiscretizer.hasDefault(readDiscretizer.outputCol))
383393
}
384394

385395
test("Multiple Columns: Both inputCol and inputCols are set") {

0 commit comments

Comments
 (0)