-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-25765][ML] Add training cost to BisectingKMeans summary #22764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6460fe7
5919d3c
de5cadd
b7a6b51
0c74a09
e44adff
a3247a6
4454412
8ef04db
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,11 +41,12 @@ import org.apache.spark.sql.{Row, SparkSession} | |
| @Since("1.6.0") | ||
| class BisectingKMeansModel private[clustering] ( | ||
| private[clustering] val root: ClusteringTreeNode, | ||
| @Since("2.4.0") val distanceMeasure: String | ||
| @Since("2.4.0") val distanceMeasure: String, | ||
| @Since("3.0.0") val trainingCost: Double | ||
| ) extends Serializable with Saveable with Logging { | ||
|
|
||
| @Since("1.6.0") | ||
| def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN) | ||
| def this(root: ClusteringTreeNode) = this(root, DistanceMeasure.EUCLIDEAN, 0.0) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On the other hand, we did preserve this old constructor, and that's fine to keep. The other issue I see here is that the cost is 0, when the cost is really unknown.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, because this is public, so users may rely on it. The idea is that this is indeed a "new feature" (previously is was not accessible) and we are not guaranteeing new features in the MLLib API. I just followed the same approach which was used for KMeans. |
||
|
|
||
| private val distanceMeasureInstance: DistanceMeasure = | ||
| DistanceMeasure.decodeFromString(distanceMeasure) | ||
|
|
@@ -109,10 +110,10 @@ class BisectingKMeansModel private[clustering] ( | |
|
|
||
| @Since("2.0.0") | ||
| override def save(sc: SparkContext, path: String): Unit = { | ||
| BisectingKMeansModel.SaveLoadV2_0.save(sc, this, path) | ||
| BisectingKMeansModel.SaveLoadV3_0.save(sc, this, path) | ||
| } | ||
|
|
||
| override protected def formatVersion: String = "2.0" | ||
| override protected def formatVersion: String = "3.0" | ||
| } | ||
|
|
||
| @Since("2.0.0") | ||
|
|
@@ -128,11 +129,15 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { | |
| case (SaveLoadV2_0.thisClassName, SaveLoadV2_0.thisFormatVersion) => | ||
| val model = SaveLoadV2_0.load(sc, path) | ||
| model | ||
| case (SaveLoadV3_0.thisClassName, SaveLoadV3_0.thisFormatVersion) => | ||
| val model = SaveLoadV3_0.load(sc, path) | ||
| model | ||
| case _ => throw new Exception( | ||
| s"BisectingKMeansModel.load did not recognize model with (className, format version):" + | ||
| s"($loadedClassName, $formatVersion). Supported:\n" + | ||
| s" (${SaveLoadV1_0.thisClassName}, ${SaveLoadV1_0.thisClassName}\n" + | ||
| s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})") | ||
| s" (${SaveLoadV2_0.thisClassName}, ${SaveLoadV2_0.thisClassName})\n" + | ||
| s" (${SaveLoadV3_0.thisClassName}, ${SaveLoadV3_0.thisClassName})") | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -195,7 +200,8 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { | |
| val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") | ||
| val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap | ||
| val rootNode = buildTree(rootId, nodes) | ||
| new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN) | ||
| val totalCost = rootNode.leafNodes.map(_.cost).sum | ||
| new BisectingKMeansModel(rootNode, DistanceMeasure.EUCLIDEAN, totalCost) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -231,7 +237,46 @@ object BisectingKMeansModel extends Loader[BisectingKMeansModel] { | |
| val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") | ||
| val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap | ||
| val rootNode = buildTree(rootId, nodes) | ||
| new BisectingKMeansModel(rootNode, distanceMeasure) | ||
| val totalCost = rootNode.leafNodes.map(_.cost).sum | ||
| new BisectingKMeansModel(rootNode, distanceMeasure, totalCost) | ||
| } | ||
| } | ||
|
|
||
| private[clustering] object SaveLoadV3_0 { | ||
| private[clustering] val thisFormatVersion = "3.0" | ||
|
|
||
| private[clustering] | ||
| val thisClassName = "org.apache.spark.mllib.clustering.BisectingKMeansModel" | ||
|
|
||
| def save(sc: SparkContext, model: BisectingKMeansModel, path: String): Unit = { | ||
| val spark = SparkSession.builder().sparkContext(sc).getOrCreate() | ||
| val metadata = compact(render( | ||
| ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) | ||
| ~ ("rootId" -> model.root.index) ~ ("distanceMeasure" -> model.distanceMeasure) | ||
| ~ ("trainingCost" -> model.trainingCost))) | ||
| sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) | ||
|
|
||
| val data = getNodes(model.root).map(node => Data(node.index, node.size, | ||
| node.centerWithNorm.vector, node.centerWithNorm.norm, node.cost, node.height, | ||
| node.children.map(_.index))) | ||
| spark.createDataFrame(data).write.parquet(Loader.dataPath(path)) | ||
| } | ||
|
|
||
| def load(sc: SparkContext, path: String): BisectingKMeansModel = { | ||
| implicit val formats: DefaultFormats = DefaultFormats | ||
| val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) | ||
| assert(className == thisClassName) | ||
| assert(formatVersion == thisFormatVersion) | ||
| val rootId = (metadata \ "rootId").extract[Int] | ||
| val distanceMeasure = (metadata \ "distanceMeasure").extract[String] | ||
| val trainingCost = (metadata \ "trainingCost").extract[Double] | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hmm, can this read old model from previous version?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do other models have this problem? I was told that this change just follows what we did for other models before.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you all for the comments and sorry for the late answer. Just a couple of notes on your comments @WeichenXu123 (I may be missing something, so please correct me if I am wrong):
@cloud-fan Yes, let me link the PR for KMeans doing the same, which is: #20629. Just a final comment which I hope clarifies which is the source of the confusion here and the reason of the above comments by @viirya and @WeichenXu123: Hope this clarifies (sorry for being so verbose).
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I meant that can it read old model from previous versions, not that this model can read from previous versions. In other words, when reading a previous model without "trainingCost" in metadata, can this line work well? val trainingCost = (metadata \ "trainingCost").extract[Double]
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @mgaido91
I am suspicious of this line in When loading an old version spark(e.g. spark 2.3.1) saved
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @WeichenXu123 I have explained it in and #22764 (comment). If you don't agree or believe on what I said you can try it. A model saved in 2.3.1 will have "1.0" as version. So this code is not run. Every model from 2.4.0 on, will be saved with "2.0" as version, so it will have this stored. As mentioned, please notice that Hope this clarifies. Thanks.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry, I am more confusing...
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes @WeichenXu123 , you're right, that line is a bug. Thanks for noticing it. Anyway, that is going to be addressed in another PR and it is not (strictly) related to this one. The other option, as I mentioned, is that if we agree that this doesn't need to be restored after model persistence, we can just ignore it in save/load.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. After #22790 merged, I think this PR can work. |
||
| val spark = SparkSession.builder().sparkContext(sc).getOrCreate() | ||
| val rows = spark.read.parquet(Loader.dataPath(path)) | ||
| Loader.checkSchema[Data](rows.schema) | ||
| val data = rows.select("index", "size", "center", "norm", "cost", "height", "children") | ||
| val nodes = data.rdd.map(Data.apply).collect().map(d => (d.index, d)).toMap | ||
| val rootNode = buildTree(rootId, nodes) | ||
| new BisectingKMeansModel(rootNode, distanceMeasure, trainingCost) | ||
| } | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's a big deal for 3.0, but we lose the constructor without the new param. That's probably OK as the summary kind of needs this value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this constructor is private so I don't think it is a problem to avoid having the previous one.