-
Notifications
You must be signed in to change notification settings - Fork 834
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
* feat: Causal dmlestimator (#8) Add package 'com.microsoft.azure.synapse.ml.causal' and implementation LinearDMLEstimator * resolve comments and fix build pipeline failures * add //scalastyle:off method.length for trainIternal * fix DML Python API missing problem * fix build pipeline failures * rename LinearDMLEstimator to DoubleMLEstimator * rename DML APIs * forgot to change DoubleMLModel name in fuzzingTest * use findUnusedColumnName for treatmentResidual and outcomeResidul comlumns * address a few small comments * remove HasExcludeFeatureCols and use HasInputCols * fix style issue, Params.scala: File must end with newline character * set DoubleMLEstimator transform as experimental Co-authored-by: Jason Wang <jasowang@microsoft.com>
- Loading branch information
1 parent
7ab63a1
commit d0a9f20
Showing
16 changed files
with
914 additions
and
79 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# Copyright (C) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
import sys | ||
|
||
if sys.version >= "3": | ||
basestring = str | ||
|
||
from synapse.ml.causal._DoubleMLModel import _DoubleMLModel | ||
from pyspark.ml.common import inherit_doc | ||
import numpy as np | ||
|
||
|
||
@inherit_doc | ||
class DoubleMLModel(_DoubleMLModel): | ||
def getAvgTreatmentEffect(self): | ||
return sum(self.getRawTreatmentEffects()) / len(self.getRawTreatmentEffects()) | ||
|
||
def getConfidenceInterval(self): | ||
ciLowerBound = np.percentile( | ||
self.getRawTreatmentEffects(), 100 * (1 - self.getConfidenceLevel()) | ||
) | ||
ciUpperBound = np.percentile( | ||
self.getRawTreatmentEffects(), self.getConfidenceLevel() * 100 | ||
) | ||
return [ciLowerBound, ciUpperBound] |
Empty file.
312 changes: 312 additions & 0 deletions
312
core/src/main/scala/com/microsoft/azure/synapse/ml/causal/DoubleMLEstimator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,312 @@ | ||
// Copyright (C) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. See LICENSE in project root for information. | ||
|
||
package com.microsoft.azure.synapse.ml.causal | ||
|
||
import com.microsoft.azure.synapse.ml.codegen.Wrappable | ||
import com.microsoft.azure.synapse.ml.train.{TrainClassifier, TrainRegressor} | ||
import com.microsoft.azure.synapse.ml.core.schema.{DatasetExtensions, SchemaConstants} | ||
import com.microsoft.azure.synapse.ml.core.utils.StopWatch | ||
import com.microsoft.azure.synapse.ml.logging.SynapseMLLogging | ||
import com.microsoft.azure.synapse.ml.stages.DropColumns | ||
import jdk.jfr.Experimental | ||
import org.apache.commons.math3.stat.descriptive.rank.Percentile | ||
import org.apache.spark.annotation.DeveloperApi | ||
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model, Pipeline} | ||
import org.apache.spark.ml.classification.ProbabilisticClassifier | ||
import org.apache.spark.ml.regression.{GeneralizedLinearRegression, Regressor} | ||
import org.apache.spark.ml.feature.VectorAssembler | ||
import org.apache.spark.ml.param.{DoubleArrayParam, ParamMap} | ||
import org.apache.spark.ml.param.shared.{HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasWeightCol} | ||
import org.apache.spark.ml.util.Identifiable | ||
import org.apache.spark.sql.{DataFrame, Dataset} | ||
import org.apache.spark.sql.types.StructType | ||
|
||
import scala.concurrent.Future | ||
|
||
/** Double ML estimators. The estimator follows the two stage process, | ||
* where a set of nuisance functions are estimated in the first stage in a cross-fitting manner | ||
* and a final stage estimates the average treatment effect (ATE) model. | ||
* Our goal is to estimate the constant marginal ATE Theta(X) | ||
* | ||
* In this estimator, the ATE is estimated by using the following estimating equations: | ||
* .. math :: | ||
* Y - \\E[Y | X, W] = \\Theta(X) \\cdot (T - \\E[T | X, W]) + \\epsilon | ||
* | ||
* Thus if we estimate the nuisance functions :math:`q(X, W) = \\E[Y | X, W]` and | ||
* :math:`f(X, W)=\\E[T | X, W]` in the first stage, we can estimate the final stage ate for each | ||
* treatment t, by running a regression, minimizing the residual on residual square loss, | ||
* estimating Theta(X) is a final regression problem, regressing tilde{Y} on X and tilde{T}) | ||
* | ||
* .. math :: | ||
* \\hat{\\theta} = \\arg\\min_{\\Theta}\ | ||
* \E_n\\left[ (\\tilde{Y} - \\Theta(X) \\cdot \\tilde{T})^2 \\right] | ||
* | ||
* Where | ||
* `\\tilde{Y}=Y - \\E[Y | X, W]` and :math:`\\tilde{T}=T-\\E[T | X, W]` denotes the | ||
* residual outcome and residual treatment. | ||
* | ||
* The nuisance function :math:`q` is a simple machine learning problem and | ||
* user can use setOutcomeModel to set an arbitrary sparkML model | ||
* that is internally used to solve this problem | ||
* | ||
* The problem of estimating the nuisance function :math:`f` is also a machine learning problem and | ||
* user can use setTreatmentModel to set an arbitrary sparkML model | ||
* that is internally used to solve this problem. | ||
* | ||
*/ | ||
//noinspection ScalaDocParserErrorInspection,ScalaDocUnclosedTagWithoutParser | ||
class DoubleMLEstimator(override val uid: String) | ||
extends Estimator[DoubleMLModel] with ComplexParamsWritable | ||
with DoubleMLParams with SynapseMLLogging with Wrappable { | ||
|
||
logClass() | ||
|
||
def this() = this(Identifiable.randomUID("DoubleMLEstimator")) | ||
|
||
/** Fits the DoubleML model. | ||
* | ||
* @param dataset The input dataset to train. | ||
* @return The trained DoubleML model, from which you can get Ate and Ci values | ||
*/ | ||
override def fit(dataset: Dataset[_]): DoubleMLModel = { | ||
logFit({ | ||
require(getMaxIter > 0, "maxIter should be larger than 0!") | ||
if (get(weightCol).isDefined) { | ||
getTreatmentModel match { | ||
case w: HasWeightCol => w.set(w.weightCol, getWeightCol) | ||
case _ => throw new Exception("""The selected treatment model does not support sample weight, | ||
but the weightCol parameter was set for the DoubleMLEstimator. | ||
Please select a treatment model that supports sample weight.""".stripMargin) | ||
} | ||
getOutcomeModel match { | ||
case w: HasWeightCol => w.set(w.weightCol, getWeightCol) | ||
case _ => throw new Exception("""The selected outcome model does not support sample weight, | ||
but the weightCol parameter was set for the DoubleMLEstimator. | ||
Please select a outcome model that supports sample weight.""".stripMargin) | ||
} | ||
} | ||
|
||
// sampling with replacement to redraw data and get TE value | ||
// Run it for multiple times in parallel, get a number of TE values, | ||
// Use average as Ate value, and 2.5% low end, 97.5% high end as Ci value | ||
// Create execution context based on $(parallelism) | ||
log.info(s"Parallelism: $getParallelism") | ||
val executionContext = getExecutionContextProxy | ||
|
||
val ateFutures =(1 to getMaxIter).toArray.map { index => | ||
Future[Option[Double]] { | ||
log.info(s"Executing ATE calculation on iteration: $index") | ||
// If the algorithm runs over 1 iteration, do not bootstrap from dataset, | ||
// otherwise, redraw sample with replacement | ||
val redrewDF = if (getMaxIter == 1) dataset else dataset.sample(withReplacement = true, fraction = 1) | ||
val ate: Option[Double] = | ||
try { | ||
val totalTime = new StopWatch | ||
val oneAte = totalTime.measure { | ||
trainInternal(redrewDF) | ||
} | ||
log.info(s"Completed ATE calculation on iteration $index and got ATE value: $oneAte, " + | ||
s"time elapsed: ${totalTime.elapsed() / 6e10} minutes") | ||
Some(oneAte) | ||
} catch { | ||
case ex: Throwable => | ||
log.warn(s"ATE calculation got exception on iteration $index with the redrew sample data. " + | ||
s"Exception details: $ex") | ||
None | ||
} | ||
ate | ||
}(executionContext) | ||
} | ||
|
||
val ates = awaitFutures(ateFutures).flatten | ||
if (ates.isEmpty) { | ||
throw new Exception("ATE calculation failed on all iterations. Please check the log for details.") | ||
} | ||
val dmlModel = this.copyValues(new DoubleMLModel(uid)).setRawTreatmentEffects(ates.toArray) | ||
dmlModel | ||
}) | ||
} | ||
|
||
//scalastyle:off method.length | ||
private def trainInternal(dataset: Dataset[_]): Double = { | ||
|
||
def getModel(model: Estimator[_ <: Model[_]], labelColName: String) = { | ||
model match { | ||
case classifier: ProbabilisticClassifier[_, _, _] => ( | ||
new TrainClassifier() | ||
.setModel(model) | ||
.setLabelCol(labelColName), | ||
classifier.getProbabilityCol | ||
) | ||
case regressor: Regressor[_, _, _] => ( | ||
new TrainRegressor() | ||
.setModel(model) | ||
.setLabelCol(labelColName), | ||
regressor.getPredictionCol | ||
) | ||
} | ||
} | ||
|
||
def getPredictedCols(model: Estimator[_ <: Model[_]]): Array[String] = { | ||
val rawPredictionCol = model match { | ||
case rp: HasRawPredictionCol => Some(rp.getRawPredictionCol) | ||
case _ => None | ||
} | ||
|
||
val predictionCol = model match { | ||
case p: HasPredictionCol => Some(p.getPredictionCol) | ||
case _ => None | ||
} | ||
|
||
val probabilityCol = model match { | ||
case pr: HasProbabilityCol => Some(pr.getProbabilityCol) | ||
case _ => None | ||
} | ||
|
||
(rawPredictionCol :: predictionCol :: probabilityCol :: Nil).flatten.toArray | ||
} | ||
|
||
val (treatmentEstimator, treatmentResidualPredictionColName) = getModel( | ||
getTreatmentModel.copy(getTreatmentModel.extractParamMap()), | ||
getTreatmentCol | ||
) | ||
val treatmentPredictionColsToDrop = getPredictedCols(getTreatmentModel) | ||
|
||
val (outcomeEstimator, outcomeResidualPredictionColName) = getModel( | ||
getOutcomeModel.copy(getOutcomeModel.extractParamMap()), | ||
getOutcomeCol | ||
) | ||
val outcomePredictionColsToDrop = getPredictedCols(getOutcomeModel) | ||
|
||
val treatmentResidualCol = DatasetExtensions.findUnusedColumnName(SchemaConstants.TreatmentResidualColumn, dataset) | ||
val outcomeResidualCol = DatasetExtensions.findUnusedColumnName(SchemaConstants.OutcomeResidualColumn, dataset) | ||
val treatmentResidualVecCol = DatasetExtensions.findUnusedColumnName("treatmentResidualVec", dataset) | ||
|
||
def calculateResiduals(train: Dataset[_], test: Dataset[_]): DataFrame = { | ||
val treatmentModel = treatmentEstimator.setInputCols(train.columns.filterNot(_ == getOutcomeCol)).fit(train) | ||
val outcomeModel = outcomeEstimator.setInputCols(train.columns.filterNot(_ == getTreatmentCol)).fit(train) | ||
|
||
val treatmentResidual = | ||
new ResidualTransformer() | ||
.setObservedCol(getTreatmentCol) | ||
.setPredictedCol(treatmentResidualPredictionColName) | ||
.setOutputCol(treatmentResidualCol) | ||
val dropTreatmentPredictedColumns = new DropColumns().setCols(treatmentPredictionColsToDrop.toArray) | ||
val outcomeResidual = | ||
new ResidualTransformer() | ||
.setObservedCol(getOutcomeCol) | ||
.setPredictedCol(outcomeResidualPredictionColName) | ||
.setOutputCol(outcomeResidualCol) | ||
val dropOutcomePredictedColumns = new DropColumns().setCols(outcomePredictionColsToDrop.toArray) | ||
val treatmentResidualVA = | ||
new VectorAssembler() | ||
.setInputCols(Array(treatmentResidualCol)) | ||
.setOutputCol(treatmentResidualVecCol) | ||
.setHandleInvalid("skip") | ||
val pipeline = new Pipeline().setStages(Array( | ||
treatmentModel, treatmentResidual, dropTreatmentPredictedColumns, | ||
outcomeModel, outcomeResidual, dropOutcomePredictedColumns, | ||
treatmentResidualVA)) | ||
|
||
pipeline.fit(test).transform(test) | ||
} | ||
|
||
// Note, we perform these steps to get ATE | ||
/* | ||
1. Split sample, e.g. 50/50 | ||
2. Use the first split to fit the treatment model and the outcome model. | ||
3. Use the two models to fit a residual model on the second split. | ||
4. Cross-fit treatment and outcome models with the second split, residual model with the first split. | ||
5. Average slopes from the two residual models. | ||
*/ | ||
val splits = dataset.randomSplit(getSampleSplitRatio) | ||
val (train, test) = (splits(0).cache, splits(1).cache) | ||
val residualsDF1 = calculateResiduals(train, test) | ||
val residualsDF2 = calculateResiduals(test, train) | ||
|
||
// Average slopes from the two residual models. | ||
val regressor = new GeneralizedLinearRegression() | ||
.setLabelCol(outcomeResidualCol) | ||
.setFeaturesCol(treatmentResidualVecCol) | ||
.setFamily("gaussian") | ||
.setLink("identity") | ||
.setFitIntercept(false) | ||
|
||
val coefficients = Array(residualsDF1, residualsDF2).map(regressor.fit).map(_.coefficients(0)) | ||
val ate = coefficients.sum / coefficients.length | ||
|
||
Seq(train, test).foreach(_.unpersist) | ||
ate | ||
} | ||
|
||
override def copy(extra: ParamMap): Estimator[DoubleMLModel] = { | ||
defaultCopy(extra) | ||
} | ||
|
||
@DeveloperApi | ||
override def transformSchema(schema: StructType): StructType = { | ||
DoubleMLEstimator.validateTransformSchema(schema) | ||
} | ||
} | ||
|
||
object DoubleMLEstimator extends ComplexParamsReadable[DoubleMLEstimator] { | ||
|
||
def validateTransformSchema(schema: StructType): StructType = { | ||
StructType(schema.fields) | ||
} | ||
} | ||
|
||
/** Model produced by [[DoubleMLEstimator]]. */ | ||
class DoubleMLModel(val uid: String) | ||
extends Model[DoubleMLModel] with DoubleMLParams with ComplexParamsWritable with Wrappable with SynapseMLLogging { | ||
logClass() | ||
|
||
override protected lazy val pyInternalWrapper = true | ||
|
||
def this() = this(Identifiable.randomUID("DoubleMLModel")) | ||
|
||
val rawTreatmentEffects = new DoubleArrayParam( | ||
this, | ||
"rawTreatmentEffects", | ||
"raw treatment effect results for all iterations") | ||
def getRawTreatmentEffects: Array[Double] = $(rawTreatmentEffects) | ||
def setRawTreatmentEffects(v: Array[Double]): this.type = set(rawTreatmentEffects, v) | ||
|
||
def getAvgTreatmentEffect: Double = { | ||
val finalAte = $(rawTreatmentEffects).sum / $(rawTreatmentEffects).length | ||
finalAte | ||
} | ||
|
||
def getConfidenceInterval: Array[Double] = { | ||
val ciLowerBound = percentile($(rawTreatmentEffects), 100 * (1 - getConfidenceLevel)) | ||
val ciUpperBound = percentile($(rawTreatmentEffects), getConfidenceLevel * 100) | ||
Array(ciLowerBound, ciUpperBound) | ||
} | ||
|
||
private def percentile(values: Seq[Double], quantile: Double): Double = { | ||
val sortedValues = values.sorted | ||
val percentile = new Percentile() | ||
percentile.setData(sortedValues.toArray) | ||
percentile.evaluate(quantile) | ||
} | ||
|
||
override def copy(extra: ParamMap): DoubleMLModel = defaultCopy(extra) | ||
|
||
/** | ||
* :: Experimental :: | ||
* DoubleMLEstimator transform function is still experimental, and its behavior could change in the future. | ||
*/ | ||
@Experimental | ||
override def transform(dataset: Dataset[_]): DataFrame = { | ||
logTransform[DataFrame]({ | ||
dataset.toDF() | ||
}) | ||
} | ||
|
||
@DeveloperApi | ||
override def transformSchema(schema: StructType): StructType = | ||
StructType(schema.fields) | ||
} | ||
|
||
object DoubleMLModel extends ComplexParamsReadable[DoubleMLModel] |
Oops, something went wrong.