-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17017][MLLIB][ML] add a chiSquare Selector based on False Positive Rate (FPR) test #14597
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2adebe8
04053ca
7623563
3d6aecb
026ac85
5305709
1e8d83a
85a17dd
61b71c8
d7b2892
6699396
b8986b5
5c2e44c
0d3967a
1dc6a8e
9908871
bbccac7
c35bcf1
e8f03ed
ec74dde
6398f4c
6cc4c92
1d2f67f
6220dd5
ce3f8fb
88d2143
24f26f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,27 +32,21 @@ import org.apache.spark.rdd.RDD | |
| import org.apache.spark.SparkContext | ||
| import org.apache.spark.sql.{Row, SparkSession} | ||
|
|
||
| @Since("2.1.0") | ||
| private[spark] object ChiSqSelectorType extends Enumeration { | ||
| type SelectorType = Value | ||
| val KBest, Percentile, FPR = Value | ||
| } | ||
|
|
||
| /** | ||
| * Chi Squared selector model. | ||
| * | ||
| * @param selectedFeatures list of indices to select (filter). Must be ordered asc | ||
| * @param selectedFeatures list of indices to select (filter). | ||
| */ | ||
| @Since("1.3.0") | ||
| class ChiSqSelectorModel @Since("1.3.0") ( | ||
| @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable { | ||
|
|
||
| require(isSorted(selectedFeatures), "Array has to be sorted asc") | ||
|
|
||
| protected def isSorted(array: Array[Int]): Boolean = { | ||
| var i = 1 | ||
| val len = array.length | ||
| while (i < len) { | ||
| if (array(i) < array(i-1)) return false | ||
| i += 1 | ||
| } | ||
| true | ||
| } | ||
|
|
||
| /** | ||
| * Applies transformation on a vector. | ||
| * | ||
|
|
@@ -69,21 +63,22 @@ class ChiSqSelectorModel @Since("1.3.0") ( | |
| * Preserves the order of filtered features the same as their indices are stored. | ||
| * Might be moved to Vector as .slice | ||
| * @param features vector | ||
| * @param filterIndices indices of features to filter, must be ordered asc | ||
| * @param filterIndices indices of features to filter | ||
| */ | ||
| private def compress(features: Vector, filterIndices: Array[Int]): Vector = { | ||
| val orderedIndices = filterIndices.sorted | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can be computed once and stored, rather than store unsorted indices and resort them. |
||
| features match { | ||
| case SparseVector(size, indices, values) => | ||
| val newSize = filterIndices.length | ||
| val newSize = orderedIndices.length | ||
| val newValues = new ArrayBuilder.ofDouble | ||
| val newIndices = new ArrayBuilder.ofInt | ||
| var i = 0 | ||
| var j = 0 | ||
| var indicesIdx = 0 | ||
| var filterIndicesIdx = 0 | ||
| while (i < indices.length && j < filterIndices.length) { | ||
| while (i < indices.length && j < orderedIndices.length) { | ||
| indicesIdx = indices(i) | ||
| filterIndicesIdx = filterIndices(j) | ||
| filterIndicesIdx = orderedIndices(j) | ||
| if (indicesIdx == filterIndicesIdx) { | ||
| newIndices += j | ||
| newValues += values(i) | ||
|
|
@@ -101,7 +96,7 @@ class ChiSqSelectorModel @Since("1.3.0") ( | |
| Vectors.sparse(newSize, newIndices.result(), newValues.result()) | ||
| case DenseVector(values) => | ||
| val values = features.toArray | ||
| Vectors.dense(filterIndices.map(i => values(i))) | ||
| Vectors.dense(orderedIndices.map(i => values(i))) | ||
| case other => | ||
| throw new UnsupportedOperationException( | ||
| s"Only sparse and dense vectors are supported but got ${other.getClass}.") | ||
|
|
@@ -171,14 +166,57 @@ object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] { | |
|
|
||
| /** | ||
| * Creates a ChiSquared feature selector. | ||
| * @param numTopFeatures number of features that selector will select | ||
| * (ordered by statistic value descending) | ||
| * Note that if the number of features is less than numTopFeatures, | ||
| * then this will select all features. | ||
| * The selector supports three selection methods: `KBest`, `Percentile` and `FPR`. | ||
| * `KBest` chooses the `k` top features according to a chi-squared test. | ||
| * `Percentile` is similar but chooses a fraction of all features instead of a fixed number. | ||
| * `FPR` chooses all features whose false positive rate meets some threshold. | ||
| * By default, the selection method is `KBest`, the default number of top features is 50. | ||
| * User can use setNumTopFeatures, setPercentile and setAlpha to set different selection methods. | ||
| */ | ||
| @Since("1.3.0") | ||
| class ChiSqSelector @Since("1.3.0") ( | ||
| @Since("1.3.0") val numTopFeatures: Int) extends Serializable { | ||
| class ChiSqSelector @Since("2.1.0") () extends Serializable { | ||
| var numTopFeatures: Int = 50 | ||
| var percentile: Double = 0.1 | ||
| var alpha: Double = 0.05 | ||
| var selectorType = ChiSqSelectorType.KBest | ||
|
|
||
| /** | ||
| * The is the same to call this() and setNumTopFeatures(numTopFeatures) | ||
| */ | ||
| @Since("1.3.0") | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The existing constructor should still have javadoc maybe pointing to the setNumTopFeatures method to say that's the effect it has |
||
| def this(numTopFeatures: Int) { | ||
| this() | ||
| this.numTopFeatures = numTopFeatures | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should call
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not necessary, because the default selectorType is KBest
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. It seemed to split the logic a bit here but it's not bad. The default behavior needs to be documented then. Now there is effectively a default numTopFeatures. |
||
| } | ||
|
|
||
| @Since("1.6.0") | ||
| def setNumTopFeatures(value: Int): this.type = { | ||
| numTopFeatures = value | ||
| selectorType = ChiSqSelectorType.KBest | ||
| this | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setPercentile(value: Double): this.type = { | ||
| require(0.0 <= value && value <= 1.0, "Percentile must be in [0,1]") | ||
| percentile = value | ||
| selectorType = ChiSqSelectorType.Percentile | ||
| this | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setAlpha(value: Double): this.type = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it need a
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. require is added, thanks |
||
| require(0.0 <= value && value <= 1.0, "Alpha must be in [0,1]") | ||
| alpha = value | ||
| selectorType = ChiSqSelectorType.FPR | ||
| this | ||
| } | ||
|
|
||
| @Since("2.1.0") | ||
| def setChiSqSelectorType(value: ChiSqSelectorType.Value): this.type = { | ||
| selectorType = value | ||
| this | ||
| } | ||
|
|
||
| /** | ||
| * Returns a ChiSquared feature selector. | ||
|
|
@@ -189,11 +227,20 @@ class ChiSqSelector @Since("1.3.0") ( | |
| */ | ||
| @Since("1.3.0") | ||
| def fit(data: RDD[LabeledPoint]): ChiSqSelectorModel = { | ||
| val indices = Statistics.chiSqTest(data) | ||
| val chiSqTestResult = Statistics.chiSqTest(data) | ||
| .zipWithIndex.sortBy { case (res, _) => -res.statistic } | ||
| .take(numTopFeatures) | ||
| .map { case (_, indices) => indices } | ||
| .sorted | ||
| val features = selectorType match { | ||
| case ChiSqSelectorType.KBest => chiSqTestResult | ||
| .take(numTopFeatures) | ||
| case ChiSqSelectorType.Percentile => chiSqTestResult | ||
| .take((chiSqTestResult.length * percentile).toInt) | ||
| case ChiSqSelectorType.FPR => chiSqTestResult | ||
| .filter{ case (res, _) => res.pValue < alpha } | ||
| case errorType => | ||
| throw new IllegalStateException(s"Unknown ChiSqSelector Type: $errorType") | ||
| } | ||
| val indices = features.map { case (_, indices) => indices } | ||
| new ChiSqSelectorModel(indices) | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somewhere, there should be a few brief sentences describing how the types relate to the parameters to this class.