From c0c0174e30a3e4473e0d5c4872c5b33144e5b015 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 29 Sep 2016 09:31:01 -0400 Subject: [PATCH 1/2] Partial revert of #15277 to instead sort and store input to model rather than require sorted input --- .../spark/ml/feature/ChiSqSelector.scala | 2 +- .../spark/mllib/feature/ChiSqSelector.scala | 22 +++++++++---------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 9c131a41850c..d0385e220e1e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -193,7 +193,7 @@ final class ChiSqSelectorModel private[ml] ( import ChiSqSelectorModel._ - /** list of indices to select (filter). Must be ordered asc */ + /** list of indices to select (filter). */ @Since("1.6.0") val selectedFeatures: Array[Int] = chiSqSelector.selectedFeatures diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index 706ce78f260a..c305b36278e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -35,14 +35,15 @@ import org.apache.spark.sql.{Row, SparkSession} /** * Chi Squared selector model. * - * @param selectedFeatures list of indices to select (filter). Must be ordered asc + * @param selectedFeatures list of indices to select (filter). */ @Since("1.3.0") class ChiSqSelectorModel @Since("1.3.0") ( @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { - require(isSorted(selectedFeatures), "Array has to be sorted asc") + private val filterIndices = selectedFeatures.sorted + @deprecated("not intended for subclasses to use", "2.1.0") protected def isSorted(array: Array[Int]): Boolean = { var i = 1 val len = array.length @@ -61,7 +62,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( */ @Since("1.3.0") override def transform(vector: Vector): Vector = { - compress(vector, selectedFeatures) + compress(vector) } /** @@ -69,9 +70,8 @@ class ChiSqSelectorModel @Since("1.3.0") ( * Preserves the order of filtered features the same as their indices are stored. * Might be moved to Vector as .slice * @param features vector - * @param filterIndices indices of features to filter, must be ordered asc */ - private def compress(features: Vector, filterIndices: Array[Int]): Vector = { + private def compress(features: Vector): Vector = { features match { case SparseVector(size, indices, values) => val newSize = filterIndices.length @@ -230,23 +230,23 @@ class ChiSqSelector @Since("2.1.0") () extends Serializable { */ @Since("1.3.0") def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { - val chiSqTestResult = Statistics.chiSqTest(data) + val chiSqTestResult = Statistics.chiSqTest(data).zipWithIndex val features = selectorType match { case ChiSqSelector.KBest => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take(numTopFeatures) case ChiSqSelector.Percentile => - chiSqTestResult.zipWithIndex + chiSqTestResult .sortBy { case (res, _) => -res.statistic } .take((chiSqTestResult.length * percentile).toInt) case ChiSqSelector.FPR => - chiSqTestResult.zipWithIndex - .filter{ case (res, _) => res.pValue < alpha } + chiSqTestResult + .filter { case (res, _) => res.pValue < alpha } case errorType => throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") } - val indices = features.map { case (_, indices) => indices }.sorted + val indices = features.map { case (_, index) => index } new ChiSqSelectorModel(indices) } } From 02a2b40e0df63371f0c9cb2b7741a7323bc5c163 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Thu, 29 Sep 2016 11:23:01 -0400 Subject: [PATCH 2/2] Fix selectedFeatures doc --- python/pyspark/ml/feature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 12a13849dc9b..64b21caa616e 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2705,7 +2705,7 @@ class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): @since("2.0.0") def selectedFeatures(self): """ - List of indices to select (filter). Must be ordered asc. + List of indices to select (filter). """ return self._call_java("selectedFeatures")