|
17 | 17 |
|
18 | 18 | package org.apache.spark.ml.feature |
19 | 19 |
|
20 | | -import org.apache.spark.annotation.Experimental |
21 | 20 | import org.apache.spark.SparkContext |
| 21 | +import org.apache.spark.annotation.Experimental |
22 | 22 | import org.apache.spark.ml.{Estimator, Model} |
23 | 23 | import org.apache.spark.ml.param._ |
24 | 24 | import org.apache.spark.ml.param.shared._ |
25 | 25 | import org.apache.spark.ml.util.{Identifiable, SchemaUtils} |
26 | 26 | import org.apache.spark.mllib.feature |
27 | | -import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} |
28 | | -import org.apache.spark.mllib.linalg.BLAS._ |
29 | | -import org.apache.spark.sql.DataFrame |
| 27 | +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} |
| 28 | +import org.apache.spark.sql.{DataFrame, SQLContext} |
30 | 29 | import org.apache.spark.sql.functions._ |
31 | | -import org.apache.spark.sql.SQLContext |
32 | 30 | import org.apache.spark.sql.types._ |
33 | 31 |
|
34 | 32 | /** |
@@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] |
148 | 146 | @Experimental |
149 | 147 | class Word2VecModel private[ml] ( |
150 | 148 | override val uid: String, |
151 | | - wordVectors: feature.Word2VecModel) |
| 149 | + @transient wordVectors: feature.Word2VecModel) |
152 | 150 | extends Model[Word2VecModel] with Word2VecBase { |
153 | 151 |
|
154 | | - |
155 | 152 | /** |
156 | 153 | * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and |
157 | 154 | * and the vector the DenseVector that it is mapped to. |
@@ -197,22 +194,23 @@ class Word2VecModel private[ml] ( |
197 | 194 | */ |
198 | 195 | override def transform(dataset: DataFrame): DataFrame = { |
199 | 196 | transformSchema(dataset.schema, logging = true) |
200 | | - val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) |
| 197 | + val vectors = wordVectors.getVectors |
| 198 | + .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) |
| 199 | + .map(identity) // mapValues doesn't return a serializable map (SI-7005) |
| 200 | + val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors) |
| 201 | + val d = $(vectorSize) |
201 | 202 | val word2Vec = udf { sentence: Seq[String] => |
202 | 203 | if (sentence.size == 0) { |
203 | | - Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double]) |
| 204 | + Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) |
204 | 205 | } else { |
205 | | - val cum = Vectors.zeros($(vectorSize)) |
206 | | - val model = bWordVectors.value.getVectors |
207 | | - for (word <- sentence) { |
208 | | - if (model.contains(word)) { |
209 | | - axpy(1.0, bWordVectors.value.transform(word), cum) |
210 | | - } else { |
211 | | - // pass words which not belong to model |
| 206 | + val sum = Vectors.zeros(d) |
| 207 | + sentence.foreach { word => |
| 208 | + bVectors.value.get(word).foreach { v => |
| 209 | + BLAS.axpy(1.0, v, sum) |
212 | 210 | } |
213 | 211 | } |
214 | | - scal(1.0 / sentence.size, cum) |
215 | | - cum |
| 212 | + BLAS.scal(1.0 / sentence.size, sum) |
| 213 | + sum |
216 | 214 | } |
217 | 215 | } |
218 | 216 | dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) |
|
0 commit comments