Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new LIME and KernelSHAP explainers #1077

Merged
merged 87 commits into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
87 commits
Select commit Hold shift + click to select a range
830336e
Adding LassoRegression and LeastSquaresRegression, both supporting fi…
memoryz May 22, 2021
b5d6130
Speeding up lasso regression
memoryz May 23, 2021
a161fb0
Simpify the recurse function
memoryz May 23, 2021
370dad8
Add code comment
memoryz May 23, 2021
246a5a3
Fixing r-squared calculation; add loss to result.
memoryz May 23, 2021
ac3af96
WIP: LIME
memoryz May 26, 2021
1080899
WIP: LIME tabular
memoryz May 27, 2021
c30c5cd
Unit tests: Simple regression model with continuous variables working
memoryz May 27, 2021
df878e3
Unit tests: Categorical variable working
memoryz May 27, 2021
9cc5a7e
Outputting fitting metrics r2
memoryz May 27, 2021
6e3234c
Vector lime implementation
memoryz May 27, 2021
24e4b04
Adding unit tests for vector lime
memoryz May 27, 2021
f666ef8
Clean up and reorganize classes
memoryz May 28, 2021
d270f2b
LIME image sampler
memoryz May 28, 2021
f6d9a5a
WIP: Image LIME
memoryz May 29, 2021
be65511
Image LIME: unit test passing
memoryz May 29, 2021
7ed4de8
Efficiency fix for ImageLIME
memoryz May 30, 2021
2c22da5
Validate input schema
memoryz May 30, 2021
14c47f4
More refactoring
memoryz May 30, 2021
ba64847
WIP: TextLIME
memoryz May 30, 2021
069569f
Unit test for Text LIME
memoryz May 30, 2021
267e3d1
Add support for binary type images
memoryz May 30, 2021
f8c8ab4
Reorganize classes
memoryz May 30, 2021
dd7e5c6
More restructuring
memoryz May 30, 2021
6295659
renames
memoryz May 31, 2021
10216e9
KernelSHAP sampler
memoryz May 31, 2021
7de5756
WIP: TabularSHAP
memoryz May 31, 2021
cf861f7
Unit test for KernelSHAP tabular
memoryz Jun 1, 2021
c3e0a7c
KernelSHAP Vector Sampler
memoryz Jun 1, 2021
06a3a5c
WIP: Vector SHAP
memoryz Jun 1, 2021
4e6d0a2
Unit test for vector kernel shap
memoryz Jun 1, 2021
1aaed01
Simplify sampler class structure
memoryz Jun 2, 2021
c34a9ec
Restructure the samplers
memoryz Jun 2, 2021
ce5c101
Unit tests for image and text samplers
memoryz Jun 2, 2021
1091482
ImageSHAP and TextSHAP
memoryz Jun 2, 2021
8d8efcd
ImageLIME unit test
memoryz Jun 2, 2021
39fef38
Unit test for TextSHAP
memoryz Jun 2, 2021
7cda288
Logging
memoryz Jun 2, 2021
6ca7a50
pyspark layer
memoryz Jun 4, 2021
12a89e7
change LocalExplainer to internal class
memoryz Jun 4, 2021
c9ad1ad
Bug fix
memoryz Jun 4, 2021
fcd0b61
Bug fix
memoryz Jun 4, 2021
a972dba
Bug fix for tabular LIME categorical features
memoryz Jun 4, 2021
6dd2d77
Update copyright
memoryz Jun 4, 2021
caff75f
Explainers inherit from Transformer, implements readable/writable traits
memoryz Jun 6, 2021
7b1634e
Performance: Repartition samples
memoryz Jun 7, 2021
00fd239
JDK 1.8 compatible
memoryz Jun 7, 2021
c799722
Support multiple explain targets
memoryz Jun 7, 2021
d15f3be
Return coefficients as array of row vectors rather than matrix
memoryz Jun 8, 2021
9800d02
Performance: join hint for LHS of the join
memoryz Jun 9, 2021
a6cea8a
Adding explainers unit tests to CI pipeline
memoryz Jun 9, 2021
bd74c17
Style fix
memoryz Jun 9, 2021
d0b0953
Reorganize the tests and add fuzzing test support
memoryz Jun 9, 2021
f772374
Unit test issues
memoryz Jun 10, 2021
a54692b
More unit test fixes
memoryz Jun 10, 2021
1bbacfd
More unit test and style fix
memoryz Jun 10, 2021
e2cc08d
Style fix
memoryz Jun 10, 2021
bdcdd47
clean up
memoryz Jun 10, 2021
155b4eb
unit test OOM issue
memoryz Jun 10, 2021
4a910ea
oom issue
memoryz Jun 10, 2021
89f59fa
debugging python test failure
memoryz Jun 10, 2021
2a0e449
Rename test suites to avoid name conflicts with old classes
memoryz Jun 10, 2021
9691bd2
Change backgroundData to DataFrameParam and remove unneeded python wr…
memoryz Jun 10, 2021
032157d
Change to transformer fuzzing
memoryz Jun 10, 2021
17fac69
Set mini batch size to 1 to prevent OOM in unit test
memoryz Jun 10, 2021
fccdc21
OOM in unit test
memoryz Jun 11, 2021
f151952
Excluding SerializationFuzzing for SHAP suites due to error caused by…
memoryz Jun 11, 2021
c245975
Addressing code review comments
memoryz Jun 11, 2021
a7504ae
Code review feedback
memoryz Jun 11, 2021
6b86149
Code review feedback
memoryz Jun 12, 2021
54e8b1e
code review comments
memoryz Jun 12, 2021
ba73b50
code review feedbacks
memoryz Jun 12, 2021
a6ae582
more...
memoryz Jun 12, 2021
00aa703
more...
memoryz Jun 12, 2021
75ee15f
sort
memoryz Jun 12, 2021
9dc3ca8
Rename Spark vector imports
memoryz Jun 14, 2021
5f85475
use string constants
memoryz Jun 15, 2021
c4a9c43
Change regression base to support sparse vector as well
memoryz Jun 15, 2021
a85b879
Clean up printlns
memoryz Jun 15, 2021
77f0e3e
background dataframe should be mandatory.
memoryz Jun 15, 2021
392ac1e
Extracting slicer function
memoryz Jun 15, 2021
3106557
WIP: Rewrite sampler for kernel SHAP
memoryz Jun 16, 2021
352defd
Rewrite tabular LIME sampler to support non-numerial types
memoryz Jun 17, 2021
bf10732
Add file header, fixing unit tests.
memoryz Jun 17, 2021
8bc3ce3
Add header
memoryz Jun 17, 2021
df7f736
Add unit test to compare shap explainer with kernel explainer from ht…
memoryz Jun 17, 2021
1364d30
Fixing unit test
memoryz Jun 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ jobs:
PACKAGE: "core"
downloader:
PACKAGE: "downloader"
explainers:
PACKAGE: "explainers"
featurize:
PACKAGE: "featurize"
image:
Expand Down
10 changes: 7 additions & 3 deletions src/main/scala/com/microsoft/ml/spark/codegen/Wrappable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ package com.microsoft.ml.spark.codegen
import java.lang.reflect.ParameterizedType
import java.nio.charset.StandardCharsets
import java.nio.file.Files

import com.microsoft.ml.spark.core.env.FileUtilities
import com.microsoft.ml.spark.core.serialize.ComplexParam
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.commons.lang.StringEscapeUtils

import scala.collection.Iterator.iterate
import scala.reflect.ClassTag
Expand Down Expand Up @@ -119,7 +119,7 @@ trait PythonWrappable extends BaseWrappable {
}
}

protected lazy val pyInheritedClasses =
protected lazy val pyInheritedClasses: Seq[String] =
Seq("ComplexParamsMixin", "JavaMLReadable", "JavaMLWritable", pyObjectBaseClass)

// TODO add default values
Expand All @@ -134,10 +134,14 @@ trait PythonWrappable extends BaseWrappable {
|""".stripMargin
}

private def escape(raw: String): String = {
StringEscapeUtils.escapeJava(raw)
}

protected lazy val pyParamsDefinitions: String = {
this.params.map { p =>
val typeConverterString = getParamInfo(p).pyTypeConverter.map(", typeConverter=" + _).getOrElse("")
s"""|${p.name} = Param(Params._dummy(), "${p.name}", "${p.doc}"$typeConverterString)
s"""|${p.name} = Param(Params._dummy(), "${p.name}", "${escape(p.doc)}"$typeConverterString)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TY!

|""".stripMargin
}.mkString("\n")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType

import scala.collection.mutable

/** Contains methods for manipulating spark dataframes and datasets. */
object DatasetExtensions {

Expand All @@ -20,8 +18,7 @@ object DatasetExtensions {
* @return The unused column name.
*/
def withDerivativeCol(prefix: String): String = {
val columnNamesSet = mutable.HashSet(df.columns: _*)
findUnusedColumnName(prefix)(columnNamesSet)
findUnusedColumnName(prefix)(df.columns.toSet)
}

/** Gets the column values as the given type.
Expand All @@ -36,12 +33,12 @@ object DatasetExtensions {
/** Gets the spark sparse vector column.
* @return The spark sparse vector column.
*/
def getSVCol: String => Seq[SparseVector] = getColAs[SparseVector] _
def getSVCol: String => Seq[SparseVector] = getColAs[SparseVector]

/** Gets the spark dense vector column.
* @return The spark dense vector column.
*/
def getDVCol: String => Seq[DenseVector] = getColAs[DenseVector] _
def getDVCol: String => Seq[DenseVector] = getColAs[DenseVector]
}

/** Finds an unused column name given initial column name and a list of existing column names.
Expand All @@ -51,13 +48,8 @@ object DatasetExtensions {
* @return The unused column name.
*/
def findUnusedColumnName(prefix: String)(columnNames: scala.collection.Set[String]): String = {
var counter = 2
var unusedColumnName = prefix
while (columnNames.contains(unusedColumnName)) {
unusedColumnName += "_" + counter
counter += 1
}
unusedColumnName
val stream = Iterator(prefix) ++ Iterator.from(1, 1).map(prefix + "_" + _)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😎

stream.dropWhile(columnNames.contains).next()
}

def findUnusedColumnName(prefix: String, schema: StructType): String = {
Expand All @@ -67,5 +59,4 @@ object DatasetExtensions {
def findUnusedColumnName(prefix: String, df: Dataset[_]): String = {
findUnusedColumnName(prefix, df.schema)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ object ModelEquality {
}

def assertEqual(m1: PipelineStage, m2: PipelineStage): Unit = {
assert(m1.getClass === m2.getClass)
assert(m1.getClass === m2.getClass, s"${m1.getClass} != ${m2.getClass}, assertion failed.")
val m1Params = m1.extractParamMap().toSeq.map(pp => pp.param.name).toSet
val m2Params = m2.extractParamMap().toSeq.map(pp => pp.param.name).toSet
assert(m1Params === m2Params)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@ package com.microsoft.ml.spark.core.utils
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRow

// This class currently has no usage. Should we just remove it?
@deprecated("This is a copy of Row.merge function from Spark, which was marked deprecated.", "1.0.0-rc3")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes we can remove

class RowUtils {

//TODO Deprecate later
def merge(rows: Row*): Row = {
Row.merge()
new GenericRow(rows.flatMap(_.toSeq).toArray)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.core.utils

import org.apache.spark.injections.UDFUtils
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.types._

private[spark] object SlicerFunctions {
private def slice[T](values: Int => T, indices: Seq[Int])(implicit num: Numeric[_]): Vector = {
val n = num.asInstanceOf[Numeric[T]]
Vectors.dense(indices.map(values.apply).map(n.toDouble).toArray)
}

private val DataTypeToNumericMap: Map[NumericType, Numeric[_]] = Map(
FloatType -> implicitly[Numeric[Float]],
DoubleType -> implicitly[Numeric[Double]],
ByteType -> implicitly[Numeric[Byte]],
ShortType -> implicitly[Numeric[Short]],
IntegerType -> implicitly[Numeric[Int]],
LongType -> implicitly[Numeric[Long]]
)

/**
* A UDF that takes a vector, and a seq of indices. The function slices the given vector at given indices,
* and returns the result in a Vector.
*/
def vectorSlicer: UserDefinedFunction = {
implicit val num: Numeric[_] = DataTypeToNumericMap(DoubleType)
UDFUtils.oldUdf(
(v: Vector, indices: Seq[Int]) => slice(v.apply, indices),
VectorType
)
}

/**
* A UDF that takes an array of numeric types, and a seq of indices.
* The function slices the given array at given indices, and returns the result in a Vector.
*/
def arraySlicer(elementType: NumericType): UserDefinedFunction = {
implicit val num: Numeric[_] = DataTypeToNumericMap(elementType)
UDFUtils.oldUdf(
(v: Seq[Any], indices: Seq[Int]) => slice(v.apply, indices),
VectorType
)
}

/**
* A UDF that takes a map of integer keys and numeric values, and a seq of keys.
* The function slices the given array at given keys, and returns the result in a Vector.
*/
def mapSlicer(valueType: NumericType): UserDefinedFunction = {
implicit val num: Numeric[_] = DataTypeToNumericMap(valueType)
UDFUtils.oldUdf(
(m: Map[Int, Any], indices: Seq[Int]) => slice(m.apply, indices),
VectorType
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.explainers

import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV, DenseMatrix => BDM}
import org.apache.spark.ml.linalg.{Vector, Vectors, Matrix, Matrices}

object BreezeUtils {
implicit class SparkVectorCanConvertToBreeze(sv: Vector) {
def toBreeze: BDV[Double] = {
BDV(sv.toArray)
}
}

implicit class SparkMatrixCanConvertToBreeze(mat: Matrix) {
def toBreeze: BDM[Double] = {
BDM(mat.rowIter.map(_.toBreeze).toArray: _*)
}
}

implicit class BreezeVectorCanConvertToSpark(bv: BV[Double]) {
def toSpark: Vector = {
bv match {
case v: BDV[Double] => Vectors.dense(v.toArray)
case v: BSV[Double] => Vectors.sparse(v.size, v.activeIterator.toSeq).compressed
}
}
}

implicit class BreezeMatrixCanConvertToSpark(bm: BDM[Double]) {
def toSpark: Matrix = {
Matrices.dense(bm.rows, bm.cols, bm.data)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.explainers

import breeze.numerics.abs
import breeze.stats.distributions.{RandBasis, Uniform}

private[explainers] trait FeatureStats[T] {

def getRandomState(instance: T)(implicit randBasis: RandBasis): Double

def sample(state: Double): T

def getDistance(instance: T, sample: T): Double
}

private[explainers] final case class ContinuousFeatureStats(stddev: Double)
extends FeatureStats[Double] {
override def getRandomState(instance: Double)(implicit randBasis: RandBasis): Double = {
randBasis.gaussian(instance, this.stddev).sample
}

override def sample(state: Double): Double = {
state
}

override def getDistance(instance: Double, sample: Double): Double = {
if (this.stddev == 0d) {
0d
} else {
// Normalize by stddev
abs(sample - instance) / this.stddev
}
}
}

private[explainers] final case class DiscreteFeatureStats[V](freq: Map[V, Double])
extends FeatureStats[V] {

/**
* Returns the cumulative density function (CDF) of the given frequency table.
*/
private def cdf[T](freq: Seq[(T, Double)]): Seq[(T, Double)] = {
freq.map(_._1) zip freq.map(_._2).scanLeft(0d)(_ + _).drop(1)
}

private lazy val cdfTable: Seq[(V, Double)] = {
val freq = this.freq.toSeq
cdf(freq)
}

override def getRandomState(instance: V)(implicit randBasis: RandBasis): Double = {
Uniform(0d, freq.values.sum).sample
}

override def sample(state: Double): V = {
cdfTable.find(state <= _._2).get._1
}

override def getDistance(instance: V, sample: V): Double = {
if (instance == sample) 0d else 1d
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.ml.spark.explainers

import com.microsoft.ml.spark.lime.{HasCellSize, HasModifier, SuperpixelTransformer}
import org.apache.spark.ml.param.shared.HasInputCol
import org.apache.spark.sql.DataFrame

/**
* Common preprocessing logic for image explainers
*/
trait ImageExplainer {
self: LocalExplainer
with HasCellSize
with HasModifier
with HasInputCol
with HasSuperpixelCol =>

protected override def preprocess(df: DataFrame): DataFrame = {
// Dataframe with new column containing superpixels (Array[Cluster]) for each row (image to explain)
new SuperpixelTransformer()
.setCellSize(getCellSize)
.setModifier(getModifier)
.setInputCol(getInputCol)
.setOutputCol(getSuperpixelCol)
.transform(df)
}
}
Loading