Skip to content

Commit b89d055

Browse files
wojtek-szymanskiyanboliang
authored andcommitted
[SPARK-18210][ML] Pipeline.copy does not create an instance with the same UID
## What changes were proposed in this pull request? Motivation: `org.apache.spark.ml.Pipeline.copy(extra: ParamMap)` does not create an instance with the same UID. It does not conform to the method specification from its base class `org.apache.spark.ml.param.Params.copy(extra: ParamMap)` Solution: - fix for Pipeline UID - introduced new tests for `org.apache.spark.ml.Pipeline.copy` - minor improvements in test for `org.apache.spark.ml.PipelineModel.copy` ## How was this patch tested? Introduced new unit test: `org.apache.spark.ml.PipelineSuite."Pipeline.copy"` Improved existing unit test: `org.apache.spark.ml.PipelineSuite."PipelineModel.copy"` Author: Wojciech Szymanski <wk.szymanski@gmail.com> Closes #15759 from wojtek-szymanski/SPARK-18210.
1 parent 340f09d commit b89d055

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ class Pipeline @Since("1.4.0") (
169169
override def copy(extra: ParamMap): Pipeline = {
170170
val map = extractParamMap(extra)
171171
val newStages = map(stages).map(_.copy(extra))
172-
new Pipeline().setStages(newStages)
172+
new Pipeline(uid).setStages(newStages)
173173
}
174174

175175
@Since("1.2.0")

mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,31 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
101101
}
102102
}
103103

104+
test("Pipeline.copy") {
105+
val hashingTF = new HashingTF()
106+
.setNumFeatures(100)
107+
val pipeline = new Pipeline("pipeline").setStages(Array[Transformer](hashingTF))
108+
val copied = pipeline.copy(ParamMap(hashingTF.numFeatures -> 10))
109+
110+
assert(copied.uid === pipeline.uid,
111+
"copy should create an instance with the same UID")
112+
assert(copied.getStages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
113+
"copy should handle extra stage params")
114+
}
115+
104116
test("PipelineModel.copy") {
105117
val hashingTF = new HashingTF()
106118
.setNumFeatures(100)
107-
val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
119+
val model = new PipelineModel("pipelineModel", Array[Transformer](hashingTF))
120+
.setParent(new Pipeline())
108121
val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
109-
require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
122+
123+
assert(copied.uid === model.uid,
124+
"copy should create an instance with the same UID")
125+
assert(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
110126
"copy should handle extra stage params")
127+
assert(copied.parent === model.parent,
128+
"copy should create an instance with the same parent")
111129
}
112130

113131
test("pipeline model constructors") {

0 commit comments

Comments
 (0)