From 59c5611c0d1897a2446f27fec9e9c7ae0db0f4a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20K=C3=B6se?= Date: Sat, 7 May 2016 01:54:22 +0300 Subject: [PATCH 1/2] locale support to StopWords --- .../spark/ml/feature/StopWordsRemover.scala | 30 +++++++++++++++---- .../ml/feature/StopWordsRemoverSuite.scala | 1 + python/pyspark/ml/feature.py | 30 +++++++++++++++---- python/pyspark/ml/tests.py | 7 +++++ 4 files changed, 57 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 11864cb8f439..5975013c8049 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,9 +17,11 @@ package org.apache.spark.ml.feature +import java.util.Locale + import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer -import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} @@ -73,7 +75,22 @@ class StopWordsRemover(override val uid: String) /** @group getParam */ def getCaseSensitive: Boolean = $(caseSensitive) - setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), caseSensitive -> false) + /** + * Locale for doing a case sensitive comparison + * Default: English locale ("en") + * @group param + */ + val locale: Param[String] = new Param[String](this, "locale", + "locale for doing a case sensitive comparison") + + /** @group setParam */ + def setLocale(value: String): this.type = set(locale, value) + + /** @group getParam */ + def getLocale: String = $(locale) + + setDefault(stopWords -> StopWordsRemover.loadDefaultStopWords("english"), + caseSensitive -> false, locale -> "en") @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { @@ -81,14 +98,14 @@ class StopWordsRemover(override val uid: String) val t = if ($(caseSensitive)) { val stopWordsSet = $(stopWords).toSet udf { terms: Seq[String] => - terms.filter(s => !stopWordsSet.contains(s)) + terms.filterNot(stopWordsSet.contains) } } else { - // TODO: support user locale (SPARK-15064) - val toLower = (s: String) => if (s != null) s.toLowerCase else s + val loadedLocale = StopWordsRemover.loadLocale($(locale)) + val toLower = (s: String) => if (s != null) s.toLowerCase(loadedLocale) else s val lowerStopWords = $(stopWords).map(toLower(_)).toSet udf { terms: Seq[String] => - terms.filter(s => !lowerStopWords.contains(toLower(s))) + terms.filterNot(term => lowerStopWords.contains(toLower(term))) } } val metadata = outputSchema($(outputCol)).metadata @@ -109,6 +126,7 @@ class StopWordsRemover(override val uid: String) object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { private[feature] + def loadLocale(value : String) = new Locale(value) val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german", "hungarian", "italian", "norwegian", "portuguese", "russian", "spanish", "swedish", "turkish") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala index 8e7e000fbc11..71684543df86 100755 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StopWordsRemoverSuite.scala @@ -98,6 +98,7 @@ class StopWordsRemoverSuite .setInputCol("raw") .setOutputCol("filtered") .setStopWords(stopWords) + .setLocale("tr") val dataSet = sqlContext.createDataFrame(Seq( (Seq("acaba", "ama", "biri"), Seq()), (Seq("hep", "her", "scala"), Seq("scala")) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index d2989fa4cdb0..6685a379bbaf 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1736,25 +1736,31 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl typeConverter=TypeConverters.toListString) caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " + "comparison over the stop words", typeConverter=TypeConverters.toBoolean) + locale = Param(Params._dummy(), "locale", "locale for doing a case sensitive comparison", + typeConverter=TypeConverters.toString) @keyword_only - def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def __init__(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False, locale="en"): """ - __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + __init__(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=false, locale="en") """ super(StopWordsRemover, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StopWordsRemover", self.uid) self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), - caseSensitive=False) + caseSensitive=False, locale="en") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=False): + def setParams(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=False, locale="en"): """ - setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) + setParams(self, inputCol=None, outputCol=None, stopWords=None, + caseSensitive=false, locale="en") Sets params for this StopWordRemover. """ kwargs = self.setParams._input_kwargs @@ -1788,6 +1794,20 @@ def getCaseSensitive(self): """ return self.getOrDefault(self.caseSensitive) + @since("2.0.0") + def setLocale(self, value): + """ + Sets the value of :py:attr:`locale`. + """ + return self._set(locale=value) + + @since("2.0.0") + def getLocale(self): + """ + Gets the value of :py:attr:`locale`. + """ + return self.getOrDefault(self.locale) + @staticmethod @since("2.0.0") def loadDefaultStopWords(language): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index ad1631fb5baa..88be3377dd5b 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -424,6 +424,13 @@ def test_stopwordsremover(self): self.assertEqual(stopWordRemover.getStopWords(), stopwords) transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, []) + # with locale + stopwords = ["BİRİ"] + dataset = sqlContext.createDataFrame([Row(input=["biri"])]) + stopWordRemover.setStopWords(stopwords).setLocale("tr") + self.assertEqual(stopWordRemover.getStopWords(), stopwords) + transformedDF = stopWordRemover.transform(dataset) + self.assertEqual(transformedDF.head().output, []) def test_count_vectorizer_with_binary(self): sqlContext = SQLContext(self.sc) From 13f99d9ba175cf5cf5d1c11a8a3023b86d8e7cbd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Burak=20K=C3=B6se?= Date: Fri, 13 May 2016 09:30:12 +0300 Subject: [PATCH 2/2] add return type and reference --- .../scala/org/apache/spark/ml/feature/StopWordsRemover.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 5975013c8049..adcb2cca9211 100755 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -78,6 +78,7 @@ class StopWordsRemover(override val uid: String) /** * Locale for doing a case sensitive comparison * Default: English locale ("en") + * @see [[http://www.localeplanet.com/java/]] * @group param */ val locale: Param[String] = new Param[String](this, "locale", @@ -126,7 +127,7 @@ class StopWordsRemover(override val uid: String) object StopWordsRemover extends DefaultParamsReadable[StopWordsRemover] { private[feature] - def loadLocale(value : String) = new Locale(value) + def loadLocale(value : String): java.util.Locale = new Locale(value) val supportedLanguages = Set("danish", "dutch", "english", "finnish", "french", "german", "hungarian", "italian", "norwegian", "portuguese", "russian", "spanish", "swedish", "turkish")