diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala new file mode 100644 index 000000000000..c6057c7f837b --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -0,0 +1,127 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import scala.collection.mutable.ArrayBuilder + +import org.apache.spark.annotation.Experimental +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.rdd.RDD + +/** + * :: Experimental :: + * Chi Squared selector model. + * + * @param selectedFeatures list of indices to select (filter). Must be ordered asc + */ +@Experimental +class ChiSqSelectorModel (val selectedFeatures: Array[Int]) extends VectorTransformer { + + require(isSorted(selectedFeatures), "Array has to be sorted asc") + + protected def isSorted(array: Array[Int]): Boolean = { + var i = 1 + while (i < array.length) { + if (array(i) < array(i-1)) return false + i += 1 + } + true + } + + /** + * Applies transformation on a vector. + * + * @param vector vector to be transformed. + * @return transformed vector. + */ + override def transform(vector: Vector): Vector = { + compress(vector, selectedFeatures) + } + + /** + * Returns a vector with features filtered. + * Preserves the order of filtered features the same as their indices are stored. + * Might be moved to Vector as .slice + * @param features vector + * @param filterIndices indices of features to filter, must be ordered asc + */ + private def compress(features: Vector, filterIndices: Array[Int]): Vector = { + features match { + case SparseVector(size, indices, values) => + val newSize = filterIndices.length + val newValues = new ArrayBuilder.ofDouble + val newIndices = new ArrayBuilder.ofInt + var i = 0 + var j = 0 + var indicesIdx = 0 + var filterIndicesIdx = 0 + while (i < indices.length && j < filterIndices.length) { + indicesIdx = indices(i) + filterIndicesIdx = filterIndices(j) + if (indicesIdx == filterIndicesIdx) { + newIndices += j + newValues += values(i) + j += 1 + i += 1 + } else { + if (indicesIdx > filterIndicesIdx) { + j += 1 + } else { + i += 1 + } + } + } + // TODO: Sparse representation might be ineffective if (newSize ~= newValues.size) + Vectors.sparse(newSize, newIndices.result(), newValues.result()) + case DenseVector(values) => + val values = features.toArray + Vectors.dense(filterIndices.map(i => values(i))) + case other => + throw new UnsupportedOperationException( + s"Only sparse and dense vectors are supported but got ${other.getClass}.") + } + } +} + +/** + * :: Experimental :: + * Creates a ChiSquared feature selector. + * @param numTopFeatures number of features that selector will select + * (ordered by statistic value descending) + */ +@Experimental +class ChiSqSelector (val numTopFeatures: Int) { + + /** + * Returns a ChiSquared feature selector. + * + * @param data an `RDD[LabeledPoint]` containing the labeled dataset with categorical features. + * Real-valued features will be treated as categorical for each distinct value. + * Apply feature discretizer before using this function. + */ + def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { + val indices = Statistics.chiSqTest(data) + .zipWithIndex.sortBy { case (res, _) => -res.statistic } + .take(numTopFeatures) + .map { case (_, indices) => indices } + .sorted + new ChiSqSelectorModel(indices) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala new file mode 100644 index 000000000000..747f5914598e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.feature + +import org.scalatest.FunSuite + +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.util.MLlibTestSparkContext + +class ChiSqSelectorSuite extends FunSuite with MLlibTestSparkContext { + + /* + * Contingency tables + * feature0 = {8.0, 0.0} + * class 0 1 2 + * 8.0||1|0|1| + * 0.0||0|2|0| + * + * feature1 = {7.0, 9.0} + * class 0 1 2 + * 7.0||1|0|0| + * 9.0||0|2|1| + * + * feature2 = {0.0, 6.0, 8.0, 5.0} + * class 0 1 2 + * 0.0||1|0|0| + * 6.0||0|1|0| + * 8.0||0|1|0| + * 5.0||0|0|1| + * + * Use chi-squared calculator from Internet + */ + + test("ChiSqSelector transform test (sparse & dense vector)") { + val labeledDiscreteData = sc.parallelize( + Seq(LabeledPoint(0.0, Vectors.sparse(3, Array((0, 8.0), (1, 7.0)))), + LabeledPoint(1.0, Vectors.sparse(3, Array((1, 9.0), (2, 6.0)))), + LabeledPoint(1.0, Vectors.dense(Array(0.0, 9.0, 8.0))), + LabeledPoint(2.0, Vectors.dense(Array(8.0, 9.0, 5.0)))), 2) + val preFilteredData = + Set(LabeledPoint(0.0, Vectors.dense(Array(0.0))), + LabeledPoint(1.0, Vectors.dense(Array(6.0))), + LabeledPoint(1.0, Vectors.dense(Array(8.0))), + LabeledPoint(2.0, Vectors.dense(Array(5.0)))) + val model = new ChiSqSelector(1).fit(labeledDiscreteData) + val filteredData = labeledDiscreteData.map { lp => + LabeledPoint(lp.label, model.transform(lp.features)) + }.collect().toSet + assert(filteredData == preFilteredData) + } +}