Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -559,13 +559,26 @@ trait Params extends Identifiable with Serializable {

/**
* Copies param values from this instance to another instance for params shared by them.
* @param to the target instance
* @param extra extra params to be copied
*
* This handles default Params and explicitly set Params separately.
* Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are
* copied from and to [[paramMap]].
* Warning: This implicitly assumes that this [[Params]] instance and the target instance
* share the same set of default Params.
*
* @param to the target instance, which should work with the same set of default Params as this
* source instance
* @param extra extra params to be copied to the target's [[paramMap]]
* @return the target instance with param values copied
*/
protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
val map = extractParamMap(extra)
val map = paramMap ++ extra
params.foreach { param =>
// copy default Params
if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
}
// copy explicitly set Params
if (map.contains(param) && to.hasParam(param.name)) {
to.set(param.name, map(param))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ class ParamsSuite extends SparkFunSuite {
val inArray = ParamValidators.inArray[Int](Array(1, 2))
assert(inArray(1) && inArray(2) && !inArray(0))
}

test("Params.copyValues") {
val t = new TestParams()
val t2 = t.copy(ParamMap.empty)
assert(!t2.isSet(t2.maxIter))
val t3 = t.copy(ParamMap(t.maxIter -> 20))
assert(t3.isSet(t3.maxIter))
}
}

object ParamsSuite extends SparkFunSuite {
Expand Down