Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: [DistributionBalanceMeasure] Add implementation + unit tests for custom reference distribution #1885

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@ import breeze.stats.distributions.ChiSquared
import com.microsoft.azure.synapse.ml.codegen.Wrappable
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging
import com.microsoft.azure.synapse.ml.param.ArrayMapParam
import org.apache.spark.ml.param._
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

import java.util
import scala.collection.JavaConverters._
import scala.language.postfixOps

/** This transformer computes data balance measures based on a reference distribution.
Expand Down Expand Up @@ -56,6 +59,27 @@ class DistributionBalanceMeasure(override val uid: String)

def setFeatureNameCol(value: String): this.type = set(featureNameCol, value)

val referenceDistribution = new ArrayMapParam(
this,
"referenceDistribution",
"An ordered list of reference distributions that correspond to each of the sensitive columns."
)

val emptyReferenceDistribution: Array[Map[String, Double]] = Array.empty

def getReferenceDistribution: Array[Map[String, Double]] =
if (isDefined(referenceDistribution))
$(referenceDistribution).map(_.mapValues(_.asInstanceOf[Double]).map(identity))
else emptyReferenceDistribution

def setReferenceDistribution(value: Array[Map[String, Double]]): this.type =
set(referenceDistribution, value.map(_.mapValues(_.asInstanceOf[Any])))

def setReferenceDistribution(value: util.ArrayList[util.HashMap[String, Double]]): this.type = {
val arrayMap = value.asScala.toArray.map(_.asScala.toMap.mapValues(_.asInstanceOf[Any]))
set(referenceDistribution, arrayMap)
}

setDefault(
featureNameCol -> "FeatureName",
outputCol -> "DistributionBalanceMeasure"
Expand All @@ -68,6 +92,15 @@ class DistributionBalanceMeasure(override val uid: String)
}
}

private val customDistribution: Map[String, Double] => String => Double = {
dist: Map[String, Double] => {
// NOTE: If the custom distribution doesn't have the col value, return a default probability of 0
// This assumes that the reference distribution does not contain the col value at all
s: String =>
dist.getOrElse(s, 0d)
}
}

override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
validateSchema(dataset.schema)
Expand All @@ -89,30 +122,30 @@ class DistributionBalanceMeasure(override val uid: String)
if (getVerbose)
featureStats.cache.show(numRows = 20, truncate = false) //scalastyle:ignore magic.number

// TODO (for v2): Introduce a referenceDistribution function param for user to override the uniform distribution
val referenceDistribution = uniformDistribution

df.unpersist
calculateDistributionMeasures(featureStats, featureProbCol, featureCountCol, numRows, referenceDistribution)
calculateDistributionMeasures(featureStats, featureProbCol, featureCountCol, numRows)
})
}

private def calculateDistributionMeasures(featureStats: DataFrame,
obsFeatureProbCol: String,
obsFeatureCountCol: String,
numRows: Double,
referenceDistribution: Int => String => Double): DataFrame = {
val distributionMeasures = getSensitiveCols.map {
sensitiveCol =>
numRows: Double): DataFrame = {
val distributionMeasures = getSensitiveCols.zipWithIndex.map {
case (sensitiveCol, i) =>
val observed = featureStats
.groupBy(sensitiveCol)
.agg(sum(obsFeatureProbCol).alias(obsFeatureProbCol), sum(obsFeatureCountCol).alias(obsFeatureCountCol))

val numFeatures = observed.count.toInt
val refDistFunc = udf(referenceDistribution(numFeatures))
val refFeatureProbCol = DatasetExtensions.findUnusedColumnName("refFeatureProb", featureStats.schema)
val refFeatureCountCol = DatasetExtensions.findUnusedColumnName("refFeatureCount", featureStats.schema)

val refDist: String => Double =
if (!isDefined(referenceDistribution) || getReferenceDistribution(i).isEmpty) uniformDistribution(numFeatures)
else customDistribution(getReferenceDistribution(i))
val refDistFunc = udf(refDist)

val observedWithRef = observed
.withColumn(refFeatureProbCol, refDistFunc(col(sensitiveCol)))
.withColumn(refFeatureCountCol, refDistFunc(col(sensitiveCol)) * lit(numRows))
Expand Down Expand Up @@ -146,6 +179,15 @@ class DistributionBalanceMeasure(override val uid: String)
Nil
)
}

override def validateSchema(schema: StructType): Unit = {
super.validateSchema(schema)

if (isDefined(referenceDistribution) && getReferenceDistribution.length != getSensitiveCols.length) {
throw new Exception("The reference distribution must have the same length and order as the sensitive columns: "
+ getSensitiveCols.mkString(", "))
}
}
}

object DistributionBalanceMeasure extends ComplexParamsReadable[DistributionBalanceMeasure]
Expand Down Expand Up @@ -212,23 +254,32 @@ private[exploratory] case class DistributionMetrics(numFeatures: Int,
}

// Calculates Pearson's chi-squared statistic
def chiSquaredTestStatistic: Column =
sum(pow(col(obsFeatureCountCol) - col(refFeatureCountCol), 2) / col(refFeatureCountCol))
def chiSquaredTestStatistic: Column = sum(
// If expected is zero and observed is not zero, the test assumes observed is impossible so Chi^2 value becomes +inf
when(col(refFeatureCountCol) === 0 && col(obsFeatureCountCol) =!= 0, lit(Double.PositiveInfinity))
.otherwise(pow(col(obsFeatureCountCol) - col(refFeatureCountCol), 2) / col(refFeatureCountCol)))

// Calculates left-tailed p-value from degrees of freedom and chi-squared test statistic
def chiSquaredPValue: Column = {
val degOfFreedom = numFeatures - 1
val scoreCol = chiSquaredTestStatistic
val chiSqPValueUdf = udf({
score: Double =>
1d - ChiSquared(degOfFreedom).cdf(score)
})
val chiSqPValueUdf = udf(
(score: Double) => score match {
// limit of CDF as x approaches +inf is 1 (https://en.wikipedia.org/wiki/Cumulative_distribution_function)
case Double.PositiveInfinity => 1d
case _ => 1 - ChiSquared(degOfFreedom).cdf(score)
}
)
chiSqPValueUdf(scoreCol)
}

private def entropy(distA: Column, distB: Option[Column] = None): Column = {
if (distB.isDefined) {
sum(distA * log(distA / distB.get))
// Using same cases as scipy (https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html)
val entropies = when(distA === 0d && distB.get >= 0d, lit(0d))
.when(distA > 0d && distB.get > 0d, distA * log(distA / distB.get))
.otherwise(lit(Double.PositiveInfinity))
sum(entropies)
} else {
sum(distA * log(distA)) * -1d
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,11 @@ case class AggregateMetricsCalculator(featureProbabilities: Array[Double], epsil
}
}

case class DistributionMetricsCalculator(obsFeatureProbabilities: Array[Double],
case class DistributionMetricsCalculator(refFeatureProbabilities: Array[Double],
refFeatureCounts: Array[Double],
obsFeatureProbabilities: Array[Double],
obsFeatureCounts: Array[Double],
numRows: Double) {
val numFeatures: Double = obsFeatureProbabilities.length
val refFeatureProbabilities: Array[Double] = Array.fill(numFeatures.toInt)(1d / numFeatures)

numFeatures: Double) {
val absDiffObsRef: Array[Double] = (obsFeatureProbabilities, refFeatureProbabilities).zipped.map((a, b) => abs(a - b))

val klDivergence: Double = entropy(obsFeatureProbabilities, Some(refFeatureProbabilities))
Expand All @@ -126,16 +125,22 @@ case class DistributionMetricsCalculator(obsFeatureProbabilities: Array[Double],
val infNormDistance: Double = absDiffObsRef.max
val totalVariationDistance: Double = 0.5d * absDiffObsRef.sum
val wassersteinDistance: Double = absDiffObsRef.sum / absDiffObsRef.length
val chiSquaredTestStatistic: Double = {
val refFeatureCount = numRows / numFeatures
obsFeatureCounts.map(o => pow(o - refFeatureCount, 2) / refFeatureCount).sum
val chiSquaredTestStatistic: Double = (obsFeatureCounts, refFeatureCounts).zipped.map((a, b) => pow(a - b, 2) / b).sum
val chiSquaredPValue: Double = chiSquaredTestStatistic match {
// limit of CDF as x approaches +inf is 1 (https://en.wikipedia.org/wiki/Cumulative_distribution_function)
case Double.PositiveInfinity => 1
case _ => 1 - ChiSquared(numFeatures - 1).cdf(chiSquaredTestStatistic)
}
val chiSquaredPValue: Double = 1 - ChiSquared(numFeatures - 1).cdf(chiSquaredTestStatistic)

def entropy(distA: Array[Double], distB: Option[Array[Double]] = None): Double = {
if (distB.isDefined) {
val logQuotient = (distA, distB.get).zipped.map((a, b) => log(a / b))
(distA, logQuotient).zipped.map(_ * _).sum
(distA, distB.get).zipped.map((a, b) =>
// Using cases from scipy.special.rel_entr, which scipy.stats.entropy directly calls
// https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.rel_entr.html
if (a == 0.0 && b >= 0.0) 0.0
else if (a > 0.0 && b > 0) a * log(a / b)
else Double.PositiveInfinity
).sum
} else {
-1d * distA.map(x => x * log(x)).sum
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,10 @@ class DistributionBalanceMeasureSuite extends DataBalanceTestBase with Transform

private object ExpectedFeature1 {
// Values were computed using:
// val CALC =
// DistributionMetricsCalculator(expectedFeature1.map(_._1), expectedFeature1.map(_._2), sensitiveFeaturesDf.count)
// val (numRows, numFeatures) = (sensitiveFeaturesDf.count.toDouble, expectedFeature1.length)
// val (obsProbs, obsCounts) = expectedFeature1.unzip
// val (refProbs, refCounts) = Array.fill(numFeatures.toInt)(numFeatures).map(n => (1d / n, numRows / n)).unzip
// val CALC = DistributionMetricsCalculator(refProbs, refCounts, obsProbs, obsCounts, numFeatures)
val KLDIVERGENCE = 0.03775534151008829
val JSDISTANCE = 0.09785224086736323
val INFNORMDISTANCE = 0.1111111111111111
Expand Down Expand Up @@ -83,14 +85,16 @@ class DistributionBalanceMeasureSuite extends DataBalanceTestBase with Transform

private object ExpectedFeature2 {
// Values were computed using:
// val CALC =
// DistributionMetricsCalculator(expectedFeature2.map(_._1), expectedFeature2.map(_._2), sensitiveFeaturesDf.count)
// val (numRows, numFeatures) = (sensitiveFeaturesDf.count.toDouble, expectedFeature2.length)
// val (obsProbs, obsCounts) = expectedFeature2.unzip
// val (refProbs, refCounts) = Array.fill(numFeatures.toInt)(numFeatures).map(n => (1d / n, numRows / n)).unzip
// val CALC = DistributionMetricsCalculator(refProbs, refCounts, obsProbs, obsCounts, numFeatures)
val KLDIVERGENCE = 0.07551068302017659
val JSDISTANCE = 0.14172745151398888
val INFNORMDISTANCE = 0.1388888888888889
val TOTALVARIATIONDISTANCE = 0.16666666666666666
val WASSERSTEINDISTANCE = 0.08333333333333333
val CHISQUAREDTESTSTATISTIC = 1.222222222222222
val CHISQUAREDTESTSTATISTIC = 1.2222222222222223
val CHISQUAREDPVALUE = 0.7476795872877147
}

Expand All @@ -105,4 +109,139 @@ class DistributionBalanceMeasureSuite extends DataBalanceTestBase with Transform
assert(actual(CHISQUAREDTESTSTATISTIC) === expected.CHISQUAREDTESTSTATISTIC)
assert(actual(CHISQUAREDPVALUE) === expected.CHISQUAREDPVALUE)
}

// For each feature in sensitiveFeaturesDf (["Gender", "Ethnicity"]), need to specify its corresponding distribution
private def customDistribution: Array[Map[String, Double]] = Array(
// Index 0: Gender (all unique values included)
Map("Male" -> 0.25, "Female" -> 0.4, "Other" -> 0.35),
// Index 1: Ethnicity ('Other' value purposefully left out, which signals a probability of 0.0)
Map("Asian" -> 1/3d, "White" -> 1/3d, "Black" -> 1/3d)
)

test("DistributionBalanceMeasure can use a custom reference distribution for multiple cols") {
val df = distributionBalanceMeasure
.setReferenceDistribution(customDistribution)
.transform(sensitiveFeaturesDf)

df.show(truncate = false)
df.printSchema()
}

test("DistributionBalanceMeasure can use a custom distribution for one col and uniform for another") {
val customDist = customDistribution
// Keep custom distribution for Gender (index 0), and use uniform distribution for Ethnicity (index 1)
// Specifying empty map defaults to the uniform distribution
customDist.update(1, Map())

val df = distributionBalanceMeasure
.setReferenceDistribution(customDist)
.transform(sensitiveFeaturesDf)

df.show(truncate = false)
df.printSchema()
}

test("DistributionBalanceMeasure expects the custom distribution to be the same length as sensitive columns") {
val emptyDist: Array[Map[String, Double]] = Array.empty
assertThrows[Exception] {
distributionBalanceMeasure
.setReferenceDistribution(emptyDist)
.transform(sensitiveFeaturesDf)
}

val mismatchedLenDist = Array(Map("ColA" -> 0.25))
assertThrows[Exception] {
distributionBalanceMeasure
.setReferenceDistribution(mismatchedLenDist)
.transform(sensitiveFeaturesDf)
}
}

private def actualCustomDist: DataFrame =
new DistributionBalanceMeasure()
.setSensitiveCols(features)
.setVerbose(true)
.setReferenceDistribution(customDistribution)
.transform(sensitiveFeaturesDf)

private def actualCustomDistFeature1: Map[String, Double] =
METRICS zip actualCustomDist.filter(col("FeatureName") === feature1)
.select(array(col("DistributionBalanceMeasure.*")))
.as[Array[Double]]
.head toMap

private def expectedCustomDistFeature1 = getFeatureStats(sensitiveFeaturesDf.groupBy(feature1))
.select(feature1, featureProbCol, featureCountCol)
.as[(String, Double, Double)].collect()

private object ExpectedCustomDistFeature1 {
// Values were computed using:
// val (numRows, numFeatures) = (sensitiveFeaturesDf.count.toDouble, expectedCustomDistFeature1.length)
// val (featureValues, obsProbs, obsCounts) = expectedCustomDistFeature1.unzip3
// val refProbs = featureValues.map(customDistribution.get(0).getOrDefault(_, 0.0)) // idx 0 = Gender
// val refCounts = refProbs.map(_ * numRows)
// val CALC = DistributionMetricsCalculator(refProbs, refCounts, obsProbs, obsCounts, numFeatures)
val KLDIVERGENCE = 0.09399792940857671
val JSDISTANCE = 0.15001917759832653
val INFNORMDISTANCE = 0.19444444444444442
val TOTALVARIATIONDISTANCE = 0.19444444444444445
val WASSERSTEINDISTANCE = 0.12962962962962962
val CHISQUAREDTESTSTATISTIC = 1.880952380952381
val CHISQUAREDPVALUE = 0.3904418663854293
}

test(s"DistributionBalanceMeasure can use a custom reference distribution with all values ($feature1)") {
// The custom reference distribution for Gender is Map("Male" -> 0.25, "Female" -> 0.4, "Other" -> 0.35)
// This includes all unique values of Gender in the dataframe being transformed
val actual = actualCustomDistFeature1
val expected = ExpectedCustomDistFeature1
assert(actual(KLDIVERGENCE) === expected.KLDIVERGENCE)
assert(actual(JSDISTANCE) === expected.JSDISTANCE)
assert(actual(INFNORMDISTANCE) === expected.INFNORMDISTANCE)
assert(actual(TOTALVARIATIONDISTANCE) === expected.TOTALVARIATIONDISTANCE)
assert(actual(WASSERSTEINDISTANCE) === expected.WASSERSTEINDISTANCE)
assert(actual(CHISQUAREDTESTSTATISTIC) === expected.CHISQUAREDTESTSTATISTIC)
assert(actual(CHISQUAREDPVALUE) === expected.CHISQUAREDPVALUE)
}

private def actualCustomDistFeature2: Map[String, Double] =
METRICS zip actualCustomDist.filter(col("FeatureName") === feature2)
.select(array(col("DistributionBalanceMeasure.*")))
.as[Array[Double]]
.head toMap

private def expectedCustomDistFeature2 = getFeatureStats(sensitiveFeaturesDf.groupBy(feature2))
.select(feature2, featureProbCol, featureCountCol)
.as[(String, Double, Double)].collect()

private object ExpectedCustomDistFeature2 {
// Values were computed using:
// val (numRows, numFeatures) = (sensitiveFeaturesDf.count.toDouble, expectedCustomDistFeature2.length)
// val (featureValues, obsProbs, obsCounts) = expectedCustomDistFeature2.unzip3
// val refProbs = featureValues.map(customDistribution.get(1).getOrDefault(_, 0.0)) // idx 1 = Ethnicity
// val refCounts = refProbs.map(_ * numRows)
// val CALC = DistributionMetricsCalculator(refProbs, refCounts, obsProbs, obsCounts, numFeatures)
val KLDIVERGENCE = Double.PositiveInfinity
val JSDISTANCE = 0.2100032735609124
val INFNORMDISTANCE = 0.1111111111111111
val TOTALVARIATIONDISTANCE = 0.1111111111111111
val WASSERSTEINDISTANCE = 0.05555555555555555
val CHISQUAREDTESTSTATISTIC = Double.PositiveInfinity
val CHISQUAREDPVALUE = 1d
}

test(s"DistributionBalanceMeasure can a custom reference distribution with missing values ($feature2)") {
// The custom reference distribution for Ethnicity is Map("Asian" -> 0.33, "White" -> 0.33, "Black" -> 0.33)
// This does NOT include all unique values in the dataframe being transformed; 'Other' is left out
// which means that it should default to a reference probability of 0.00
val actual = actualCustomDistFeature2
val expected = ExpectedCustomDistFeature2
assert(actual(KLDIVERGENCE) === expected.KLDIVERGENCE)
assert(actual(JSDISTANCE) === expected.JSDISTANCE)
assert(actual(INFNORMDISTANCE) === expected.INFNORMDISTANCE)
assert(actual(TOTALVARIATIONDISTANCE) === expected.TOTALVARIATIONDISTANCE)
assert(actual(WASSERSTEINDISTANCE) === expected.WASSERSTEINDISTANCE)
assert(actual(CHISQUAREDTESTSTATISTIC) === expected.CHISQUAREDTESTSTATISTIC)
assert(actual(CHISQUAREDPVALUE) === expected.CHISQUAREDPVALUE)
}
}