Skip to content

Commit

Permalink
Fix bug microsoft#1869, DML .setFitIntercept should be set to true
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanw-oss committed Mar 16, 2023
1 parent 129abde commit bcbb648
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.ml.param.{DoubleArrayParam, ParamMap}
import org.apache.spark.ml.param.shared.{HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, HasWeightCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.{BooleanType, DoubleType, IntegerType, LongType, StructType}
import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, IntegerType, LongType, StructType}
import org.apache.commons.math3.stat.inference.TestUtils

import scala.concurrent.Future
Expand Down Expand Up @@ -77,21 +77,9 @@ class DoubleMLEstimator(override val uid: String)
logFit({
require(getMaxIter > 0, "maxIter should be larger than 0!")
val treatmentColType = dataset.schema(getTreatmentCol).dataType
require(treatmentColType == DoubleType || treatmentColType == LongType
|| treatmentColType == IntegerType || treatmentColType == BooleanType,
s"TreatmentCol must be of type DoubleType, LongType, IntegerType or BooleanType but got $treatmentColType")

if (treatmentColType != IntegerType && treatmentColType != BooleanType
&& getTreatmentType == TreatmentTypes.Binary) {
throw new Exception("TreatmentModel was set to use classifier " +
"but treatment column in dataset isn't integer or boolean type.")
}

if (treatmentColType != DoubleType && treatmentColType != LongType
&& getTreatmentType == TreatmentTypes.Continuous) {
throw new Exception("TreatmentModel was set to use regression " +
"but treatment column in dataset isn't continuous data type.")
}
validateColTypeWithModel(treatmentColType, getTreatmentCol, getTreatmentModel)
val outcomeColType = dataset.schema(getOutcomeCol).dataType
validateColTypeWithModel(outcomeColType, getOutcomeCol, getOutcomeModel)

if (get(weightCol).isDefined) {
getTreatmentModel match {
Expand Down Expand Up @@ -256,7 +244,7 @@ class DoubleMLEstimator(override val uid: String)
.setFeaturesCol(treatmentResidualVecCol)
.setFamily("gaussian")
.setLink("identity")
.setFitIntercept(false)
.setFitIntercept(true)

val coefficients = Array(residualsDF1, residualsDF2).map(regressor.fit).map(_.coefficients(0))
val ate = coefficients.sum / coefficients.length
Expand All @@ -273,6 +261,23 @@ class DoubleMLEstimator(override val uid: String)
override def transformSchema(schema: StructType): StructType = {
DoubleMLEstimator.validateTransformSchema(schema)
}

protected def validateColTypeWithModel(colType: DataType, colName: String, model: Estimator[_]): Unit = {
val modelType = getDoubleMLModelType(model)
colType match {
case IntegerType | BooleanType =>
if (modelType == DoubleMLModelTypes.Continuous)
throw new Exception(s"column $colName in dataset is integer or boolean data type " +
s"but you set to use a regression model for it.")
case DoubleType | LongType =>
if (modelType == DoubleMLModelTypes.Binary)
throw new Exception(s"column $colName in dataset is double or long data type" +
"but you set to use a classification model for it.")
case _ =>
throw new Exception(s"column $colName must be of type DoubleType, LongType, " +
s"IntegerType or BooleanType but got $colType")
}
}
}

object DoubleMLEstimator extends ComplexParamsReadable[DoubleMLEstimator] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,24 @@ trait DoubleMLParams extends Params
*/
def setSampleSplitRatio(value: Array[Double]): this.type = set(sampleSplitRatio, value)

private[causal] object TreatmentTypes extends Enumeration {
private[causal] object DoubleMLModelTypes extends Enumeration {
type TreatmentType = Value
val Binary, Continuous = Value
}

private[causal] def getTreatmentType: TreatmentTypes.Value = {
val treatmentType =
getTreatmentModel match {
case _: ProbabilisticClassifier[_, _, _] =>
TreatmentTypes.Binary
case _: Regressor[_, _, _] =>
TreatmentTypes.Continuous
}
treatmentType
private[causal] def getDoubleMLModelType(model: Any): DoubleMLModelTypes.Value = {
model match {
case _: ProbabilisticClassifier[_, _, _] =>
DoubleMLModelTypes.Binary
case _: Regressor[_, _, _] =>
DoubleMLModelTypes.Continuous
case _ =>
throw new IllegalArgumentException(s"Invalid model type: ${model.getClass.getName}")
}
}



val confidenceLevel = new DoubleParam(
this,
"confidenceLevel",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class VerifyDoubleMLEstimator extends EstimatorFuzzing[DoubleMLEstimator] {
assert(ateLow < ateHigh && ateLow > -130 && ateHigh < 130)
}

test("Invalid treatment model will throw exception.") {
test("Mismatch treatment model and treatment column will throw exception.") {
assertThrows[Exception] {
val ldml = new DoubleMLEstimator()
.setTreatmentModel(new LinearRegression())
Expand All @@ -101,7 +101,21 @@ class VerifyDoubleMLEstimator extends EstimatorFuzzing[DoubleMLEstimator] {
.setOutcomeCol("col2")
.setMaxIter(20)

val ldmlModel = ldml.fit(mockDataset)
ldml.fit(mockDataset)
}
}

test("Mismatch outcome model and outcome column will throw exception.") {
assertThrows[Exception] {
val ldml = new DoubleMLEstimator()
.setTreatmentModel(new LogisticRegression())
.setTreatmentCol(mockLabelColumn)
.setOutcomeModel(new LinearRegression())
.setOutcomeCol("col1")
.setMaxIter(5)

val dmlModel = ldml.fit(mockDataset)
dmlModel.getAvgTreatmentEffect
}
}

Expand Down

0 comments on commit bcbb648

Please sign in to comment.