-
Notifications
You must be signed in to change notification settings - Fork 834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: new LIME and KernelSHAP explainers #1077
Merged
Merged
Changes from all commits
Commits
Show all changes
87 commits
Select commit
Hold shift + click to select a range
830336e
Adding LassoRegression and LeastSquaresRegression, both supporting fi…
memoryz b5d6130
Speeding up lasso regression
memoryz a161fb0
Simpify the recurse function
memoryz 370dad8
Add code comment
memoryz 246a5a3
Fixing r-squared calculation; add loss to result.
memoryz ac3af96
WIP: LIME
memoryz 1080899
WIP: LIME tabular
memoryz c30c5cd
Unit tests: Simple regression model with continuous variables working
memoryz df878e3
Unit tests: Categorical variable working
memoryz 9cc5a7e
Outputting fitting metrics r2
memoryz 6e3234c
Vector lime implementation
memoryz 24e4b04
Adding unit tests for vector lime
memoryz f666ef8
Clean up and reorganize classes
memoryz d270f2b
LIME image sampler
memoryz f6d9a5a
WIP: Image LIME
memoryz be65511
Image LIME: unit test passing
memoryz 7ed4de8
Efficiency fix for ImageLIME
memoryz 2c22da5
Validate input schema
memoryz 14c47f4
More refactoring
memoryz ba64847
WIP: TextLIME
memoryz 069569f
Unit test for Text LIME
memoryz 267e3d1
Add support for binary type images
memoryz f8c8ab4
Reorganize classes
memoryz dd7e5c6
More restructuring
memoryz 6295659
renames
memoryz 10216e9
KernelSHAP sampler
memoryz 7de5756
WIP: TabularSHAP
memoryz cf861f7
Unit test for KernelSHAP tabular
memoryz c3e0a7c
KernelSHAP Vector Sampler
memoryz 06a3a5c
WIP: Vector SHAP
memoryz 4e6d0a2
Unit test for vector kernel shap
memoryz 1aaed01
Simplify sampler class structure
memoryz c34a9ec
Restructure the samplers
memoryz ce5c101
Unit tests for image and text samplers
memoryz 1091482
ImageSHAP and TextSHAP
memoryz 8d8efcd
ImageLIME unit test
memoryz 39fef38
Unit test for TextSHAP
memoryz 7cda288
Logging
memoryz 6ca7a50
pyspark layer
memoryz 12a89e7
change LocalExplainer to internal class
memoryz c9ad1ad
Bug fix
memoryz fcd0b61
Bug fix
memoryz a972dba
Bug fix for tabular LIME categorical features
memoryz 6dd2d77
Update copyright
memoryz caff75f
Explainers inherit from Transformer, implements readable/writable traits
memoryz 7b1634e
Performance: Repartition samples
memoryz 00fd239
JDK 1.8 compatible
memoryz c799722
Support multiple explain targets
memoryz d15f3be
Return coefficients as array of row vectors rather than matrix
memoryz 9800d02
Performance: join hint for LHS of the join
memoryz a6cea8a
Adding explainers unit tests to CI pipeline
memoryz bd74c17
Style fix
memoryz d0b0953
Reorganize the tests and add fuzzing test support
memoryz f772374
Unit test issues
memoryz a54692b
More unit test fixes
memoryz 1bbacfd
More unit test and style fix
memoryz e2cc08d
Style fix
memoryz bdcdd47
clean up
memoryz 155b4eb
unit test OOM issue
memoryz 4a910ea
oom issue
memoryz 89f59fa
debugging python test failure
memoryz 2a0e449
Rename test suites to avoid name conflicts with old classes
memoryz 9691bd2
Change backgroundData to DataFrameParam and remove unneeded python wr…
memoryz 032157d
Change to transformer fuzzing
memoryz 17fac69
Set mini batch size to 1 to prevent OOM in unit test
memoryz fccdc21
OOM in unit test
memoryz f151952
Excluding SerializationFuzzing for SHAP suites due to error caused by…
memoryz c245975
Addressing code review comments
memoryz a7504ae
Code review feedback
memoryz 6b86149
Code review feedback
memoryz 54e8b1e
code review comments
memoryz ba73b50
code review feedbacks
memoryz a6ae582
more...
memoryz 00aa703
more...
memoryz 75ee15f
sort
memoryz 9dc3ca8
Rename Spark vector imports
memoryz 5f85475
use string constants
memoryz c4a9c43
Change regression base to support sparse vector as well
memoryz a85b879
Clean up printlns
memoryz 77f0e3e
background dataframe should be mandatory.
memoryz 392ac1e
Extracting slicer function
memoryz 3106557
WIP: Rewrite sampler for kernel SHAP
memoryz 352defd
Rewrite tabular LIME sampler to support non-numerial types
memoryz bf10732
Add file header, fixing unit tests.
memoryz 8bc3ce3
Add header
memoryz df7f736
Add unit test to compare shap explainer with kernel explainer from ht…
memoryz 1364d30
Fixing unit test
memoryz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,8 +7,6 @@ import org.apache.spark.ml.linalg.{DenseVector, SparseVector} | |
import org.apache.spark.sql.Dataset | ||
import org.apache.spark.sql.types.StructType | ||
|
||
import scala.collection.mutable | ||
|
||
/** Contains methods for manipulating spark dataframes and datasets. */ | ||
object DatasetExtensions { | ||
|
||
|
@@ -20,8 +18,7 @@ object DatasetExtensions { | |
* @return The unused column name. | ||
*/ | ||
def withDerivativeCol(prefix: String): String = { | ||
val columnNamesSet = mutable.HashSet(df.columns: _*) | ||
findUnusedColumnName(prefix)(columnNamesSet) | ||
findUnusedColumnName(prefix)(df.columns.toSet) | ||
} | ||
|
||
/** Gets the column values as the given type. | ||
|
@@ -36,12 +33,12 @@ object DatasetExtensions { | |
/** Gets the spark sparse vector column. | ||
* @return The spark sparse vector column. | ||
*/ | ||
def getSVCol: String => Seq[SparseVector] = getColAs[SparseVector] _ | ||
def getSVCol: String => Seq[SparseVector] = getColAs[SparseVector] | ||
|
||
/** Gets the spark dense vector column. | ||
* @return The spark dense vector column. | ||
*/ | ||
def getDVCol: String => Seq[DenseVector] = getColAs[DenseVector] _ | ||
def getDVCol: String => Seq[DenseVector] = getColAs[DenseVector] | ||
} | ||
|
||
/** Finds an unused column name given initial column name and a list of existing column names. | ||
|
@@ -51,13 +48,8 @@ object DatasetExtensions { | |
* @return The unused column name. | ||
*/ | ||
def findUnusedColumnName(prefix: String)(columnNames: scala.collection.Set[String]): String = { | ||
var counter = 2 | ||
var unusedColumnName = prefix | ||
while (columnNames.contains(unusedColumnName)) { | ||
unusedColumnName += "_" + counter | ||
counter += 1 | ||
} | ||
unusedColumnName | ||
val stream = Iterator(prefix) ++ Iterator.from(1, 1).map(prefix + "_" + _) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 😎 |
||
stream.dropWhile(columnNames.contains).next() | ||
} | ||
|
||
def findUnusedColumnName(prefix: String, schema: StructType): String = { | ||
|
@@ -67,5 +59,4 @@ object DatasetExtensions { | |
def findUnusedColumnName(prefix: String, df: Dataset[_]): String = { | ||
findUnusedColumnName(prefix, df.schema) | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,10 +6,13 @@ package com.microsoft.ml.spark.core.utils | |
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.catalyst.expressions.GenericRow | ||
|
||
// This class currently has no usage. Should we just remove it? | ||
@deprecated("This is a copy of Row.merge function from Spark, which was marked deprecated.", "1.0.0-rc3") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes we can remove |
||
class RowUtils { | ||
|
||
//TODO Deprecate later | ||
def merge(rows: Row*): Row = { | ||
Row.merge() | ||
new GenericRow(rows.flatMap(_.toSeq).toArray) | ||
} | ||
} |
62 changes: 62 additions & 0 deletions
62
src/main/scala/com/microsoft/ml/spark/core/utils/SlicerFunctions.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.ml.spark.core.utils | ||
|
||
import org.apache.spark.injections.UDFUtils | ||
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType | ||
import org.apache.spark.ml.linalg.{Vector, Vectors} | ||
import org.apache.spark.sql.expressions.UserDefinedFunction | ||
import org.apache.spark.sql.types._ | ||
|
||
private[spark] object SlicerFunctions { | ||
private def slice[T](values: Int => T, indices: Seq[Int])(implicit num: Numeric[_]): Vector = { | ||
val n = num.asInstanceOf[Numeric[T]] | ||
Vectors.dense(indices.map(values.apply).map(n.toDouble).toArray) | ||
} | ||
|
||
private val DataTypeToNumericMap: Map[NumericType, Numeric[_]] = Map( | ||
FloatType -> implicitly[Numeric[Float]], | ||
DoubleType -> implicitly[Numeric[Double]], | ||
ByteType -> implicitly[Numeric[Byte]], | ||
ShortType -> implicitly[Numeric[Short]], | ||
IntegerType -> implicitly[Numeric[Int]], | ||
LongType -> implicitly[Numeric[Long]] | ||
) | ||
|
||
/** | ||
* A UDF that takes a vector, and a seq of indices. The function slices the given vector at given indices, | ||
* and returns the result in a Vector. | ||
*/ | ||
def vectorSlicer: UserDefinedFunction = { | ||
implicit val num: Numeric[_] = DataTypeToNumericMap(DoubleType) | ||
UDFUtils.oldUdf( | ||
(v: Vector, indices: Seq[Int]) => slice(v.apply, indices), | ||
VectorType | ||
) | ||
} | ||
|
||
/** | ||
* A UDF that takes an array of numeric types, and a seq of indices. | ||
* The function slices the given array at given indices, and returns the result in a Vector. | ||
*/ | ||
def arraySlicer(elementType: NumericType): UserDefinedFunction = { | ||
implicit val num: Numeric[_] = DataTypeToNumericMap(elementType) | ||
UDFUtils.oldUdf( | ||
(v: Seq[Any], indices: Seq[Int]) => slice(v.apply, indices), | ||
VectorType | ||
) | ||
} | ||
|
||
/** | ||
* A UDF that takes a map of integer keys and numeric values, and a seq of keys. | ||
* The function slices the given array at given keys, and returns the result in a Vector. | ||
*/ | ||
def mapSlicer(valueType: NumericType): UserDefinedFunction = { | ||
implicit val num: Numeric[_] = DataTypeToNumericMap(valueType) | ||
UDFUtils.oldUdf( | ||
(m: Map[Int, Any], indices: Seq[Int]) => slice(m.apply, indices), | ||
VectorType | ||
) | ||
} | ||
} |
36 changes: 36 additions & 0 deletions
36
src/main/scala/com/microsoft/ml/spark/explainers/BreezeUtils.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.ml.spark.explainers | ||
|
||
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV, DenseMatrix => BDM} | ||
import org.apache.spark.ml.linalg.{Vector, Vectors, Matrix, Matrices} | ||
|
||
object BreezeUtils { | ||
implicit class SparkVectorCanConvertToBreeze(sv: Vector) { | ||
def toBreeze: BDV[Double] = { | ||
BDV(sv.toArray) | ||
} | ||
} | ||
|
||
implicit class SparkMatrixCanConvertToBreeze(mat: Matrix) { | ||
def toBreeze: BDM[Double] = { | ||
BDM(mat.rowIter.map(_.toBreeze).toArray: _*) | ||
} | ||
} | ||
|
||
implicit class BreezeVectorCanConvertToSpark(bv: BV[Double]) { | ||
def toSpark: Vector = { | ||
bv match { | ||
case v: BDV[Double] => Vectors.dense(v.toArray) | ||
case v: BSV[Double] => Vectors.sparse(v.size, v.activeIterator.toSeq).compressed | ||
} | ||
} | ||
} | ||
|
||
implicit class BreezeMatrixCanConvertToSpark(bm: BDM[Double]) { | ||
def toSpark: Matrix = { | ||
Matrices.dense(bm.rows, bm.cols, bm.data) | ||
} | ||
} | ||
} |
64 changes: 64 additions & 0 deletions
64
src/main/scala/com/microsoft/ml/spark/explainers/FeatureStats.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.ml.spark.explainers | ||
|
||
import breeze.numerics.abs | ||
import breeze.stats.distributions.{RandBasis, Uniform} | ||
|
||
private[explainers] trait FeatureStats[T] { | ||
|
||
def getRandomState(instance: T)(implicit randBasis: RandBasis): Double | ||
|
||
def sample(state: Double): T | ||
|
||
def getDistance(instance: T, sample: T): Double | ||
} | ||
|
||
private[explainers] final case class ContinuousFeatureStats(stddev: Double) | ||
extends FeatureStats[Double] { | ||
override def getRandomState(instance: Double)(implicit randBasis: RandBasis): Double = { | ||
randBasis.gaussian(instance, this.stddev).sample | ||
} | ||
|
||
override def sample(state: Double): Double = { | ||
state | ||
} | ||
|
||
override def getDistance(instance: Double, sample: Double): Double = { | ||
if (this.stddev == 0d) { | ||
0d | ||
} else { | ||
// Normalize by stddev | ||
abs(sample - instance) / this.stddev | ||
} | ||
} | ||
} | ||
|
||
private[explainers] final case class DiscreteFeatureStats[V](freq: Map[V, Double]) | ||
extends FeatureStats[V] { | ||
|
||
/** | ||
* Returns the cumulative density function (CDF) of the given frequency table. | ||
*/ | ||
private def cdf[T](freq: Seq[(T, Double)]): Seq[(T, Double)] = { | ||
freq.map(_._1) zip freq.map(_._2).scanLeft(0d)(_ + _).drop(1) | ||
} | ||
|
||
private lazy val cdfTable: Seq[(V, Double)] = { | ||
val freq = this.freq.toSeq | ||
cdf(freq) | ||
} | ||
|
||
override def getRandomState(instance: V)(implicit randBasis: RandBasis): Double = { | ||
Uniform(0d, freq.values.sum).sample | ||
} | ||
|
||
override def sample(state: Double): V = { | ||
cdfTable.find(state <= _._2).get._1 | ||
} | ||
|
||
override def getDistance(instance: V, sample: V): Double = { | ||
if (instance == sample) 0d else 1d | ||
} | ||
} |
29 changes: 29 additions & 0 deletions
29
src/main/scala/com/microsoft/ml/spark/explainers/ImageExplainer.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.ml.spark.explainers | ||
|
||
import com.microsoft.ml.spark.lime.{HasCellSize, HasModifier, SuperpixelTransformer} | ||
import org.apache.spark.ml.param.shared.HasInputCol | ||
import org.apache.spark.sql.DataFrame | ||
|
||
/** | ||
* Common preprocessing logic for image explainers | ||
*/ | ||
trait ImageExplainer { | ||
self: LocalExplainer | ||
with HasCellSize | ||
with HasModifier | ||
with HasInputCol | ||
with HasSuperpixelCol => | ||
|
||
protected override def preprocess(df: DataFrame): DataFrame = { | ||
// Dataframe with new column containing superpixels (Array[Cluster]) for each row (image to explain) | ||
new SuperpixelTransformer() | ||
.setCellSize(getCellSize) | ||
.setModifier(getModifier) | ||
.setInputCol(getInputCol) | ||
.setOutputCol(getSuperpixelCol) | ||
.transform(df) | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TY!