Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import breeze.linalg.{DenseVector => BDV}

import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD

Expand Down Expand Up @@ -161,6 +162,8 @@ private object IDF {
@Experimental
class IDFModel private[spark] (val idf: Vector) extends Serializable {

private var bcIdf: Option[Broadcast[Vector]] = None

/**
* Transforms term frequency (TF) vectors to TF-IDF vectors.
*
Expand All @@ -172,8 +175,12 @@ class IDFModel private[spark] (val idf: Vector) extends Serializable {
* @return an RDD of TF-IDF vectors
*/
def transform(dataset: RDD[Vector]): RDD[Vector] = {
val bcIdf = dataset.context.broadcast(idf)
dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(bcIdf.value, v)))
bcIdf match {
case None => bcIdf = Some(dataset.context.broadcast(idf))
case _ =>
}
val lclBcIdf = bcIdf
dataset.mapPartitions(iter => iter.map(v => IDFModel.transform(lclBcIdf.get.value, v)))
}

/**
Expand Down