From 0c7e552bae0fad20dd269476924f0b3d1256b67a Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Wed, 21 Sep 2016 09:46:58 +0100 Subject: [PATCH 1/3] To match Tokenizer and for compatibility with Word2Vec, output a nullable string array type in NGram --- mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 4463aea0097e..d67660ca612e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -68,7 +68,7 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) s"Input type must be ArrayType(StringType) but got $inputType.") } - override protected def outputDataType: DataType = new ArrayType(StringType, false) + override protected def outputDataType: DataType = new ArrayType(StringType, true) } @Since("1.6.0") From 43352373d949e97f1c6bff65fc00fbd4a86e1db5 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 22 Sep 2016 11:29:38 +0100 Subject: [PATCH 2/3] Let Word2Vec accept non-nullable string array input in additional to currently supported nullable string array type --- mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala | 2 +- .../src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index d67660ca612e..4463aea0097e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -68,7 +68,7 @@ class NGram @Since("1.5.0") (@Since("1.5.0") override val uid: String) s"Input type must be ArrayType(StringType) but got $inputType.") } - override protected def outputDataType: DataType = new ArrayType(StringType, true) + override protected def outputDataType: DataType = new ArrayType(StringType, false) } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 14c05123c62e..d53f3df514df 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -108,7 +108,8 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) + val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false)) + SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } } From 76af236ab516026d43bdbce679b49a3629108ef3 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Fri, 23 Sep 2016 10:13:46 +0100 Subject: [PATCH 3/3] Add unit test --- .../spark/ml/feature/Word2VecSuite.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 0b441f8b8081..613cc3d60b22 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -207,5 +207,26 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul val newInstance = testDefaultReadWrite(instance) assert(newInstance.getVectors.collect() === instance.getVectors.collect()) } + + test("Word2Vec works with input that is non-nullable (NGram)") { + val spark = this.spark + import spark.implicits._ + + val sentence = "a q s t q s t b b b s t m s t m q " + val docDF = sc.parallelize(Seq(sentence, sentence)).map(_.split(" ")).toDF("text") + + val ngram = new NGram().setN(2).setInputCol("text").setOutputCol("ngrams") + val ngramDF = ngram.transform(docDF) + + val model = new Word2Vec() + .setVectorSize(2) + .setInputCol("ngrams") + .setOutputCol("result") + .fit(ngramDF) + + // Just test that this transformation succeeds + model.transform(ngramDF).collect() + } + }