feat: Causal DoubleMLEstimator (#8) (#1715)
* feat: Causal dmlestimator (#8)

Add package '' and implementation LinearDMLEstimator

* resolve comments and fix build pipeline failures

* add //scalastyle:off method.length for trainIternal

* fix DML Python API missing problem

* fix build pipeline failures

* rename LinearDMLEstimator to DoubleMLEstimator

* rename DML APIs

* forgot to change DoubleMLModel name in fuzzingTest

* use findUnusedColumnName for treatmentResidual and outcomeResidul comlumns

* address a few small comments

* remove HasExcludeFeatureCols and use HasInputCols

* fix style issue, Params.scala: File must end with newline character

* set DoubleMLEstimator transform as experimental

Co-authored-by: Jason Wang <>
dylanw-oss and memoryz authored Dec 19, 2022
1 parent 7ab63a1 commit d0a9f20
Showing 16 changed files with 914 additions and 79 deletions.
26 changes: 26 additions & 0 deletions core/src/main/python/synapse/ml/causal/
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (C) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See LICENSE in project root for information.

import sys

if sys.version >= "3":
basestring = str

from import _DoubleMLModel
from import inherit_doc
import numpy as np

class DoubleMLModel(_DoubleMLModel):
def getAvgTreatmentEffect(self):
return sum(self.getRawTreatmentEffects()) / len(self.getRawTreatmentEffects())

def getConfidenceInterval(self):
ciLowerBound = np.percentile(
self.getRawTreatmentEffects(), 100 * (1 - self.getConfidenceLevel())
ciUpperBound = np.percentile(
self.getRawTreatmentEffects(), self.getConfidenceLevel() * 100
return [ciLowerBound, ciUpperBound]
Original file line number Diff line number Diff line change
@@ -0,0 +1,312 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.


import{TrainClassifier, TrainRegressor}
import{DatasetExtensions, SchemaConstants}
import jdk.jfr.Experimental
import org.apache.commons.math3.stat.descriptive.rank.Percentile
import org.apache.spark.annotation.DeveloperApi
import{ComplexParamsReadable, ComplexParamsWritable, Estimator, Model, Pipeline}
import{GeneralizedLinearRegression, Regressor}
import{DoubleArrayParam, ParamMap}
import{HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasWeightCol}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType

import scala.concurrent.Future

/** Double ML estimators. The estimator follows the two stage process,
* where a set of nuisance functions are estimated in the first stage in a cross-fitting manner
* and a final stage estimates the average treatment effect (ATE) model.
* Our goal is to estimate the constant marginal ATE Theta(X)
* In this estimator, the ATE is estimated by using the following estimating equations:
* .. math ::
* Y - \\E[Y | X, W] = \\Theta(X) \\cdot (T - \\E[T | X, W]) + \\epsilon
* Thus if we estimate the nuisance functions :math:`q(X, W) = \\E[Y | X, W]` and
* :math:`f(X, W)=\\E[T | X, W]` in the first stage, we can estimate the final stage ate for each
* treatment t, by running a regression, minimizing the residual on residual square loss,
* estimating Theta(X) is a final regression problem, regressing tilde{Y} on X and tilde{T})
* .. math ::
* \\hat{\\theta} = \\arg\\min_{\\Theta}\
* \E_n\\left[ (\\tilde{Y} - \\Theta(X) \\cdot \\tilde{T})^2 \\right]
* Where
* `\\tilde{Y}=Y - \\E[Y | X, W]` and :math:`\\tilde{T}=T-\\E[T | X, W]` denotes the
* residual outcome and residual treatment.
* The nuisance function :math:`q` is a simple machine learning problem and
* user can use setOutcomeModel to set an arbitrary sparkML model
* that is internally used to solve this problem
* The problem of estimating the nuisance function :math:`f` is also a machine learning problem and
* user can use setTreatmentModel to set an arbitrary sparkML model
* that is internally used to solve this problem.
//noinspection ScalaDocParserErrorInspection,ScalaDocUnclosedTagWithoutParser
class DoubleMLEstimator(override val uid: String)
extends Estimator[DoubleMLModel] with ComplexParamsWritable
with DoubleMLParams with SynapseMLLogging with Wrappable {


def this() = this(Identifiable.randomUID("DoubleMLEstimator"))

/** Fits the DoubleML model.
* @param dataset The input dataset to train.
* @return The trained DoubleML model, from which you can get Ate and Ci values
override def fit(dataset: Dataset[_]): DoubleMLModel = {
require(getMaxIter > 0, "maxIter should be larger than 0!")
if (get(weightCol).isDefined) {
getTreatmentModel match {
case w: HasWeightCol => w.set(w.weightCol, getWeightCol)
case _ => throw new Exception("""The selected treatment model does not support sample weight,
but the weightCol parameter was set for the DoubleMLEstimator.
Please select a treatment model that supports sample weight.""".stripMargin)
getOutcomeModel match {
case w: HasWeightCol => w.set(w.weightCol, getWeightCol)
case _ => throw new Exception("""The selected outcome model does not support sample weight,
but the weightCol parameter was set for the DoubleMLEstimator.
Please select a outcome model that supports sample weight.""".stripMargin)

// sampling with replacement to redraw data and get TE value
// Run it for multiple times in parallel, get a number of TE values,
// Use average as Ate value, and 2.5% low end, 97.5% high end as Ci value
// Create execution context based on $(parallelism)"Parallelism: $getParallelism")
val executionContext = getExecutionContextProxy

val ateFutures =(1 to getMaxIter) { index =>
Future[Option[Double]] {"Executing ATE calculation on iteration: $index")
// If the algorithm runs over 1 iteration, do not bootstrap from dataset,
// otherwise, redraw sample with replacement
val redrewDF = if (getMaxIter == 1) dataset else dataset.sample(withReplacement = true, fraction = 1)
val ate: Option[Double] =
try {
val totalTime = new StopWatch
val oneAte = totalTime.measure {
}"Completed ATE calculation on iteration $index and got ATE value: $oneAte, " +
s"time elapsed: ${totalTime.elapsed() / 6e10} minutes")
} catch {
case ex: Throwable =>
log.warn(s"ATE calculation got exception on iteration $index with the redrew sample data. " +
s"Exception details: $ex")

val ates = awaitFutures(ateFutures).flatten
if (ates.isEmpty) {
throw new Exception("ATE calculation failed on all iterations. Please check the log for details.")
val dmlModel = this.copyValues(new DoubleMLModel(uid)).setRawTreatmentEffects(ates.toArray)

//scalastyle:off method.length
private def trainInternal(dataset: Dataset[_]): Double = {

def getModel(model: Estimator[_ <: Model[_]], labelColName: String) = {
model match {
case classifier: ProbabilisticClassifier[_, _, _] => (
new TrainClassifier()
case regressor: Regressor[_, _, _] => (
new TrainRegressor()

def getPredictedCols(model: Estimator[_ <: Model[_]]): Array[String] = {
val rawPredictionCol = model match {
case rp: HasRawPredictionCol => Some(rp.getRawPredictionCol)
case _ => None

val predictionCol = model match {
case p: HasPredictionCol => Some(p.getPredictionCol)
case _ => None

val probabilityCol = model match {
case pr: HasProbabilityCol => Some(pr.getProbabilityCol)
case _ => None

(rawPredictionCol :: predictionCol :: probabilityCol :: Nil).flatten.toArray

val (treatmentEstimator, treatmentResidualPredictionColName) = getModel(
val treatmentPredictionColsToDrop = getPredictedCols(getTreatmentModel)

val (outcomeEstimator, outcomeResidualPredictionColName) = getModel(
val outcomePredictionColsToDrop = getPredictedCols(getOutcomeModel)

val treatmentResidualCol = DatasetExtensions.findUnusedColumnName(SchemaConstants.TreatmentResidualColumn, dataset)
val outcomeResidualCol = DatasetExtensions.findUnusedColumnName(SchemaConstants.OutcomeResidualColumn, dataset)
val treatmentResidualVecCol = DatasetExtensions.findUnusedColumnName("treatmentResidualVec", dataset)

def calculateResiduals(train: Dataset[_], test: Dataset[_]): DataFrame = {
val treatmentModel = treatmentEstimator.setInputCols(train.columns.filterNot(_ == getOutcomeCol)).fit(train)
val outcomeModel = outcomeEstimator.setInputCols(train.columns.filterNot(_ == getTreatmentCol)).fit(train)

val treatmentResidual =
new ResidualTransformer()
val dropTreatmentPredictedColumns = new DropColumns().setCols(treatmentPredictionColsToDrop.toArray)
val outcomeResidual =
new ResidualTransformer()
val dropOutcomePredictedColumns = new DropColumns().setCols(outcomePredictionColsToDrop.toArray)
val treatmentResidualVA =
new VectorAssembler()
val pipeline = new Pipeline().setStages(Array(
treatmentModel, treatmentResidual, dropTreatmentPredictedColumns,
outcomeModel, outcomeResidual, dropOutcomePredictedColumns,

// Note, we perform these steps to get ATE
1. Split sample, e.g. 50/50
2. Use the first split to fit the treatment model and the outcome model.
3. Use the two models to fit a residual model on the second split.
4. Cross-fit treatment and outcome models with the second split, residual model with the first split.
5. Average slopes from the two residual models.
val splits = dataset.randomSplit(getSampleSplitRatio)
val (train, test) = (splits(0).cache, splits(1).cache)
val residualsDF1 = calculateResiduals(train, test)
val residualsDF2 = calculateResiduals(test, train)

// Average slopes from the two residual models.
val regressor = new GeneralizedLinearRegression()

val coefficients = Array(residualsDF1, residualsDF2).map(
val ate = coefficients.sum / coefficients.length

Seq(train, test).foreach(_.unpersist)

override def copy(extra: ParamMap): Estimator[DoubleMLModel] = {

override def transformSchema(schema: StructType): StructType = {

object DoubleMLEstimator extends ComplexParamsReadable[DoubleMLEstimator] {

def validateTransformSchema(schema: StructType): StructType = {

/** Model produced by [[DoubleMLEstimator]]. */
class DoubleMLModel(val uid: String)
extends Model[DoubleMLModel] with DoubleMLParams with ComplexParamsWritable with Wrappable with SynapseMLLogging {

override protected lazy val pyInternalWrapper = true

def this() = this(Identifiable.randomUID("DoubleMLModel"))

val rawTreatmentEffects = new DoubleArrayParam(
"raw treatment effect results for all iterations")
def getRawTreatmentEffects: Array[Double] = $(rawTreatmentEffects)
def setRawTreatmentEffects(v: Array[Double]): this.type = set(rawTreatmentEffects, v)

def getAvgTreatmentEffect: Double = {
val finalAte = $(rawTreatmentEffects).sum / $(rawTreatmentEffects).length

def getConfidenceInterval: Array[Double] = {
val ciLowerBound = percentile($(rawTreatmentEffects), 100 * (1 - getConfidenceLevel))
val ciUpperBound = percentile($(rawTreatmentEffects), getConfidenceLevel * 100)
Array(ciLowerBound, ciUpperBound)

private def percentile(values: Seq[Double], quantile: Double): Double = {
val sortedValues = values.sorted
val percentile = new Percentile()

override def copy(extra: ParamMap): DoubleMLModel = defaultCopy(extra)

* :: Experimental ::
* DoubleMLEstimator transform function is still experimental, and its behavior could change in the future.
override def transform(dataset: Dataset[_]): DataFrame = {

override def transformSchema(schema: StructType): StructType =

object DoubleMLModel extends ComplexParamsReadable[DoubleMLModel]

