Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v}
val pca = new feature.PCA(k = $(k))
val pcaModel = pca.fit(input)
copyValues(new PCAModel(uid, pcaModel).setParent(this))
copyValues(new PCAModel(uid, pcaModel.pc).setParent(this))
}

override def transformSchema(schema: StructType): StructType = {
Expand All @@ -99,18 +99,17 @@ object PCA extends DefaultParamsReadable[PCA] {
/**
* :: Experimental ::
* Model fitted by [[PCA]].
*
* @param pc A principal components Matrix. Each column is one principal component.
*/
@Experimental
class PCAModel private[ml] (
override val uid: String,
pcaModel: feature.PCAModel)
val pc: DenseMatrix)
extends Model[PCAModel] with PCAParams with MLWritable {

import PCAModel._

/** a principal components Matrix. Each column is one principal component. */
val pc: DenseMatrix = pcaModel.pc

/** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)

Expand All @@ -124,6 +123,7 @@ class PCAModel private[ml] (
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
val pcaModel = new feature.PCAModel($(k), pc)
val pcaOp = udf { pcaModel.transform _ }
dataset.withColumn($(outputCol), pcaOp(col($(inputCol))))
}
Expand All @@ -139,7 +139,7 @@ class PCAModel private[ml] (
}

override def copy(extra: ParamMap): PCAModel = {
val copied = new PCAModel(uid, pcaModel)
val copied = new PCAModel(uid, pc)
copyValues(copied, extra).setParent(parent)
}

Expand All @@ -152,11 +152,11 @@ object PCAModel extends MLReadable[PCAModel] {

private[PCAModel] class PCAModelWriter(instance: PCAModel) extends MLWriter {

private case class Data(k: Int, pc: DenseMatrix)
private case class Data(pc: DenseMatrix)

override protected def saveImpl(path: String): Unit = {
DefaultParamsWriter.saveMetadata(instance, path, sc)
val data = Data(instance.getK, instance.pc)
val data = Data(instance.pc)
val dataPath = new Path(path, "data").toString
sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
}
Expand All @@ -169,11 +169,10 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
val Row(k: Int, pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
.select("k", "pc")
val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath)
.select("pc")
.head()
val oldModel = new feature.PCAModel(k, pc)
val model = new PCAModel(metadata.uid, oldModel)
val model = new PCAModel(metadata.uid, pc)
DefaultParamsReader.getAndSetParams(model, metadata)
model
}
Expand Down
31 changes: 13 additions & 18 deletions mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
test("params") {
ParamsSuite.checkParams(new PCA)
val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix]
val model = new PCAModel("pca", new OldPCAModel(2, mat))
val model = new PCAModel("pca", mat)
ParamsSuite.checkParams(model)
}

Expand Down Expand Up @@ -66,23 +66,18 @@ class PCASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
}
}

test("read/write") {
test("PCA read/write") {
val t = new PCA()
.setInputCol("myInputCol")
.setOutputCol("myOutputCol")
.setK(3)
testDefaultReadWrite(t)
}

def checkModelData(model1: PCAModel, model2: PCAModel): Unit = {
assert(model1.pc === model2.pc)
}
val allParams: Map[String, Any] = Map(
"k" -> 3,
"inputCol" -> "features",
"outputCol" -> "pca_features"
)
val data = Seq(
(0.0, Vectors.sparse(5, Seq((1, 1.0), (3, 7.0)))),
(1.0, Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0)),
(2.0, Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0))
)
val df = sqlContext.createDataFrame(data).toDF("id", "features")
val pca = new PCA().setK(3)
testEstimatorAndModelReadWrite(pca, df, allParams, checkModelData)
test("PCAModel read/write") {
val instance = new PCAModel("myPCAModel",
Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix])
val newInstance = testDefaultReadWrite(instance)
assert(newInstance.pc === instance.pc)
}
}